From c379a89bcb26bc4838efe70fae04760106e8d081 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Tue, 30 Jan 2024 14:29:12 -0800 Subject: [PATCH 001/207] [MLAS AArch64] SQNBitGemm optimization (#19272) 1. Add support for packing 4-bit values 32 at a time for CompInt8. 32 4-bit values can fit into a single 128-bit NEON register. For CompInt8, this enables a more efficient path for block sizes greater than or equal to 32. CompFp32 seems to do better with handling 16 elements at a time, so this 32-value packing is not used there. Pack differently based on compute type. Adjust APIs to handle this. 2. Introduce template argument for whether to handle zero-point. This results in less code for the no zero-point (symmetric) case. However, there is a binary size increase due to the additional template instantiations. --- .../cpu/quantization/matmul_nbits.cc | 130 +++-- onnxruntime/core/mlas/inc/mlas_qnbit.h | 16 +- onnxruntime/core/mlas/lib/sqnbitgemm.cpp | 146 +++--- onnxruntime/core/mlas/lib/sqnbitgemm.h | 27 ++ .../core/mlas/lib/sqnbitgemm_kernel_neon.cpp | 445 ++++++++++++++---- .../test/mlas/bench/bench_sqnbitgemm.cpp | 19 +- .../test/mlas/unittest/test_sqnbitgemm.cpp | 7 +- 7 files changed, 558 insertions(+), 232 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 72948c74d7877..166f5c8f52f54 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -9,6 +9,7 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" + #ifdef ORT_NEURAL_SPEED #include "contrib_ops/cpu/quantization/neural_speed_gemm.h" #endif @@ -16,6 +17,39 @@ namespace onnxruntime { namespace contrib { +namespace { +int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level_attr) { + const auto accuracy_level = std::clamp(accuracy_level_attr, + static_cast(CompMostAccurate), + static_cast(CompLeastAccurate)); + +#if defined(ORT_NEURAL_SPEED) + + ORT_UNUSED_PARAMETER(nbits); + ORT_UNUSED_PARAMETER(block_size); + + // Neural Speed APIs already expect a minimum accuracy level so just use the given value. + return accuracy_level; + +#else // defined(ORT_NEURAL_SPEED) + + // Find a supported accuracy level that is not less accurate than the one given. + // CompMostAccurate is always supported with the fallback implementation. + // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. + int64_t effective_accuracy_level = accuracy_level; + for (; effective_accuracy_level > CompMostAccurate; --effective_accuracy_level) { + const auto compute_type = static_cast(effective_accuracy_level); + if (MlasIsSQNBitGemmAvailable(nbits, block_size, compute_type)) { + break; + } + } + + return effective_accuracy_level; + +#endif // defined(ORT_NEURAL_SPEED) +} +} // namespace + class MatMulNBits final : public OpKernel { public: MatMulNBits(const OpKernelInfo& info) @@ -24,7 +58,7 @@ class MatMulNBits final : public OpKernel { N_{narrow(info.GetAttr("N"))}, block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, - accuracy_level_{info.GetAttr("accuracy_level")} { + accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); #ifdef ORT_NEURAL_SPEED @@ -58,17 +92,22 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; -#ifdef ORT_NEURAL_SPEED + +#if defined(ORT_NEURAL_SPEED) + bool is_asym_{false}; bool all_constant_{false}; -#endif + +#endif // defined(ORT_NEURAL_SPEED) }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; -#ifdef ORT_NEURAL_SPEED + +#if defined(ORT_NEURAL_SPEED) + if (!all_constant_) { return Status::OK(); } @@ -116,11 +155,17 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat #else // defined(ORT_NEURAL_SPEED) if (input_idx == 1) { - packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_); - if (packed_b_size_ == 0) return Status::OK(); + const auto compute_type = static_cast(accuracy_level_); + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + return Status::OK(); + } + packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type); + if (packed_b_size_ == 0) { + return Status::OK(); + } auto qptr = tensor.DataRaw(); packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, qptr, packed_b_.get()); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, qptr, packed_b_.get()); if (prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); prepacked_weights->buffer_sizes_.push_back(packed_b_size_); @@ -136,7 +181,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prepacked_buffers, int input_idx, /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; -#ifdef ORT_NEURAL_SPEED + +#if defined(ORT_NEURAL_SPEED) + // Pack three tensors into one buffer if (input_idx == 1) { used_shared_buffers = true; @@ -159,6 +206,7 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep } #endif // defined(ORT_NEURAL_SPEED) + return Status::OK(); } @@ -167,8 +215,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); -#ifdef ORT_NEURAL_SPEED - if (packed_b_.get()) { + +#if defined(ORT_NEURAL_SPEED) + + if (packed_b_) { TensorShape b_shape({static_cast(N_), static_cast(K_)}); MatMulComputeHelper helper; @@ -234,37 +284,43 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); - if (has_single_b_matrix && packed_b_) { - for (int64_t accuracy_level = accuracy_level_; - accuracy_level >= static_cast(CompMostAccurate); - --accuracy_level) { - const auto compute_type = static_cast(accuracy_level); - if (MlasIsSQNBitGemmAvailable(M, N, K, nbits_, block_size_, compute_type)) { - IAllocatorUniquePtr workspace{}; - if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, - nbits_, block_size_, compute_type); - workspace_size > 0) { - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); - workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); - } + if (has_single_b_matrix) { + const auto compute_type = static_cast(accuracy_level_); + + if (MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + IAllocatorUniquePtr workspace{}; + if (const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, batch_count, + nbits_, block_size_, compute_type); + workspace_size > 0) { + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); + } - InlinedVector data(batch_count); - for (size_t i = 0; i < batch_count; ++i) { - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].QuantBData = packed_b_.get(); - data[i].QuantBScale = scales_data; - data[i].QuantBZeroPoint = zero_points_data; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; + const void* b_data = [&]() -> const void* { + if (packed_b_) { + return packed_b_.get(); } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), - thread_pool); - - return Status::OK(); + const Tensor* b = ctx->Input(1); + return b->DataRaw(); + }(); + + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].QuantBData = b_data; + data[i].QuantBScale = scales_data; + data[i].QuantBZeroPoint = zero_points_data; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; } + + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + + return Status::OK(); } } diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index 047011e70bd4d..32e9cc98106d5 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -37,9 +37,7 @@ typedef enum { CompMostAccurate = CompUndef, CompLeastAccurate = CompInt8, -} MLAS_SQNBIT_COMPUTE_TYPE; - -using MLAS_SQNBIT_GEMM_COMPUTE_TYPE = MLAS_SQNBIT_COMPUTE_TYPE; // TODO consolidate these +} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; /** * @brief Data parameters for float/n-bit quantized int GEMM routine. @@ -102,18 +100,12 @@ MlasSQNBitGemmBatch( /** * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ bool MLASCALL MlasIsSQNBitGemmAvailable( - size_t M, - size_t N, - size_t K, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType @@ -153,13 +145,15 @@ MlasSQNBitGemmBatchWorkspaceSize( * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL MlasSQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, - size_t BlkLen + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -169,6 +163,7 @@ MlasSQNBitGemmPackQuantBDataSize( * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[in] QuantBData quantized B data * @param[out] PackedQuantBData packed quantized B data * @param[in] ThreadPool optional thread pool to use @@ -179,6 +174,7 @@ MlasSQNBitGemmPackQuantBData( size_t K, size_t BlkBitWidth, size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBData, MLAS_THREADPOOL* ThreadPool = nullptr diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp index 0d8a5692359a6..38c31c8841761 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.cpp @@ -39,23 +39,17 @@ enum SQNBitGemmVariant { SQNBitGemmVariant GetSQNBitGemmVariant( - size_t M, - size_t N, - size_t K, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - MLAS_UNREFERENCED_PARAMETER(N); - MLAS_UNREFERENCED_PARAMETER(K); - if (BlkBitWidth == 4 && (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { if (ComputeType == CompFp32 || ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == CompInt8 && M == 1) { + } else if (ComputeType == CompInt8) { return SQNBitGemmVariant_BitWidth4_CompInt8; } } @@ -67,9 +61,6 @@ GetSQNBitGemmVariant( bool MLASCALL MlasIsSQNBitGemmAvailable( - size_t M, - size_t N, - size_t K, size_t BlkBitWidth, size_t BlkLen, MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType @@ -80,7 +71,7 @@ MlasIsSQNBitGemmAvailable( return false; } - const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompFp32: { @@ -164,7 +155,7 @@ MlasSQNBitGemmBatchWorkspaceSize( MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(Variant, M, N, K, BlkLen); if (PerGemmWorkspaceStride == 0) { @@ -178,91 +169,24 @@ MlasSQNBitGemmBatchWorkspaceSize( return WorkspaceSize + Alignment - 1; } -namespace -{ - -void -SQ4BitGemmPackQuantBData( - size_t N, - size_t K, - size_t BlkLen, - const std::byte* QuantBDataBegin, - std::byte* PackedQuantBDataBegin, - MLAS_THREADPOOL* ThreadPool -) -{ - constexpr size_t BlkBitWidth = 4; - - assert(BlkLen % 16 == 0); - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - const size_t Iterations = N * BlockCountK; // one iteration per block - - MlasTrySimpleParallel( - ThreadPool, Iterations, - [&](ptrdiff_t tid) { - const size_t n = tid / BlockCountK; - const size_t k_blk = tid % BlockCountK; - - const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; - const std::byte* QuantBData = QuantBDataBegin + data_offset; - std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; - - // - // Pack 16 4-bit values (8 bytes) at a time like this: - // - // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | - // => - // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | - // - for (size_t kk = 0; kk < BlkLen; kk += 16) { - for (size_t byte_pair_idx = 0; byte_pair_idx < 4; ++byte_pair_idx) { - const std::byte src0 = QuantBData[byte_pair_idx]; - const std::byte src1 = QuantBData[byte_pair_idx + 4]; - - std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; - std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; - - dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); - dst1 = (src0 >> 4) | ((src1 >> 4) << 4); - } - - QuantBData += 8; - PackedQuantBData += 8; - } - } - ); -} - -} // namespace - size_t MLASCALL MlasSQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, - size_t BlkLen + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - // Ensure that a general implementation is available on this platform. - // For now, all implementations share the same packed format. - { - // Currently, there are implementations specific to M = 1, so pick a more general M > 1. - constexpr size_t M = 2; - // A CompUndef implementation should be available if any is available. - constexpr MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType = CompUndef; - const bool HasGeneralImplementation = - MlasIsSQNBitGemmAvailable(M, N, K, BlkBitWidth, BlkLen, ComputeType); - if (!HasGeneralImplementation) { - return 0; - } + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return 0; } - if (BlkBitWidth == 4) { - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->SQ4BitGemmPackQuantBDataSize( + N, K, BlkLen, ComputeType + ); } return 0; @@ -274,20 +198,28 @@ MlasSQNBitGemmPackQuantBData( size_t K, size_t BlkBitWidth, size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBData, MLAS_THREADPOOL* ThreadPool ) { - if (BlkBitWidth == 4) { - SQ4BitGemmPackQuantBData( + const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + if (Dispatch == nullptr) { + return; + } + + if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBData != nullptr) { + Dispatch->SQ4BitGemmPackQuantBData( N, K, BlkLen, + ComputeType, static_cast(QuantBData), static_cast(PackedQuantBData), ThreadPool ); + return; } } @@ -512,7 +444,37 @@ SQ4BitGemm_CompInt8( return; } - assert(false && "not implemented for M > 1"); + // This is a naive M > 1 implementation that repeatedly calls the M=1 kernel. + // TODO Replace it with an optimized implementation. + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* a_row = QuantA; + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + for (size_t m = 0; m < RangeCountM; ++m) { + GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompInt8( + BlkLen, + a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + + c_blk += ldc; + a_row += lda; + } + } } typedef void(InitializeWorkspaceFn)( @@ -594,7 +556,7 @@ MlasSQNBitGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - const auto Variant = GetSQNBitGemmVariant(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); assert(Variant != SQNBitGemmVariantInvalid); // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm.h b/onnxruntime/core/mlas/lib/sqnbitgemm.h index a66db79dc290a..3992bc3e452a3 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm.h +++ b/onnxruntime/core/mlas/lib/sqnbitgemm.h @@ -99,6 +99,33 @@ Q8BlkAlignment() // struct MLAS_SQNBIT_GEMM_DISPATCH { + // + // Quantized B data packing function prototypes. + // + + /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ + typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; + + /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ + typedef void(SQ4BitGemmPackQuantBData_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool + ); + + SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + // // CompFp32 kernel function prototypes. // diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp index 69fd427fa574a..c4c54a9be34d8 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp @@ -15,14 +15,115 @@ Module Name: --*/ -#include "sqnbitgemm.h" - #include #include #include #include +#include "sqnbitgemm.h" + +// +// Quantized B data packing function implementation. +// + +namespace +{ + +size_t +SQ4BitGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + + constexpr size_t BlkBitWidth = 4; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; +} + +void +SQ4BitGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkLen, + MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t Iterations = N * BlockCountK; // one iteration per block + + const size_t SubBlkLen = (ComputeType == CompInt8) + ? ((BlkLen == 16) ? 16 : 32) + : 16; + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const size_t data_offset = n * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset; + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset; + + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); +} + +} // namespace + +// +// General helpers. +// + namespace { @@ -95,7 +196,16 @@ LoadFloatData(const float* src, size_t count, float32x4_t (&dst)[Capacity / 4]) } } -template +} // namespace + +// +// CompFp32 kernel implementation. +// + +namespace +{ + +template MLAS_FORCEINLINE void ComputeDotProducts_BlkBitWidth4_CompFp32( size_t BlkLen, @@ -112,11 +222,11 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( ) { constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); - constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration - assert(BlkLen % SubBlkLen == 0); + assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); const uint8x8_t LowMask = vdup_n_u8(0x0F); @@ -137,7 +247,8 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( const std::byte* QuantBData = QuantBDataColPtr; const float* QuantBScale = QuantBScaleColPtr; - size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true for (size_t k = 0; k < CountK; k += BlkLen) { const size_t k_blk_len = std::min(CountK - k, BlkLen); @@ -147,8 +258,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( [&](size_t i) { scale[i] = QuantBScale[i * StrideQuantBScale]; } ); - float offset[NCols]; // Includes zero point and float conversion offset of 16. - if (QuantBZeroPointColPtr != nullptr) { + [[maybe_unused]] float offset[NCols]; // Includes zero point and float conversion offset of 16. + // only used if HasZeroPoint == true + if constexpr (HasZeroPoint) { UnrolledLoop([&](size_t i) { const std::byte zp_packed = QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; @@ -157,11 +269,6 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( : (zp_packed & std::byte{0x0F}); offset[i] = 16.0f + std::to_integer(zp); }); - } else { - UnrolledLoop([&](size_t i) { - constexpr float zp = 8.0f; - offset[i] = 16.0f + zp; - }); } for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { @@ -187,8 +294,6 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( bv_u8[i][1] = vshr_n_u8(bv_packed[i], 4); }); - // dequantize B - // shift left 3 and widen to 16 bits uint16x8_t bv_u16[NCols][2]; UnrolledLoop([&](size_t i) { @@ -217,10 +322,17 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( }); // subtract float conversion offset (16) and zero point - UnrolledLoop([&](size_t i) { - const float32x4_t offset_v = vdupq_n_f32(offset[i]); - UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); - }); + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const float32x4_t offset_v = vdupq_n_f32(offset[i]); + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } else { + const float32x4_t offset_v = vdupq_n_f32(16.0f + 8.0f); + UnrolledLoop([&](size_t i) { + UnrolledLoop<4>([&](size_t j) { bv[i][j] = vsubq_f32(bv[i][j], offset_v); }); + }); + } // multiply by scale UnrolledLoop([&](size_t i) { @@ -237,7 +349,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( // increment pointers to next block QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); QuantBScale += 1; - QuantBZeroPointIdx += 1; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } } if constexpr (NCols == 4) { @@ -258,8 +372,9 @@ ComputeDotProducts_BlkBitWidth4_CompFp32( } } -MLAS_FORCEINLINE void -SQ4BitGemmM1Kernel_CompFp32( +template +void +SQ4BitGemmM1Kernel_CompFp32_Impl( size_t BlkLen, const float* A, const std::byte* QuantBData, @@ -295,7 +410,7 @@ SQ4BitGemmM1Kernel_CompFp32( int64_t nblk = static_cast(CountN) - NCols; while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompFp32( + ComputeDotProducts_BlkBitWidth4_CompFp32( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -306,7 +421,7 @@ SQ4BitGemmM1Kernel_CompFp32( QuantBDataColPtr += NCols * StrideQuantBData; QuantBScaleColPtr += NCols * StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; } @@ -319,7 +434,7 @@ SQ4BitGemmM1Kernel_CompFp32( // left over columns less than `NCols`? nblk += NCols; for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompFp32<1>( + ComputeDotProducts_BlkBitWidth4_CompFp32<1, HasZeroPoint>( BlkLen, ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -330,7 +445,7 @@ SQ4BitGemmM1Kernel_CompFp32( QuantBDataColPtr += StrideQuantBData; QuantBScaleColPtr += StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += StrideQuantBZeroPoint; } @@ -339,6 +454,49 @@ SQ4BitGemmM1Kernel_CompFp32( } } +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompFp32_Impl( + BlkLen, + A, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } +} + MLAS_FORCEINLINE void Q4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, @@ -353,6 +511,7 @@ Q4BitBlkDequantBForSgemm_CompFp32( { auto impl0_reference = [&]() { constexpr size_t BlkBitWidth = 4; + constexpr size_t SubBlkLen = 16; float* Dst = FpData; @@ -378,11 +537,11 @@ Q4BitBlkDequantBForSgemm_CompFp32( : 8; for (size_t kk = 0; kk < kklen; ++kk) { - const size_t packed_idx = kk % 16; + const size_t packed_idx = kk % SubBlkLen; - const bool is_low_half = packed_idx < 8; - const size_t packed_byte_idx = packed_idx % 8; - const size_t packed_range_offset = (kk / 16) * 8; + const bool is_low_half = packed_idx < (SubBlkLen / 2); + const size_t packed_byte_idx = packed_idx % (SubBlkLen / 2); + const size_t packed_range_offset = (kk / SubBlkLen) * (SubBlkLen / 2); const std::byte b_packed = b_data[packed_range_offset + packed_byte_idx]; const std::byte b_byte = is_low_half ? (b_packed & std::byte{0x0F}) : (b_packed >> 4); @@ -415,7 +574,7 @@ Q4BitBlkDequantBForSgemm_CompFp32( } // -// CompInt8 kernel implementation and related helpers +// CompInt8 kernel implementation. // template @@ -431,8 +590,6 @@ QuantizeBlock( assert(BlkLen % SubBlkLen == 0); - constexpr size_t VectorCount = SubBlkLen / 4; - // // Scan block values first to determine scale. // @@ -443,16 +600,16 @@ QuantizeBlock( for (k = 0; k < ElementCount; k += SubBlkLen) { const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - float32x4_t a[VectorCount]{}; + float32x4_t a[SubBlkLen / 4]{}; LoadFloatData(A + k, SubBlkElementCount, a); - float32x4_t abs_a[VectorCount]; - UnrolledLoop([&](size_t i) { + float32x4_t abs_a[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { abs_a[i] = vabsq_f32(a[i]); }); // find amax of SubBlkLen elements - for (size_t interval = VectorCount / 2; interval > 0; interval /= 2) { + for (size_t interval = SubBlkLen / 4 / 2; interval > 0; interval /= 2) { for (size_t i = 0; i < interval; ++i) { abs_a[i] = vmaxq_f32(abs_a[i], abs_a[i + interval]); } @@ -477,19 +634,19 @@ QuantizeBlock( for (k = 0; k < ElementCount; k += SubBlkLen) { const size_t SubBlkElementCount = std::min(ElementCount - k, SubBlkLen); - float32x4_t a[VectorCount]{}; + float32x4_t a[SubBlkLen / 4]{}; LoadFloatData(A + k, SubBlkElementCount, a); - UnrolledLoop([&](size_t i) { + UnrolledLoop([&](size_t i) { a[i] = vmulq_n_f32(a[i], scale_reciprocal); }); - int32x4_t a_s32[VectorCount]; - UnrolledLoop([&](size_t i) { + int32x4_t a_s32[SubBlkLen / 4]; + UnrolledLoop([&](size_t i) { a_s32[i] = vcvtaq_s32_f32(a[i]); }); - UnrolledLoop([&](size_t i) { + UnrolledLoop([&](size_t i) { QuantAData[k + i * 4 + 0] = static_cast(vgetq_lane_s32(a_s32[i], 0)); QuantAData[k + i * 4 + 1] = static_cast(vgetq_lane_s32(a_s32[i], 1)); QuantAData[k + i * 4 + 2] = static_cast(vgetq_lane_s32(a_s32[i], 2)); @@ -530,7 +687,7 @@ QuantizeARow_CompInt8( } } -template +template MLAS_FORCEINLINE void ComputeDotProducts_BlkBitWidth4_CompInt8( size_t BlkLen, @@ -546,20 +703,22 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( const float* BiasPtr ) { - static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); - constexpr size_t BlkBitWidth = 4; - constexpr size_t SubBlkLen = 16; // number of block elements to process in a sub-block iteration - assert(BlkLen % SubBlkLen == 0); + static_assert(NCols == 1 || NCols == 4, "NCols must be 1 or 4"); + static_assert(SubBlkLen == 16 || SubBlkLen == 32, "SubBlkLen must be 16 or 32"); - const uint8x8_t LowMask = vdup_n_u8(0x0F); + assert(BlkLen >= SubBlkLen && BlkLen % SubBlkLen == 0); + + [[maybe_unused]] const uint8x8_t LowMaskU8x8 = vdup_n_u8(0x0F); // only used if SubBlkLen == 16 + [[maybe_unused]] const uint8x16_t LowMaskU8x16 = vdupq_n_u8(0x0F); // only used if SubBlkLen == 32 const std::byte* QuantA = QuantARowPtr; const std::byte* QuantBData = QuantBDataColPtr; const float* QuantBScale = QuantBScaleColPtr; - size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + // only used if HasZeroPoint == true float32x4_t acc[NCols]{}; @@ -572,8 +731,8 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( float b_scale[NCols]; UnrolledLoop([&](size_t i) { b_scale[i] = QuantBScale[i * StrideQuantBScale]; }); - int8_t b_zp[NCols]; - if (QuantBZeroPointColPtr != nullptr) { + [[maybe_unused]] int8_t b_zp[NCols]; // only used if HasZeroPoint == true + if constexpr (HasZeroPoint) { UnrolledLoop([&](size_t i) { const std::byte zp_packed = QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; @@ -581,42 +740,73 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( ? std::to_integer(zp_packed >> 4) : std::to_integer(zp_packed & std::byte{0x0F}); }); - } else { - UnrolledLoop([&](size_t i) { - b_zp[i] = 8; - }); } for (size_t k_idx_in_blk = 0; k_idx_in_blk < k_blk_len; k_idx_in_blk += SubBlkLen) { // load A row vector - int8x16_t av = vld1q_s8(a_data + k_idx_in_blk); + int8x16_t av[SubBlkLen / 16]; + UnrolledLoop([&](size_t i) { + av[i] = vld1q_s8(a_data + k_idx_in_blk + i * 16); + }); // load B column vectors - uint8x8_t bv_packed[NCols]; + int8x16_t bv[NCols][SubBlkLen / 16]; + const size_t b_data_block_offset = k_idx_in_blk * BlkBitWidth / 8; - UnrolledLoop([&](size_t i) { - bv_packed[i] = vld1_u8( - reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset - ); - }); - int8x16_t bv[NCols]; - UnrolledLoop([&](size_t i) { - const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMask)); - const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4)); - bv[i] = vcombine_s8(lo, hi); - }); + if constexpr (SubBlkLen == 16) { + uint8x8_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); + }); + + UnrolledLoop([&](size_t i) { + const int8x8_t lo = vreinterpret_s8_u8(vand_u8(bv_packed[i], LowMaskU8x8)); + const int8x8_t hi = vreinterpret_s8_u8(vshr_n_u8(bv_packed[i], 4)); + bv[i][0] = vcombine_s8(lo, hi); + }); + } else { + static_assert(SubBlkLen == 32); + + uint8x16_t bv_packed[NCols]; + UnrolledLoop([&](size_t i) { + bv_packed[i] = vld1q_u8( + reinterpret_cast(QuantBData) + i * StrideQuantBData + b_data_block_offset + ); + }); + + UnrolledLoop([&](size_t i) { + bv[i][0] = vreinterpretq_s8_u8(vandq_u8(bv_packed[i], LowMaskU8x16)); + bv[i][1] = vreinterpretq_s8_u8(vshrq_n_u8(bv_packed[i], 4)); + }); + } // subtract B zero point - UnrolledLoop([&](size_t i) { - const int8x16_t zp_v = vdupq_n_s8(b_zp[i]); - bv[i] = vsubq_s8(bv[i], zp_v); - }); + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const int8x16_t zp_v = vdupq_n_s8(b_zp[i]); + UnrolledLoop([&](size_t j) { + bv[i][j] = vsubq_s8(bv[i][j], zp_v); + }); + }); + } else { + const int8x16_t zp_v = vdupq_n_s8(8); + + UnrolledLoop([&](size_t i) { + UnrolledLoop([&](size_t j) { + bv[i][j] = vsubq_s8(bv[i][j], zp_v); + }); + }); + } // compute quantized dot product int32x4_t dot[NCols]{}; UnrolledLoop([&](size_t i) { - dot[i] = vdotq_s32(dot[i], av, bv[i]); + UnrolledLoop([&](size_t j) { + dot[i] = vdotq_s32(dot[i], av[j], bv[i][j]); + }); }); // convert dot product result to float @@ -636,7 +826,9 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( QuantA += Q8BlkSize(BlkLen); QuantBData += MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); QuantBScale += 1; - QuantBZeroPointIdx += 1; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } } if constexpr (NCols == 4) { @@ -657,9 +849,9 @@ ComputeDotProducts_BlkBitWidth4_CompInt8( } } -MLAS_FORCEINLINE +template void -SQ4BitGemmM1Kernel_CompInt8( +SQ4BitGemmM1Kernel_CompInt8_Impl( size_t BlkLen, const std::byte* QuantA, const std::byte* QuantBData, @@ -673,7 +865,6 @@ SQ4BitGemmM1Kernel_CompInt8( ) { constexpr size_t BlkBitWidth = 4; - constexpr size_t NCols = 4; const std::byte* QuantARowPtr = QuantA; float* CRowPtr = C; @@ -695,7 +886,7 @@ SQ4BitGemmM1Kernel_CompInt8( int64_t nblk = static_cast(CountN) - NCols; while (nblk >= 0) { - ComputeDotProducts_BlkBitWidth4_CompInt8( + ComputeDotProducts_BlkBitWidth4_CompInt8( BlkLen, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -706,7 +897,7 @@ SQ4BitGemmM1Kernel_CompInt8( QuantBDataColPtr += NCols * StrideQuantBData; QuantBScaleColPtr += NCols * StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += NCols * StrideQuantBZeroPoint; } @@ -719,7 +910,7 @@ SQ4BitGemmM1Kernel_CompInt8( // left over columns less than `NCols`? nblk += NCols; for (int64_t n = 0; n < nblk; ++n) { - ComputeDotProducts_BlkBitWidth4_CompInt8<1>( + ComputeDotProducts_BlkBitWidth4_CompInt8<1, SubBlkLen, HasZeroPoint>( BlkLen, QuantARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, @@ -730,7 +921,7 @@ SQ4BitGemmM1Kernel_CompInt8( QuantBDataColPtr += StrideQuantBData; QuantBScaleColPtr += StrideQuantBScale; - if (QuantBZeroPointColPtr != nullptr) { + if constexpr (HasZeroPoint) { QuantBZeroPointColPtr += StrideQuantBZeroPoint; } @@ -739,6 +930,94 @@ SQ4BitGemmM1Kernel_CompInt8( } } +template +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + SQ4BitGemmM1Kernel_CompInt8_Impl<4, 16, HasZeroPoint>( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompInt8_Impl<4, 32, HasZeroPoint>( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } +} + +MLAS_FORCEINLINE +void +SQ4BitGemmM1Kernel_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } else { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen( + BlkLen, + QuantA, + QuantBData, + QuantBScale, + QuantBZeroPoint, + C, + CountN, + CountK, + BlockStrideQuantB, + Bias + ); + } +} + } // namespace // @@ -748,8 +1027,12 @@ SQ4BitGemmM1Kernel_CompInt8( const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { MLAS_SQNBIT_GEMM_DISPATCH d; + d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32; + d.SQ4BitGemmM1Kernel_CompInt8 = SQ4BitGemmM1Kernel_CompInt8; d.QuantizeARow_CompInt8 = QuantizeARow_CompInt8; diff --git a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp index 668d7a0611367..b7b453415838a 100644 --- a/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/bench/bench_sqnbitgemm.cpp @@ -61,10 +61,11 @@ void SQNBITGEMM(benchmark::State& state) { } std::unique_ptr PackedQuantBData; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen); + if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData.data(), PackedQuantBData.get(), tp.get()); + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), + tp.get()); } MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; @@ -87,7 +88,9 @@ void SQNBITGEMM(benchmark::State& state) { } } -static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { +static void SQ4BitGemmArgs(benchmark::internal::Benchmark* b) { + constexpr size_t BlkBitWidth = 4; + b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "ComputeType"}); ArgsProductWithFilter(b, @@ -96,19 +99,17 @@ static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { {1, 1024, 2048}, // M {4096, 11008}, // N {4096, 11008}, // K - {8}, // Threads + {1, 8}, // Threads {int64_t{false}, int64_t{true}}, // Symmetric {int64_t{CompFp32}, int64_t{CompInt8}}}, // ComputeType - [](const std::vector& args) { + [&](const std::vector& args) { return MlasIsSQNBitGemmAvailable( - // M, N, K - narrow(args[1]), narrow(args[2]), narrow(args[3]), // BlkBitWidth, BlkLen - 4, narrow(args[0]), + BlkBitWidth, narrow(args[0]), // ComputeType static_cast(args[6])); }); } -BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); +BENCHMARK(SQNBITGEMM<4>)->Apply(SQ4BitGemmArgs)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp index 4fb8ab41745d5..ed09d7ee92b2a 100644 --- a/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp @@ -259,10 +259,11 @@ class MlasSQNBitGemmTest : public MlasTestBase { } void* PackedQuantBData = nullptr; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen); + if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, QuantBData, PackedQuantBData, GetMlasThreadPool()); + MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBData, + GetMlasThreadPool()); } if (ComputeType == CompFp32) { @@ -330,7 +331,7 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture Date: Tue, 30 Jan 2024 15:59:37 -0800 Subject: [PATCH 002/207] Move einsum's test data to constexpr variables (#19320) ### Description emscripten's C++ compiler has difficulty on compiling einsum_test.cc because the file has too many local variables. So I moved them to constexpr. --- cmake/onnxruntime_unittests.cmake | 3 +- .../test/providers/cpu/math/einsum_test.cc | 1670 +++++++++++++---- 2 files changed, 1316 insertions(+), 357 deletions(-) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 351ea1a95581b..714f35380ca02 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -825,8 +825,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") ) endif() list(REMOVE_ITEM all_tests "${TEST_SRC_DIR}/providers/cpu/reduction/reduction_ops_test.cc" - "${TEST_SRC_DIR}/providers/cpu/tensor/grid_sample_test.cc" - "${TEST_SRC_DIR}/providers/cpu/math/einsum_test.cc") + "${TEST_SRC_DIR}/providers/cpu/tensor/grid_sample_test.cc") endif() set(test_all_args) diff --git a/onnxruntime/test/providers/cpu/math/einsum_test.cc b/onnxruntime/test/providers/cpu/math/einsum_test.cc index 05b936a41e3c1..4e968d3de6b8a 100644 --- a/onnxruntime/test/providers/cpu/math/einsum_test.cc +++ b/onnxruntime/test/providers/cpu/math/einsum_test.cc @@ -769,374 +769,1334 @@ TEST(Einsum, ExplicitEinsumAsTensorContraction_Half) { // for two and three inputs (most common use-case of Einsum operator) struct EinsumTestCase { - std::string equation; - std::vector shape; - std::vector expected; - EinsumTestCase(const std::string& eq, const std::vector& sh, const std::vector& exp) : equation(eq), shape(sh), expected(exp) {} + std::string_view equation; + gsl::span shape; + gsl::span expected; }; +static constexpr std::string_view equation0 = "abc,cd->abc"; +static constexpr std::array shape0{2, 2, 2}; +static constexpr std::array expected0{0.f, 5.f, 2.f, 15.f, 4.f, 25.f, 6.f, 35.f}; +static constexpr std::string_view equation1 = "abc,cd->abd"; +static constexpr std::array shape1{2, 2, 2}; +static constexpr std::array expected1{2.f, 3.f, 6.f, 11.f, 10.f, 19.f, 14.f, 27.f}; +static constexpr std::string_view equation2 = "abc,cd->acd"; +static constexpr std::array shape2{2, 2, 2}; +static constexpr std::array expected2{0.f, 2.f, 8.f, 12.f, 0.f, 10.f, 24.f, 36.f}; +static constexpr std::string_view equation3 = "abc,dc->abd"; +static constexpr std::array shape3{2, 2, 2}; +static constexpr std::array expected3{1.f, 3.f, 3.f, 13.f, 5.f, 23.f, 7.f, 33.f}; +static constexpr std::string_view equation4 = "abc,dc->abc"; +static constexpr std::array shape4{2, 2, 2}; +static constexpr std::array expected4{0.f, 4.f, 4.f, 12.f, 8.f, 20.f, 12.f, 28.f}; +static constexpr std::string_view equation5 = "abc,dc->acd"; +static constexpr std::array shape5{2, 2, 2}; +static constexpr std::array expected5{0.f, 4.f, 4.f, 12.f, 0.f, 20.f, 12.f, 36.f}; +static constexpr std::string_view equation6 = "acb,cd->acd"; +static constexpr std::array shape6{2, 2, 2}; +static constexpr std::array expected6{0.f, 1.f, 10.f, 15.f, 0.f, 9.f, 26.f, 39.f}; +static constexpr std::string_view equation7 = "acb,cd->abc"; +static constexpr std::array shape7{2, 2, 2}; +static constexpr std::array expected7{0.f, 10.f, 1.f, 15.f, 4.f, 30.f, 5.f, 35.f}; +static constexpr std::string_view equation8 = "acb,cd->abd"; +static constexpr std::array shape8{2, 2, 2}; +static constexpr std::array expected8{4.f, 6.f, 6.f, 10.f, 12.f, 22.f, 14.f, 26.f}; +static constexpr std::string_view equation9 = "acb,dc->acd"; +static constexpr std::array shape9{2, 2, 2}; +static constexpr std::array expected9{0.f, 2.f, 5.f, 15.f, 0.f, 18.f, 13.f, 39.f}; +static constexpr std::string_view equation10 = "acb,dc->abd"; +static constexpr std::array shape10{2, 2, 2}; +static constexpr std::array expected10{2.f, 6.f, 3.f, 11.f, 6.f, 26.f, 7.f, 31.f}; +static constexpr std::string_view equation11 = "acb,dc->abc"; +static constexpr std::array shape11{2, 2, 2}; +static constexpr std::array expected11{0.f, 8.f, 2.f, 12.f, 8.f, 24.f, 10.f, 28.f}; +static constexpr std::string_view equation12 = "bac,cd->bac"; +static constexpr std::array shape12{2, 2, 2}; +static constexpr std::array expected12{0.f, 5.f, 2.f, 15.f, 4.f, 25.f, 6.f, 35.f}; +static constexpr std::string_view equation13 = "bac,cd->bad"; +static constexpr std::array shape13{2, 2, 2}; +static constexpr std::array expected13{2.f, 3.f, 6.f, 11.f, 10.f, 19.f, 14.f, 27.f}; +static constexpr std::string_view equation14 = "bac,cd->bcd"; +static constexpr std::array shape14{2, 2, 2}; +static constexpr std::array expected14{0.f, 2.f, 8.f, 12.f, 0.f, 10.f, 24.f, 36.f}; +static constexpr std::string_view equation15 = "bac,dc->bad"; +static constexpr std::array shape15{2, 2, 2}; +static constexpr std::array expected15{1.f, 3.f, 3.f, 13.f, 5.f, 23.f, 7.f, 33.f}; +static constexpr std::string_view equation16 = "bac,dc->bac"; +static constexpr std::array shape16{2, 2, 2}; +static constexpr std::array expected16{0.f, 4.f, 4.f, 12.f, 8.f, 20.f, 12.f, 28.f}; +static constexpr std::string_view equation17 = "bac,dc->bcd"; +static constexpr std::array shape17{2, 2, 2}; +static constexpr std::array expected17{0.f, 4.f, 4.f, 12.f, 0.f, 20.f, 12.f, 36.f}; +static constexpr std::string_view equation18 = "bca,cd->bcd"; +static constexpr std::array shape18{2, 2, 2}; +static constexpr std::array expected18{0.f, 1.f, 10.f, 15.f, 0.f, 9.f, 26.f, 39.f}; +static constexpr std::string_view equation19 = "bca,cd->bac"; +static constexpr std::array shape19{2, 2, 2}; +static constexpr std::array expected19{0.f, 10.f, 1.f, 15.f, 4.f, 30.f, 5.f, 35.f}; +static constexpr std::string_view equation20 = "bca,cd->bad"; +static constexpr std::array shape20{2, 2, 2}; +static constexpr std::array expected20{4.f, 6.f, 6.f, 10.f, 12.f, 22.f, 14.f, 26.f}; +static constexpr std::string_view equation21 = "bca,dc->bcd"; +static constexpr std::array shape21{2, 2, 2}; +static constexpr std::array expected21{0.f, 2.f, 5.f, 15.f, 0.f, 18.f, 13.f, 39.f}; +static constexpr std::string_view equation22 = "bca,dc->bad"; +static constexpr std::array shape22{2, 2, 2}; +static constexpr std::array expected22{2.f, 6.f, 3.f, 11.f, 6.f, 26.f, 7.f, 31.f}; +static constexpr std::string_view equation23 = "bca,dc->bac"; +static constexpr std::array shape23{2, 2, 2}; +static constexpr std::array expected23{0.f, 8.f, 2.f, 12.f, 8.f, 24.f, 10.f, 28.f}; +static constexpr std::string_view equation24 = "cab,cd->cad"; +static constexpr std::array shape24{2, 2, 2}; +static constexpr std::array expected24{0.f, 1.f, 0.f, 5.f, 18.f, 27.f, 26.f, 39.f}; +static constexpr std::string_view equation25 = "cab,cd->cbd"; +static constexpr std::array shape25{2, 2, 2}; +static constexpr std::array expected25{0.f, 2.f, 0.f, 4.f, 20.f, 30.f, 24.f, 36.f}; +static constexpr std::string_view equation26 = "cab,dc->cad"; +static constexpr std::array shape26{2, 2, 2}; +static constexpr std::array expected26{0.f, 2.f, 0.f, 10.f, 9.f, 27.f, 13.f, 39.f}; +static constexpr std::string_view equation27 = "cab,dc->cbd"; +static constexpr std::array shape27{2, 2, 2}; +static constexpr std::array expected27{0.f, 4.f, 0.f, 8.f, 10.f, 30.f, 12.f, 36.f}; +static constexpr std::string_view equation28 = "cba,cd->cbd"; +static constexpr std::array shape28{2, 2, 2}; +static constexpr std::array expected28{0.f, 1.f, 0.f, 5.f, 18.f, 27.f, 26.f, 39.f}; +static constexpr std::string_view equation29 = "cba,cd->cad"; +static constexpr std::array shape29{2, 2, 2}; +static constexpr std::array expected29{0.f, 2.f, 0.f, 4.f, 20.f, 30.f, 24.f, 36.f}; +static constexpr std::string_view equation30 = "cba,dc->cbd"; +static constexpr std::array shape30{2, 2, 2}; +static constexpr std::array expected30{0.f, 2.f, 0.f, 10.f, 9.f, 27.f, 13.f, 39.f}; +static constexpr std::string_view equation31 = "cba,dc->cad"; +static constexpr std::array shape31{2, 2, 2}; +static constexpr std::array expected31{0.f, 4.f, 0.f, 8.f, 10.f, 30.f, 12.f, 36.f}; +static constexpr std::array case0 = {{ + {equation0, shape0, expected0}, + {equation1, shape1, expected1}, + {equation2, shape2, expected2}, + {equation3, shape3, expected3}, + {equation4, shape4, expected4}, + {equation5, shape5, expected5}, + {equation6, shape6, expected6}, + {equation7, shape7, expected7}, + {equation8, shape8, expected8}, + {equation9, shape9, expected9}, + {equation10, shape10, expected10}, + {equation11, shape11, expected11}, + {equation12, shape12, expected12}, + {equation13, shape13, expected13}, + {equation14, shape14, expected14}, + {equation15, shape15, expected15}, + {equation16, shape16, expected16}, + {equation17, shape17, expected17}, + {equation18, shape18, expected18}, + {equation19, shape19, expected19}, + {equation20, shape20, expected20}, + {equation21, shape21, expected21}, + {equation22, shape22, expected22}, + {equation23, shape23, expected23}, + {equation24, shape24, expected24}, + {equation25, shape25, expected25}, + {equation26, shape26, expected26}, + {equation27, shape27, expected27}, + {equation28, shape28, expected28}, + {equation29, shape29, expected29}, + {equation30, shape30, expected30}, + {equation31, shape31, expected31}, +}}; + +static constexpr std::string_view equation32 = "abc,cd,def->abd"; +static constexpr std::array shape32{2, 2, 2}; +static constexpr std::array expected32{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}; +static constexpr std::string_view equation33 = "abc,cd,def->abe"; +static constexpr std::array shape33{2, 2, 2}; +static constexpr std::array expected33{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}; +static constexpr std::string_view equation34 = "abc,cd,def->acd"; +static constexpr std::array shape34{2, 2, 2}; +static constexpr std::array expected34{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}; +static constexpr std::string_view equation35 = "abc,cd,def->ace"; +static constexpr std::array shape35{2, 2, 2}; +static constexpr std::array expected35{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}; +static constexpr std::string_view equation36 = "abc,cd,dfe->abd"; +static constexpr std::array shape36{2, 2, 2}; +static constexpr std::array expected36{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}; +static constexpr std::string_view equation37 = "abc,cd,dfe->abf"; +static constexpr std::array shape37{2, 2, 2}; +static constexpr std::array expected37{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}; +static constexpr std::string_view equation38 = "abc,cd,dfe->acd"; +static constexpr std::array shape38{2, 2, 2}; +static constexpr std::array expected38{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}; +static constexpr std::string_view equation39 = "abc,cd,dfe->acf"; +static constexpr std::array shape39{2, 2, 2}; +static constexpr std::array expected39{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}; +static constexpr std::string_view equation40 = "abc,cd,edf->abe"; +static constexpr std::array shape40{2, 2, 2}; +static constexpr std::array expected40{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}; +static constexpr std::string_view equation41 = "abc,cd,edf->abd"; +static constexpr std::array shape41{2, 2, 2}; +static constexpr std::array expected41{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}; +static constexpr std::string_view equation42 = "abc,cd,edf->ace"; +static constexpr std::array shape42{2, 2, 2}; +static constexpr std::array expected42{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}; +static constexpr std::string_view equation43 = "abc,cd,edf->acd"; +static constexpr std::array shape43{2, 2, 2}; +static constexpr std::array expected43{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}; +static constexpr std::string_view equation44 = "abc,cd,efd->abe"; +static constexpr std::array shape44{2, 2, 2}; +static constexpr std::array expected44{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}; +static constexpr std::string_view equation45 = "abc,cd,efd->abf"; +static constexpr std::array shape45{2, 2, 2}; +static constexpr std::array expected45{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}; +static constexpr std::string_view equation46 = "abc,cd,efd->ace"; +static constexpr std::array shape46{2, 2, 2}; +static constexpr std::array expected46{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}; +static constexpr std::string_view equation47 = "abc,cd,efd->acf"; +static constexpr std::array shape47{2, 2, 2}; +static constexpr std::array expected47{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}; +static constexpr std::string_view equation48 = "abc,cd,fde->abf"; +static constexpr std::array shape48{2, 2, 2}; +static constexpr std::array expected48{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}; +static constexpr std::string_view equation49 = "abc,cd,fde->abd"; +static constexpr std::array shape49{2, 2, 2}; +static constexpr std::array expected49{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}; +static constexpr std::string_view equation50 = "abc,cd,fde->acf"; +static constexpr std::array shape50{2, 2, 2}; +static constexpr std::array expected50{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}; +static constexpr std::string_view equation51 = "abc,cd,fde->acd"; +static constexpr std::array shape51{2, 2, 2}; +static constexpr std::array expected51{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}; +static constexpr std::string_view equation52 = "abc,cd,fed->abf"; +static constexpr std::array shape52{2, 2, 2}; +static constexpr std::array expected52{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}; +static constexpr std::string_view equation53 = "abc,cd,fed->abe"; +static constexpr std::array shape53{2, 2, 2}; +static constexpr std::array expected53{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}; +static constexpr std::string_view equation54 = "abc,cd,fed->acf"; +static constexpr std::array shape54{2, 2, 2}; +static constexpr std::array expected54{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}; +static constexpr std::string_view equation55 = "abc,cd,fed->ace"; +static constexpr std::array shape55{2, 2, 2}; +static constexpr std::array expected55{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}; +static constexpr std::string_view equation56 = "abc,dc,def->abd"; +static constexpr std::array shape56{2, 2, 2}; +static constexpr std::array expected56{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}; +static constexpr std::string_view equation57 = "abc,dc,def->abe"; +static constexpr std::array shape57{2, 2, 2}; +static constexpr std::array expected57{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}; +static constexpr std::string_view equation58 = "abc,dc,def->acd"; +static constexpr std::array shape58{2, 2, 2}; +static constexpr std::array expected58{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}; +static constexpr std::string_view equation59 = "abc,dc,def->ace"; +static constexpr std::array shape59{2, 2, 2}; +static constexpr std::array expected59{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}; +static constexpr std::string_view equation60 = "abc,dc,dfe->abd"; +static constexpr std::array shape60{2, 2, 2}; +static constexpr std::array expected60{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}; +static constexpr std::string_view equation61 = "abc,dc,dfe->abf"; +static constexpr std::array shape61{2, 2, 2}; +static constexpr std::array expected61{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}; +static constexpr std::string_view equation62 = "abc,dc,dfe->acd"; +static constexpr std::array shape62{2, 2, 2}; +static constexpr std::array expected62{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}; +static constexpr std::string_view equation63 = "abc,dc,dfe->acf"; +static constexpr std::array shape63{2, 2, 2}; +static constexpr std::array expected63{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}; +static constexpr std::string_view equation64 = "abc,dc,edf->abe"; +static constexpr std::array shape64{2, 2, 2}; +static constexpr std::array expected64{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}; +static constexpr std::string_view equation65 = "abc,dc,edf->abd"; +static constexpr std::array shape65{2, 2, 2}; +static constexpr std::array expected65{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}; +static constexpr std::string_view equation66 = "abc,dc,edf->ace"; +static constexpr std::array shape66{2, 2, 2}; +static constexpr std::array expected66{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}; +static constexpr std::string_view equation67 = "abc,dc,edf->acd"; +static constexpr std::array shape67{2, 2, 2}; +static constexpr std::array expected67{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}; +static constexpr std::string_view equation68 = "abc,dc,efd->abe"; +static constexpr std::array shape68{2, 2, 2}; +static constexpr std::array expected68{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}; +static constexpr std::string_view equation69 = "abc,dc,efd->abf"; +static constexpr std::array shape69{2, 2, 2}; +static constexpr std::array expected69{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}; +static constexpr std::string_view equation70 = "abc,dc,efd->ace"; +static constexpr std::array shape70{2, 2, 2}; +static constexpr std::array expected70{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}; +static constexpr std::string_view equation71 = "abc,dc,efd->acf"; +static constexpr std::array shape71{2, 2, 2}; +static constexpr std::array expected71{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}; +static constexpr std::string_view equation72 = "abc,dc,fde->abf"; +static constexpr std::array shape72{2, 2, 2}; +static constexpr std::array expected72{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}; +static constexpr std::string_view equation73 = "abc,dc,fde->abd"; +static constexpr std::array shape73{2, 2, 2}; +static constexpr std::array expected73{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}; +static constexpr std::string_view equation74 = "abc,dc,fde->acf"; +static constexpr std::array shape74{2, 2, 2}; +static constexpr std::array expected74{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}; +static constexpr std::string_view equation75 = "abc,dc,fde->acd"; +static constexpr std::array shape75{2, 2, 2}; +static constexpr std::array expected75{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}; +static constexpr std::string_view equation76 = "abc,dc,fed->abf"; +static constexpr std::array shape76{2, 2, 2}; +static constexpr std::array expected76{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}; +static constexpr std::string_view equation77 = "abc,dc,fed->abe"; +static constexpr std::array shape77{2, 2, 2}; +static constexpr std::array expected77{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}; +static constexpr std::string_view equation78 = "abc,dc,fed->acf"; +static constexpr std::array shape78{2, 2, 2}; +static constexpr std::array expected78{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}; +static constexpr std::string_view equation79 = "abc,dc,fed->ace"; +static constexpr std::array shape79{2, 2, 2}; +static constexpr std::array expected79{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}; +static constexpr std::string_view equation80 = "acb,cd,def->acd"; +static constexpr std::array shape80{2, 2, 2}; +static constexpr std::array expected80{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}; +static constexpr std::string_view equation81 = "acb,cd,def->ace"; +static constexpr std::array shape81{2, 2, 2}; +static constexpr std::array expected81{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}; +static constexpr std::string_view equation82 = "acb,cd,def->abd"; +static constexpr std::array shape82{2, 2, 2}; +static constexpr std::array expected82{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}; +static constexpr std::string_view equation83 = "acb,cd,def->abe"; +static constexpr std::array shape83{2, 2, 2}; +static constexpr std::array expected83{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}; +static constexpr std::string_view equation84 = "acb,cd,dfe->acd"; +static constexpr std::array shape84{2, 2, 2}; +static constexpr std::array expected84{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}; +static constexpr std::string_view equation85 = "acb,cd,dfe->acf"; +static constexpr std::array shape85{2, 2, 2}; +static constexpr std::array expected85{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}; +static constexpr std::string_view equation86 = "acb,cd,dfe->abd"; +static constexpr std::array shape86{2, 2, 2}; +static constexpr std::array expected86{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}; +static constexpr std::string_view equation87 = "acb,cd,dfe->abf"; +static constexpr std::array shape87{2, 2, 2}; +static constexpr std::array expected87{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}; +static constexpr std::string_view equation88 = "acb,cd,edf->ace"; +static constexpr std::array shape88{2, 2, 2}; +static constexpr std::array expected88{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}; +static constexpr std::string_view equation89 = "acb,cd,edf->acd"; +static constexpr std::array shape89{2, 2, 2}; +static constexpr std::array expected89{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}; +static constexpr std::string_view equation90 = "acb,cd,edf->abe"; +static constexpr std::array shape90{2, 2, 2}; +static constexpr std::array expected90{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}; +static constexpr std::string_view equation91 = "acb,cd,edf->abd"; +static constexpr std::array shape91{2, 2, 2}; +static constexpr std::array expected91{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}; +static constexpr std::string_view equation92 = "acb,cd,efd->ace"; +static constexpr std::array shape92{2, 2, 2}; +static constexpr std::array expected92{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}; +static constexpr std::string_view equation93 = "acb,cd,efd->acf"; +static constexpr std::array shape93{2, 2, 2}; +static constexpr std::array expected93{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}; +static constexpr std::string_view equation94 = "acb,cd,efd->abe"; +static constexpr std::array shape94{2, 2, 2}; +static constexpr std::array expected94{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}; +static constexpr std::string_view equation95 = "acb,cd,efd->abf"; +static constexpr std::array shape95{2, 2, 2}; +static constexpr std::array expected95{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}; +static constexpr std::string_view equation96 = "acb,cd,fde->acf"; +static constexpr std::array shape96{2, 2, 2}; +static constexpr std::array expected96{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}; +static constexpr std::string_view equation97 = "acb,cd,fde->acd"; +static constexpr std::array shape97{2, 2, 2}; +static constexpr std::array expected97{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}; +static constexpr std::string_view equation98 = "acb,cd,fde->abf"; +static constexpr std::array shape98{2, 2, 2}; +static constexpr std::array expected98{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}; +static constexpr std::string_view equation99 = "acb,cd,fde->abd"; +static constexpr std::array shape99{2, 2, 2}; +static constexpr std::array expected99{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}; +static constexpr std::string_view equation100 = "acb,cd,fed->acf"; +static constexpr std::array shape100{2, 2, 2}; +static constexpr std::array expected100{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}; +static constexpr std::string_view equation101 = "acb,cd,fed->ace"; +static constexpr std::array shape101{2, 2, 2}; +static constexpr std::array expected101{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}; +static constexpr std::string_view equation102 = "acb,cd,fed->abf"; +static constexpr std::array shape102{2, 2, 2}; +static constexpr std::array expected102{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}; +static constexpr std::string_view equation103 = "acb,cd,fed->abe"; +static constexpr std::array shape103{2, 2, 2}; +static constexpr std::array expected103{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}; +static constexpr std::string_view equation104 = "acb,dc,def->acd"; +static constexpr std::array shape104{2, 2, 2}; +static constexpr std::array expected104{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}; +static constexpr std::string_view equation105 = "acb,dc,def->ace"; +static constexpr std::array shape105{2, 2, 2}; +static constexpr std::array expected105{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}; + +static constexpr std::string_view equation106 = "acb,dc,def->abd"; +static constexpr std::array shape106{2, 2, 2}; +static constexpr std::array expected106{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}; +static constexpr std::string_view equation107 = "acb,dc,def->abe"; +static constexpr std::array shape107{2, 2, 2}; +static constexpr std::array expected107{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}; +static constexpr std::string_view equation108 = "acb,dc,dfe->acd"; +static constexpr std::array shape108{2, 2, 2}; +static constexpr std::array expected108{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}; +static constexpr std::string_view equation109 = "acb,dc,dfe->acf"; +static constexpr std::array shape109{2, 2, 2}; +static constexpr std::array expected109{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}; +static constexpr std::string_view equation110 = "acb,dc,dfe->abd"; +static constexpr std::array shape110{2, 2, 2}; +static constexpr std::array expected110{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}; +static constexpr std::string_view equation111 = "acb,dc,dfe->abf"; +static constexpr std::array shape111{2, 2, 2}; +static constexpr std::array expected111{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}; +static constexpr std::string_view equation112 = "acb,dc,edf->ace"; +static constexpr std::array shape112{2, 2, 2}; +static constexpr std::array expected112{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}; +static constexpr std::string_view equation113 = "acb,dc,edf->acd"; +static constexpr std::array shape113{2, 2, 2}; +static constexpr std::array expected113{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}; +static constexpr std::string_view equation114 = "acb,dc,edf->abe"; +static constexpr std::array shape114{2, 2, 2}; +static constexpr std::array expected114{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}; +static constexpr std::string_view equation115 = "acb,dc,edf->abd"; +static constexpr std::array shape115{2, 2, 2}; +static constexpr std::array expected115{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}; +static constexpr std::string_view equation116 = "acb,dc,efd->ace"; +static constexpr std::array shape116{2, 2, 2}; +static constexpr std::array expected116{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}; +static constexpr std::string_view equation117 = "acb,dc,efd->acf"; +static constexpr std::array shape117{2, 2, 2}; +static constexpr std::array expected117{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}; +static constexpr std::string_view equation118 = "acb,dc,efd->abe"; +static constexpr std::array shape118{2, 2, 2}; +static constexpr std::array expected118{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}; +static constexpr std::string_view equation119 = "acb,dc,efd->abf"; +static constexpr std::array shape119{2, 2, 2}; +static constexpr std::array expected119{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}; +static constexpr std::string_view equation120 = "acb,dc,fde->acf"; +static constexpr std::array shape120{2, 2, 2}; +static constexpr std::array expected120{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}; +static constexpr std::string_view equation121 = "acb,dc,fde->acd"; +static constexpr std::array shape121{2, 2, 2}; +static constexpr std::array expected121{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}; +static constexpr std::string_view equation122 = "acb,dc,fde->abf"; +static constexpr std::array shape122{2, 2, 2}; +static constexpr std::array expected122{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}; +static constexpr std::string_view equation123 = "acb,dc,fde->abd"; +static constexpr std::array shape123{2, 2, 2}; +static constexpr std::array expected123{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}; +static constexpr std::string_view equation124 = "acb,dc,fed->acf"; +static constexpr std::array shape124{2, 2, 2}; +static constexpr std::array expected124{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}; +static constexpr std::string_view equation125 = "acb,dc,fed->ace"; +static constexpr std::array shape125{2, 2, 2}; +static constexpr std::array expected125{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}; +static constexpr std::string_view equation126 = "acb,dc,fed->abf"; +static constexpr std::array shape126{2, 2, 2}; +static constexpr std::array expected126{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}; +static constexpr std::string_view equation127 = "acb,dc,fed->abe"; +static constexpr std::array shape127{2, 2, 2}; +static constexpr std::array expected127{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}; +static constexpr std::string_view equation128 = "bac,cd,def->bad"; +static constexpr std::array shape128{2, 2, 2}; +static constexpr std::array expected128{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}; +static constexpr std::string_view equation129 = "bac,cd,def->bae"; +static constexpr std::array shape129{2, 2, 2}; +static constexpr std::array expected129{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}; +static constexpr std::string_view equation130 = "bac,cd,def->bcd"; +static constexpr std::array shape130{2, 2, 2}; +static constexpr std::array expected130{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}; +static constexpr std::string_view equation131 = "bac,cd,def->bce"; +static constexpr std::array shape131{2, 2, 2}; +static constexpr std::array expected131{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}; +static constexpr std::string_view equation132 = "bac,cd,dfe->bad"; +static constexpr std::array shape132{2, 2, 2}; +static constexpr std::array expected132{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}; +static constexpr std::string_view equation133 = "bac,cd,dfe->baf"; +static constexpr std::array shape133{2, 2, 2}; +static constexpr std::array expected133{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}; +static constexpr std::string_view equation134 = "bac,cd,dfe->bcd"; +static constexpr std::array shape134{2, 2, 2}; +static constexpr std::array expected134{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}; +static constexpr std::string_view equation135 = "bac,cd,dfe->bcf"; +static constexpr std::array shape135{2, 2, 2}; +static constexpr std::array expected135{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}; +static constexpr std::string_view equation136 = "bac,cd,edf->bae"; +static constexpr std::array shape136{2, 2, 2}; +static constexpr std::array expected136{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}; +static constexpr std::string_view equation137 = "bac,cd,edf->bad"; +static constexpr std::array shape137{2, 2, 2}; +static constexpr std::array expected137{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}; +static constexpr std::string_view equation138 = "bac,cd,edf->bce"; +static constexpr std::array shape138{2, 2, 2}; +static constexpr std::array expected138{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}; +static constexpr std::string_view equation139 = "bac,cd,edf->bcd"; +static constexpr std::array shape139{2, 2, 2}; +static constexpr std::array expected139{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}; +static constexpr std::string_view equation140 = "bac,cd,efd->bae"; +static constexpr std::array shape140{2, 2, 2}; +static constexpr std::array expected140{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}; +static constexpr std::string_view equation141 = "bac,cd,efd->baf"; +static constexpr std::array shape141{2, 2, 2}; +static constexpr std::array expected141{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}; +static constexpr std::string_view equation142 = "bac,cd,efd->bce"; +static constexpr std::array shape142{2, 2, 2}; +static constexpr std::array expected142{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}; +static constexpr std::string_view equation143 = "bac,cd,efd->bcf"; +static constexpr std::array shape143{2, 2, 2}; +static constexpr std::array expected143{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}; +static constexpr std::string_view equation144 = "bac,cd,fde->baf"; +static constexpr std::array shape144{2, 2, 2}; +static constexpr std::array expected144{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}; +static constexpr std::string_view equation145 = "bac,cd,fde->bad"; +static constexpr std::array shape145{2, 2, 2}; +static constexpr std::array expected145{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}; +static constexpr std::string_view equation146 = "bac,cd,fde->bcf"; +static constexpr std::array shape146{2, 2, 2}; +static constexpr std::array expected146{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}; +static constexpr std::string_view equation147 = "bac,cd,fde->bcd"; +static constexpr std::array shape147{2, 2, 2}; +static constexpr std::array expected147{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}; +static constexpr std::string_view equation148 = "bac,cd,fed->baf"; +static constexpr std::array shape148{2, 2, 2}; +static constexpr std::array expected148{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}; +static constexpr std::string_view equation149 = "bac,cd,fed->bae"; +static constexpr std::array shape149{2, 2, 2}; +static constexpr std::array expected149{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}; +static constexpr std::string_view equation150 = "bac,cd,fed->bcf"; +static constexpr std::array shape150{2, 2, 2}; +static constexpr std::array expected150{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}; +static constexpr std::string_view equation151 = "bac,cd,fed->bce"; +static constexpr std::array shape151{2, 2, 2}; +static constexpr std::array expected151{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}; +static constexpr std::string_view equation152 = "bac,dc,def->bad"; +static constexpr std::array shape152{2, 2, 2}; +static constexpr std::array expected152{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}; +static constexpr std::string_view equation153 = "bac,dc,def->bae"; +static constexpr std::array shape153{2, 2, 2}; +static constexpr std::array expected153{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}; +static constexpr std::string_view equation154 = "bac,dc,def->bcd"; +static constexpr std::array shape154{2, 2, 2}; +static constexpr std::array expected154{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}; +static constexpr std::string_view equation155 = "bac,dc,def->bce"; +static constexpr std::array shape155{2, 2, 2}; +static constexpr std::array expected155{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}; +static constexpr std::string_view equation156 = "bac,dc,dfe->bad"; +static constexpr std::array shape156{2, 2, 2}; +static constexpr std::array expected156{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}; +static constexpr std::string_view equation157 = "bac,dc,dfe->baf"; +static constexpr std::array shape157{2, 2, 2}; +static constexpr std::array expected157{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}; +static constexpr std::string_view equation158 = "bac,dc,dfe->bcd"; +static constexpr std::array shape158{2, 2, 2}; +static constexpr std::array expected158{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}; +static constexpr std::string_view equation159 = "bac,dc,dfe->bcf"; +static constexpr std::array shape159{2, 2, 2}; +static constexpr std::array expected159{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}; +static constexpr std::string_view equation160 = "bac,dc,edf->bae"; +static constexpr std::array shape160{2, 2, 2}; +static constexpr std::array expected160{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}; +static constexpr std::string_view equation161 = "bac,dc,edf->bad"; +static constexpr std::array shape161{2, 2, 2}; +static constexpr std::array expected161{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}; +static constexpr std::string_view equation162 = "bac,dc,edf->bce"; +static constexpr std::array shape162{2, 2, 2}; +static constexpr std::array expected162{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}; +static constexpr std::string_view equation163 = "bac,dc,edf->bcd"; +static constexpr std::array shape163{2, 2, 2}; +static constexpr std::array expected163{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}; +static constexpr std::string_view equation164 = "bac,dc,efd->bae"; +static constexpr std::array shape164{2, 2, 2}; +static constexpr std::array expected164{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}; +static constexpr std::string_view equation165 = "bac,dc,efd->baf"; +static constexpr std::array shape165{2, 2, 2}; +static constexpr std::array expected165{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}; +static constexpr std::string_view equation166 = "bac,dc,efd->bce"; +static constexpr std::array shape166{2, 2, 2}; +static constexpr std::array expected166{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}; +static constexpr std::string_view equation167 = "bac,dc,efd->bcf"; +static constexpr std::array shape167{2, 2, 2}; +static constexpr std::array expected167{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}; +static constexpr std::string_view equation168 = "bac,dc,fde->baf"; +static constexpr std::array shape168{2, 2, 2}; +static constexpr std::array expected168{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}; +static constexpr std::string_view equation169 = "bac,dc,fde->bad"; +static constexpr std::array shape169{2, 2, 2}; +static constexpr std::array expected169{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}; +static constexpr std::string_view equation170 = "bac,dc,fde->bcf"; +static constexpr std::array shape170{2, 2, 2}; +static constexpr std::array expected170{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}; +static constexpr std::string_view equation171 = "bac,dc,fde->bcd"; +static constexpr std::array shape171{2, 2, 2}; +static constexpr std::array expected171{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}; +static constexpr std::string_view equation172 = "bac,dc,fed->baf"; +static constexpr std::array shape172{2, 2, 2}; +static constexpr std::array expected172{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}; +static constexpr std::string_view equation173 = "bac,dc,fed->bae"; +static constexpr std::array shape173{2, 2, 2}; +static constexpr std::array expected173{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}; +static constexpr std::string_view equation174 = "bac,dc,fed->bcf"; +static constexpr std::array shape174{2, 2, 2}; +static constexpr std::array expected174{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}; +static constexpr std::string_view equation175 = "bac,dc,fed->bce"; +static constexpr std::array shape175{2, 2, 2}; +static constexpr std::array expected175{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}; +static constexpr std::string_view equation176 = "bca,cd,def->bcd"; +static constexpr std::array shape176{2, 2, 2}; +static constexpr std::array expected176{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}; +static constexpr std::string_view equation177 = "bca,cd,def->bce"; +static constexpr std::array shape177{2, 2, 2}; +static constexpr std::array expected177{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}; +static constexpr std::string_view equation178 = "bca,cd,def->bad"; +static constexpr std::array shape178{2, 2, 2}; +static constexpr std::array expected178{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}; +static constexpr std::string_view equation179 = "bca,cd,def->bae"; +static constexpr std::array shape179{2, 2, 2}; +static constexpr std::array expected179{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}; +static constexpr std::string_view equation180 = "bca,cd,dfe->bcd"; +static constexpr std::array shape180{2, 2, 2}; +static constexpr std::array expected180{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}; +static constexpr std::string_view equation181 = "bca,cd,dfe->bcf"; +static constexpr std::array shape181{2, 2, 2}; +static constexpr std::array expected181{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}; +static constexpr std::string_view equation182 = "bca,cd,dfe->bad"; +static constexpr std::array shape182{2, 2, 2}; +static constexpr std::array expected182{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}; +static constexpr std::string_view equation183 = "bca,cd,dfe->baf"; +static constexpr std::array shape183{2, 2, 2}; +static constexpr std::array expected183{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}; +static constexpr std::string_view equation184 = "bca,cd,edf->bce"; +static constexpr std::array shape184{2, 2, 2}; +static constexpr std::array expected184{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}; +static constexpr std::string_view equation185 = "bca,cd,edf->bcd"; +static constexpr std::array shape185{2, 2, 2}; +static constexpr std::array expected185{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}; +static constexpr std::string_view equation186 = "bca,cd,edf->bae"; +static constexpr std::array shape186{2, 2, 2}; +static constexpr std::array expected186{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}; +static constexpr std::string_view equation187 = "bca,cd,edf->bad"; +static constexpr std::array shape187{2, 2, 2}; +static constexpr std::array expected187{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}; +static constexpr std::string_view equation188 = "bca,cd,efd->bce"; +static constexpr std::array shape188{2, 2, 2}; +static constexpr std::array expected188{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}; +static constexpr std::string_view equation189 = "bca,cd,efd->bcf"; +static constexpr std::array shape189{2, 2, 2}; +static constexpr std::array expected189{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}; +static constexpr std::string_view equation190 = "bca,cd,efd->bae"; +static constexpr std::array shape190{2, 2, 2}; +static constexpr std::array expected190{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}; +static constexpr std::string_view equation191 = "bca,cd,efd->baf"; +static constexpr std::array shape191{2, 2, 2}; +static constexpr std::array expected191{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}; +static constexpr std::string_view equation192 = "bca,cd,fde->bcf"; +static constexpr std::array shape192{2, 2, 2}; +static constexpr std::array expected192{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}; +static constexpr std::string_view equation193 = "bca,cd,fde->bcd"; +static constexpr std::array shape193{2, 2, 2}; +static constexpr std::array expected193{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}; +static constexpr std::string_view equation194 = "bca,cd,fde->baf"; +static constexpr std::array shape194{2, 2, 2}; +static constexpr std::array expected194{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}; +static constexpr std::string_view equation195 = "bca,cd,fde->bad"; +static constexpr std::array shape195{2, 2, 2}; +static constexpr std::array expected195{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}; +static constexpr std::string_view equation196 = "bca,cd,fed->bcf"; +static constexpr std::array shape196{2, 2, 2}; +static constexpr std::array expected196{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}; +static constexpr std::string_view equation197 = "bca,cd,fed->bce"; +static constexpr std::array shape197{2, 2, 2}; +static constexpr std::array expected197{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}; +static constexpr std::string_view equation198 = "bca,cd,fed->baf"; +static constexpr std::array shape198{2, 2, 2}; +static constexpr std::array expected198{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}; +static constexpr std::string_view equation199 = "bca,cd,fed->bae"; +static constexpr std::array shape199{2, 2, 2}; +static constexpr std::array expected199{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}; +static constexpr std::string_view equation200 = "bca,dc,def->bcd"; +static constexpr std::array shape200{2, 2, 2}; +static constexpr std::array expected200{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}; +static constexpr std::string_view equation201 = "bca,dc,def->bce"; +static constexpr std::array shape201{2, 2, 2}; +static constexpr std::array expected201{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}; +static constexpr std::string_view equation202 = "bca,dc,def->bad"; +static constexpr std::array shape202{2, 2, 2}; +static constexpr std::array expected202{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}; +static constexpr std::string_view equation203 = "bca,dc,def->bae"; +static constexpr std::array shape203{2, 2, 2}; +static constexpr std::array expected203{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}; +static constexpr std::string_view equation204 = "bca,dc,dfe->bcd"; +static constexpr std::array shape204{2, 2, 2}; +static constexpr std::array expected204{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}; +static constexpr std::string_view equation205 = "bca,dc,dfe->bcf"; +static constexpr std::array shape205{2, 2, 2}; +static constexpr std::array expected205{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}; +static constexpr std::string_view equation206 = "bca,dc,dfe->bad"; +static constexpr std::array shape206{2, 2, 2}; +static constexpr std::array expected206{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}; +static constexpr std::string_view equation207 = "bca,dc,dfe->baf"; +static constexpr std::array shape207{2, 2, 2}; +static constexpr std::array expected207{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}; +static constexpr std::string_view equation208 = "bca,dc,edf->bce"; +static constexpr std::array shape208{2, 2, 2}; +static constexpr std::array expected208{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}; +static constexpr std::string_view equation209 = "bca,dc,edf->bcd"; +static constexpr std::array shape209{2, 2, 2}; +static constexpr std::array expected209{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}; +static constexpr std::string_view equation210 = "bca,dc,edf->bae"; +static constexpr std::array shape210{2, 2, 2}; +static constexpr std::array expected210{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}; +static constexpr std::string_view equation211 = "bca,dc,edf->bad"; +static constexpr std::array shape211{2, 2, 2}; +static constexpr std::array expected211{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}; +static constexpr std::string_view equation212 = "bca,dc,efd->bce"; +static constexpr std::array shape212{2, 2, 2}; +static constexpr std::array expected212{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}; +static constexpr std::string_view equation213 = "bca,dc,efd->bcf"; +static constexpr std::array shape213{2, 2, 2}; +static constexpr std::array expected213{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}; +static constexpr std::string_view equation214 = "bca,dc,efd->bae"; +static constexpr std::array shape214{2, 2, 2}; +static constexpr std::array expected214{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}; +static constexpr std::string_view equation215 = "bca,dc,efd->baf"; +static constexpr std::array shape215{2, 2, 2}; +static constexpr std::array expected215{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}; +static constexpr std::string_view equation216 = "bca,dc,fde->bcf"; +static constexpr std::array shape216{2, 2, 2}; +static constexpr std::array expected216{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}; +static constexpr std::string_view equation217 = "bca,dc,fde->bcd"; +static constexpr std::array shape217{2, 2, 2}; +static constexpr std::array expected217{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}; +static constexpr std::string_view equation218 = "bca,dc,fde->baf"; +static constexpr std::array shape218{2, 2, 2}; +static constexpr std::array expected218{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}; +static constexpr std::string_view equation219 = "bca,dc,fde->bad"; +static constexpr std::array shape219{2, 2, 2}; +static constexpr std::array expected219{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}; +static constexpr std::string_view equation220 = "bca,dc,fed->bcf"; +static constexpr std::array shape220{2, 2, 2}; +static constexpr std::array expected220{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}; +static constexpr std::string_view equation221 = "bca,dc,fed->bce"; +static constexpr std::array shape221{2, 2, 2}; +static constexpr std::array expected221{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}; +static constexpr std::string_view equation222 = "bca,dc,fed->baf"; +static constexpr std::array shape222{2, 2, 2}; +static constexpr std::array expected222{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}; +static constexpr std::string_view equation223 = "bca,dc,fed->bae"; +static constexpr std::array shape223{2, 2, 2}; +static constexpr std::array expected223{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}; +static constexpr std::string_view equation224 = "cab,cd,def->cad"; +static constexpr std::array shape224{2, 2, 2}; +static constexpr std::array expected224{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}; +static constexpr std::string_view equation225 = "cab,cd,def->cae"; +static constexpr std::array shape225{2, 2, 2}; +static constexpr std::array expected225{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}; +static constexpr std::string_view equation226 = "cab,cd,def->cbd"; +static constexpr std::array shape226{2, 2, 2}; +static constexpr std::array expected226{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}; +static constexpr std::string_view equation227 = "cab,cd,def->cbe"; +static constexpr std::array shape227{2, 2, 2}; +static constexpr std::array expected227{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}; +static constexpr std::string_view equation228 = "cab,cd,dfe->cad"; +static constexpr std::array shape228{2, 2, 2}; +static constexpr std::array expected228{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}; +static constexpr std::string_view equation229 = "cab,cd,dfe->caf"; +static constexpr std::array shape229{2, 2, 2}; +static constexpr std::array expected229{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}; +static constexpr std::string_view equation230 = "cab,cd,dfe->cbd"; +static constexpr std::array shape230{2, 2, 2}; +static constexpr std::array expected230{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}; +static constexpr std::string_view equation231 = "cab,cd,dfe->cbf"; +static constexpr std::array shape231{2, 2, 2}; +static constexpr std::array expected231{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}; +static constexpr std::string_view equation232 = "cab,cd,edf->cae"; +static constexpr std::array shape232{2, 2, 2}; +static constexpr std::array expected232{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}; +static constexpr std::string_view equation233 = "cab,cd,edf->cad"; +static constexpr std::array shape233{2, 2, 2}; +static constexpr std::array expected233{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}; +static constexpr std::string_view equation234 = "cab,cd,edf->cbe"; +static constexpr std::array shape234{2, 2, 2}; +static constexpr std::array expected234{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}; +static constexpr std::string_view equation235 = "cab,cd,edf->cbd"; +static constexpr std::array shape235{2, 2, 2}; +static constexpr std::array expected235{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}; +static constexpr std::string_view equation236 = "cab,cd,efd->cae"; +static constexpr std::array shape236{2, 2, 2}; +static constexpr std::array expected236{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}; +static constexpr std::string_view equation237 = "cab,cd,efd->caf"; +static constexpr std::array shape237{2, 2, 2}; +static constexpr std::array expected237{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}; +static constexpr std::string_view equation238 = "cab,cd,efd->cbe"; +static constexpr std::array shape238{2, 2, 2}; +static constexpr std::array expected238{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}; +static constexpr std::string_view equation239 = "cab,cd,efd->cbf"; +static constexpr std::array shape239{2, 2, 2}; +static constexpr std::array expected239{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}; +static constexpr std::string_view equation240 = "cab,cd,fde->caf"; +static constexpr std::array shape240{2, 2, 2}; +static constexpr std::array expected240{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}; +static constexpr std::string_view equation241 = "cab,cd,fde->cad"; +static constexpr std::array shape241{2, 2, 2}; +static constexpr std::array expected241{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}; +static constexpr std::string_view equation242 = "cab,cd,fde->cbf"; +static constexpr std::array shape242{2, 2, 2}; +static constexpr std::array expected242{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}; +static constexpr std::string_view equation243 = "cab,cd,fde->cbd"; +static constexpr std::array shape243{2, 2, 2}; +static constexpr std::array expected243{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}; +static constexpr std::string_view equation244 = "cab,cd,fed->caf"; +static constexpr std::array shape244{2, 2, 2}; +static constexpr std::array expected244{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}; +static constexpr std::string_view equation245 = "cab,cd,fed->cae"; +static constexpr std::array shape245{2, 2, 2}; +static constexpr std::array expected245{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}; +static constexpr std::string_view equation246 = "cab,cd,fed->cbf"; +static constexpr std::array shape246{2, 2, 2}; +static constexpr std::array expected246{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}; +static constexpr std::string_view equation247 = "cab,cd,fed->cbe"; +static constexpr std::array shape247{2, 2, 2}; +static constexpr std::array expected247{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}; +static constexpr std::string_view equation248 = "cab,dc,def->cad"; +static constexpr std::array shape248{2, 2, 2}; +static constexpr std::array expected248{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}; +static constexpr std::string_view equation249 = "cab,dc,def->cae"; +static constexpr std::array shape249{2, 2, 2}; +static constexpr std::array expected249{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}; +static constexpr std::string_view equation250 = "cab,dc,def->cbd"; +static constexpr std::array shape250{2, 2, 2}; +static constexpr std::array expected250{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}; + +static constexpr std::string_view equation251 = "cab,dc,def->cbe"; +static constexpr std::array shape251{2, 2, 2}; +static constexpr std::array expected251{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}; +static constexpr std::string_view equation252 = "cab,dc,dfe->cad"; +static constexpr std::array shape252{2, 2, 2}; +static constexpr std::array expected252{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}; +static constexpr std::string_view equation253 = "cab,dc,dfe->caf"; +static constexpr std::array shape253{2, 2, 2}; +static constexpr std::array expected253{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}; +static constexpr std::string_view equation254 = "cab,dc,dfe->cbd"; +static constexpr std::array shape254{2, 2, 2}; +static constexpr std::array expected254{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}; +static constexpr std::string_view equation255 = "cab,dc,dfe->cbf"; +static constexpr std::array shape255{2, 2, 2}; +static constexpr std::array expected255{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}; +static constexpr std::string_view equation256 = "cab,dc,edf->cae"; +static constexpr std::array shape256{2, 2, 2}; +static constexpr std::array expected256{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}; +static constexpr std::string_view equation257 = "cab,dc,edf->cad"; +static constexpr std::array shape257{2, 2, 2}; +static constexpr std::array expected257{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}; +static constexpr std::string_view equation258 = "cab,dc,edf->cbe"; +static constexpr std::array shape258{2, 2, 2}; +static constexpr std::array expected258{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}; +static constexpr std::string_view equation259 = "cab,dc,edf->cbd"; +static constexpr std::array shape259{2, 2, 2}; +static constexpr std::array expected259{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}; +static constexpr std::string_view equation260 = "cab,dc,efd->cae"; +static constexpr std::array shape260{2, 2, 2}; +static constexpr std::array expected260{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}; +static constexpr std::string_view equation261 = "cab,dc,efd->caf"; +static constexpr std::array shape261{2, 2, 2}; +static constexpr std::array expected261{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}; +static constexpr std::string_view equation262 = "cab,dc,efd->cbe"; +static constexpr std::array shape262{2, 2, 2}; +static constexpr std::array expected262{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}; +static constexpr std::string_view equation263 = "cab,dc,efd->cbf"; +static constexpr std::array shape263{2, 2, 2}; +static constexpr std::array expected263{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}; +static constexpr std::string_view equation264 = "cab,dc,fde->caf"; +static constexpr std::array shape264{2, 2, 2}; +static constexpr std::array expected264{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}; +static constexpr std::string_view equation265 = "cab,dc,fde->cad"; +static constexpr std::array shape265{2, 2, 2}; +static constexpr std::array expected265{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}; +static constexpr std::string_view equation266 = "cab,dc,fde->cbf"; +static constexpr std::array shape266{2, 2, 2}; +static constexpr std::array expected266{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}; +static constexpr std::string_view equation267 = "cab,dc,fde->cbd"; +static constexpr std::array shape267{2, 2, 2}; +static constexpr std::array expected267{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}; +static constexpr std::string_view equation268 = "cab,dc,fed->caf"; +static constexpr std::array shape268{2, 2, 2}; +static constexpr std::array expected268{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}; +static constexpr std::string_view equation269 = "cab,dc,fed->cae"; +static constexpr std::array shape269{2, 2, 2}; +static constexpr std::array expected269{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}; +static constexpr std::string_view equation270 = "cab,dc,fed->cbf"; +static constexpr std::array shape270{2, 2, 2}; +static constexpr std::array expected270{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}; +static constexpr std::string_view equation271 = "cab,dc,fed->cbe"; +static constexpr std::array shape271{2, 2, 2}; +static constexpr std::array expected271{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}; +static constexpr std::string_view equation272 = "cba,cd,def->cbd"; +static constexpr std::array shape272{2, 2, 2}; +static constexpr std::array expected272{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}; +static constexpr std::string_view equation273 = "cba,cd,def->cbe"; +static constexpr std::array shape273{2, 2, 2}; +static constexpr std::array expected273{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}; +static constexpr std::string_view equation274 = "cba,cd,def->cad"; +static constexpr std::array shape274{2, 2, 2}; +static constexpr std::array expected274{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}; +static constexpr std::string_view equation275 = "cba,cd,def->cae"; +static constexpr std::array shape275{2, 2, 2}; +static constexpr std::array expected275{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}; +static constexpr std::string_view equation276 = "cba,cd,dfe->cbd"; +static constexpr std::array shape276{2, 2, 2}; +static constexpr std::array expected276{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}; +static constexpr std::string_view equation277 = "cba,cd,dfe->cbf"; +static constexpr std::array shape277{2, 2, 2}; +static constexpr std::array expected277{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}; +static constexpr std::string_view equation278 = "cba,cd,dfe->cad"; +static constexpr std::array shape278{2, 2, 2}; +static constexpr std::array expected278{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}; +static constexpr std::string_view equation279 = "cba,cd,dfe->caf"; +static constexpr std::array shape279{2, 2, 2}; +static constexpr std::array expected279{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}; +static constexpr std::string_view equation280 = "cba,cd,edf->cbe"; +static constexpr std::array shape280{2, 2, 2}; +static constexpr std::array expected280{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}; +static constexpr std::string_view equation281 = "cba,cd,edf->cbd"; +static constexpr std::array shape281{2, 2, 2}; +static constexpr std::array expected281{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}; +static constexpr std::string_view equation282 = "cba,cd,edf->cae"; +static constexpr std::array shape282{2, 2, 2}; +static constexpr std::array expected282{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}; +static constexpr std::string_view equation283 = "cba,cd,edf->cad"; +static constexpr std::array shape283{2, 2, 2}; +static constexpr std::array expected283{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}; +static constexpr std::string_view equation284 = "cba,cd,efd->cbe"; +static constexpr std::array shape284{2, 2, 2}; +static constexpr std::array expected284{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}; +static constexpr std::string_view equation285 = "cba,cd,efd->cbf"; +static constexpr std::array shape285{2, 2, 2}; +static constexpr std::array expected285{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}; +static constexpr std::string_view equation286 = "cba,cd,efd->cae"; +static constexpr std::array shape286{2, 2, 2}; +static constexpr std::array expected286{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}; +static constexpr std::string_view equation287 = "cba,cd,efd->caf"; +static constexpr std::array shape287{2, 2, 2}; +static constexpr std::array expected287{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}; +static constexpr std::string_view equation288 = "cba,cd,fde->cbf"; +static constexpr std::array shape288{2, 2, 2}; +static constexpr std::array expected288{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}; +static constexpr std::string_view equation289 = "cba,cd,fde->cbd"; +static constexpr std::array shape289{2, 2, 2}; +static constexpr std::array expected289{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}; +static constexpr std::string_view equation290 = "cba,cd,fde->caf"; +static constexpr std::array shape290{2, 2, 2}; +static constexpr std::array expected290{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}; +static constexpr std::string_view equation291 = "cba,cd,fde->cad"; +static constexpr std::array shape291{2, 2, 2}; +static constexpr std::array expected291{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}; +static constexpr std::string_view equation292 = "cba,cd,fed->cbf"; +static constexpr std::array shape292{2, 2, 2}; +static constexpr std::array expected292{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}; +static constexpr std::string_view equation293 = "cba,cd,fed->cbe"; +static constexpr std::array shape293{2, 2, 2}; +static constexpr std::array expected293{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}; +static constexpr std::string_view equation294 = "cba,cd,fed->caf"; +static constexpr std::array shape294{2, 2, 2}; +static constexpr std::array expected294{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}; +static constexpr std::string_view equation295 = "cba,cd,fed->cae"; +static constexpr std::array shape295{2, 2, 2}; +static constexpr std::array expected295{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}; +static constexpr std::string_view equation296 = "cba,dc,def->cbd"; +static constexpr std::array shape296{2, 2, 2}; +static constexpr std::array expected296{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}; +static constexpr std::string_view equation297 = "cba,dc,def->cbe"; +static constexpr std::array shape297{2, 2, 2}; +static constexpr std::array expected297{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}; +static constexpr std::string_view equation298 = "cba,dc,def->cad"; +static constexpr std::array shape298{2, 2, 2}; +static constexpr std::array expected298{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}; +static constexpr std::string_view equation299 = "cba,dc,def->cae"; +static constexpr std::array shape299{2, 2, 2}; +static constexpr std::array expected299{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}; +static constexpr std::string_view equation300 = "cba,dc,dfe->cbd"; +static constexpr std::array shape300{2, 2, 2}; +static constexpr std::array expected300{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}; +static constexpr std::string_view equation301 = "cba,dc,dfe->cbf"; +static constexpr std::array shape301{2, 2, 2}; +static constexpr std::array expected301{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}; +static constexpr std::string_view equation302 = "cba,dc,dfe->cad"; +static constexpr std::array shape302{2, 2, 2}; +static constexpr std::array expected302{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}; +static constexpr std::string_view equation303 = "cba,dc,dfe->caf"; +static constexpr std::array shape303{2, 2, 2}; +static constexpr std::array expected303{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}; +static constexpr std::string_view equation304 = "cba,dc,edf->cbe"; +static constexpr std::array shape304{2, 2, 2}; +static constexpr std::array expected304{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}; +static constexpr std::string_view equation305 = "cba,dc,edf->cbd"; +static constexpr std::array shape305{2, 2, 2}; +static constexpr std::array expected305{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}; +static constexpr std::string_view equation306 = "cba,dc,edf->cae"; +static constexpr std::array shape306{2, 2, 2}; +static constexpr std::array expected306{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}; +static constexpr std::string_view equation307 = "cba,dc,edf->cad"; +static constexpr std::array shape307{2, 2, 2}; +static constexpr std::array expected307{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}; +static constexpr std::string_view equation308 = "cba,dc,efd->cbe"; +static constexpr std::array shape308{2, 2, 2}; +static constexpr std::array expected308{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}; +static constexpr std::string_view equation309 = "cba,dc,efd->cbf"; +static constexpr std::array shape309{2, 2, 2}; +static constexpr std::array expected309{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}; +static constexpr std::string_view equation310 = "cba,dc,efd->cae"; +static constexpr std::array shape310{2, 2, 2}; +static constexpr std::array expected310{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}; +static constexpr std::string_view equation311 = "cba,dc,efd->caf"; +static constexpr std::array shape311{2, 2, 2}; +static constexpr std::array expected311{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}; +static constexpr std::string_view equation312 = "cba,dc,fde->cbf"; +static constexpr std::array shape312{2, 2, 2}; +static constexpr std::array expected312{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}; +static constexpr std::string_view equation313 = "cba,dc,fde->cbd"; +static constexpr std::array shape313{2, 2, 2}; +static constexpr std::array expected313{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}; +static constexpr std::string_view equation314 = "cba,dc,fde->caf"; +static constexpr std::array shape314{2, 2, 2}; +static constexpr std::array expected314{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}; +static constexpr std::string_view equation315 = "cba,dc,fde->cad"; +static constexpr std::array shape315{2, 2, 2}; +static constexpr std::array expected315{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}; +static constexpr std::string_view equation316 = "cba,dc,fed->cbf"; +static constexpr std::array shape316{2, 2, 2}; +static constexpr std::array expected316{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}; +static constexpr std::string_view equation317 = "cba,dc,fed->cbe"; +static constexpr std::array shape317{2, 2, 2}; +static constexpr std::array expected317{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}; +static constexpr std::string_view equation318 = "cba,dc,fed->caf"; +static constexpr std::array shape318{2, 2, 2}; +static constexpr std::array expected318{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}; +static constexpr std::string_view equation319 = "cba,dc,fed->cae"; +static constexpr std::array shape319{2, 2, 2}; +static constexpr std::array expected319{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}; +static constexpr std::array case1 = {{{equation32, shape32, expected32}, + {equation33, shape33, expected33}, + {equation34, shape34, expected34}, + {equation35, shape35, expected35}, + {equation36, shape36, expected36}, + {equation37, shape37, expected37}, + {equation38, shape38, expected38}, + {equation39, shape39, expected39}, + {equation40, shape40, expected40}, + {equation41, shape41, expected41}, + {equation42, shape42, expected42}, + {equation43, shape43, expected43}, + {equation44, shape44, expected44}, + {equation45, shape45, expected45}, + {equation46, shape46, expected46}, + {equation47, shape47, expected47}, + {equation48, shape48, expected48}, + {equation49, shape49, expected49}, + {equation50, shape50, expected50}, + {equation51, shape51, expected51}, + {equation52, shape52, expected52}, + {equation53, shape53, expected53}, + {equation54, shape54, expected54}, + {equation55, shape55, expected55}, + {equation56, shape56, expected56}, + {equation57, shape57, expected57}, + {equation58, shape58, expected58}, + {equation59, shape59, expected59}, + {equation60, shape60, expected60}, + {equation61, shape61, expected61}, + {equation62, shape62, expected62}, + {equation63, shape63, expected63}, + {equation64, shape64, expected64}, + {equation65, shape65, expected65}, + {equation66, shape66, expected66}, + {equation67, shape67, expected67}, + {equation68, shape68, expected68}, + {equation69, shape69, expected69}, + {equation70, shape70, expected70}, + {equation71, shape71, expected71}, + {equation72, shape72, expected72}, + {equation73, shape73, expected73}, + {equation74, shape74, expected74}, + {equation75, shape75, expected75}, + {equation76, shape76, expected76}, + {equation77, shape77, expected77}, + {equation78, shape78, expected78}, + {equation79, shape79, expected79}, + {equation80, shape80, expected80}, + {equation81, shape81, expected81}, + {equation82, shape82, expected82}, + {equation83, shape83, expected83}, + {equation84, shape84, expected84}, + {equation85, shape85, expected85}, + {equation86, shape86, expected86}, + {equation87, shape87, expected87}, + {equation88, shape88, expected88}, + {equation89, shape89, expected89}, + {equation90, shape90, expected90}, + {equation91, shape91, expected91}, + {equation92, shape92, expected92}, + {equation93, shape93, expected93}, + {equation94, shape94, expected94}, + {equation95, shape95, expected95}, + {equation96, shape96, expected96}, + {equation97, shape97, expected97}, + {equation98, shape98, expected98}, + {equation99, shape99, expected99}, + {equation100, shape100, expected100}, + {equation101, shape101, expected101}, + {equation102, shape102, expected102}, + {equation103, shape103, expected103}, + {equation104, shape104, expected104}, + {equation105, shape105, expected105}, + {equation106, shape106, expected106}, + {equation107, shape107, expected107}, + {equation108, shape108, expected108}, + {equation109, shape109, expected109}, + {equation110, shape110, expected110}, + {equation111, shape111, expected111}, + {equation112, shape112, expected112}, + {equation113, shape113, expected113}, + {equation114, shape114, expected114}, + {equation115, shape115, expected115}, + {equation116, shape116, expected116}, + {equation117, shape117, expected117}, + {equation118, shape118, expected118}, + {equation119, shape119, expected119}, + {equation120, shape120, expected120}, + {equation121, shape121, expected121}, + {equation122, shape122, expected122}, + {equation123, shape123, expected123}, + {equation124, shape124, expected124}, + {equation125, shape125, expected125}, + {equation126, shape126, expected126}, + {equation127, shape127, expected127}, + {equation128, shape128, expected128}, + {equation129, shape129, expected129}, + {equation130, shape130, expected130}, + {equation131, shape131, expected131}, + {equation132, shape132, expected132}, + {equation133, shape133, expected133}, + {equation134, shape134, expected134}, + {equation135, shape135, expected135}, + {equation136, shape136, expected136}, + {equation137, shape137, expected137}, + {equation138, shape138, expected138}, + {equation139, shape139, expected139}, + {equation140, shape140, expected140}, + {equation141, shape141, expected141}, + {equation142, shape142, expected142}, + {equation143, shape143, expected143}, + {equation144, shape144, expected144}, + {equation145, shape145, expected145}, + {equation146, shape146, expected146}, + {equation147, shape147, expected147}, + {equation148, shape148, expected148}, + {equation149, shape149, expected149}, + {equation150, shape150, expected150}, + {equation151, shape151, expected151}, + {equation152, shape152, expected152}, + {equation153, shape153, expected153}, + {equation154, shape154, expected154}, + {equation155, shape155, expected155}, + {equation156, shape156, expected156}, + {equation157, shape157, expected157}, + {equation158, shape158, expected158}, + {equation159, shape159, expected159}, + {equation160, shape160, expected160}, + {equation161, shape161, expected161}, + {equation162, shape162, expected162}, + {equation163, shape163, expected163}, + {equation164, shape164, expected164}, + {equation165, shape165, expected165}, + {equation166, shape166, expected166}, + {equation167, shape167, expected167}, + {equation168, shape168, expected168}, + {equation169, shape169, expected169}, + {equation170, shape170, expected170}, + {equation171, shape171, expected171}, + {equation172, shape172, expected172}, + {equation173, shape173, expected173}, + {equation174, shape174, expected174}, + {equation175, shape175, expected175}, + {equation176, shape176, expected176}, + {equation177, shape177, expected177}, + {equation178, shape178, expected178}, + {equation179, shape179, expected179}, + {equation180, shape180, expected180}, + {equation181, shape181, expected181}, + {equation182, shape182, expected182}, + {equation183, shape183, expected183}, + {equation184, shape184, expected184}, + {equation185, shape185, expected185}, + {equation186, shape186, expected186}, + {equation187, shape187, expected187}, + {equation188, shape188, expected188}, + {equation189, shape189, expected189}, + {equation190, shape190, expected190}, + {equation191, shape191, expected191}, + {equation192, shape192, expected192}, + {equation193, shape193, expected193}, + {equation194, shape194, expected194}, + {equation195, shape195, expected195}, + {equation196, shape196, expected196}, + {equation197, shape197, expected197}, + {equation198, shape198, expected198}, + {equation199, shape199, expected199}, + {equation200, shape200, expected200}, + {equation201, shape201, expected201}, + {equation202, shape202, expected202}, + {equation203, shape203, expected203}, + {equation204, shape204, expected204}, + {equation205, shape205, expected205}, + {equation206, shape206, expected206}, + {equation207, shape207, expected207}, + {equation208, shape208, expected208}, + {equation209, shape209, expected209}, + {equation210, shape210, expected210}, + {equation211, shape211, expected211}, + {equation212, shape212, expected212}, + {equation213, shape213, expected213}, + {equation214, shape214, expected214}, + {equation215, shape215, expected215}, + {equation216, shape216, expected216}, + {equation217, shape217, expected217}, + {equation218, shape218, expected218}, + {equation219, shape219, expected219}, + {equation220, shape220, expected220}, + {equation221, shape221, expected221}, + {equation222, shape222, expected222}, + {equation223, shape223, expected223}, + {equation224, shape224, expected224}, + {equation225, shape225, expected225}, + {equation226, shape226, expected226}, + {equation227, shape227, expected227}, + {equation228, shape228, expected228}, + {equation229, shape229, expected229}, + {equation230, shape230, expected230}, + {equation231, shape231, expected231}, + {equation232, shape232, expected232}, + {equation233, shape233, expected233}, + {equation234, shape234, expected234}, + {equation235, shape235, expected235}, + {equation236, shape236, expected236}, + {equation237, shape237, expected237}, + {equation238, shape238, expected238}, + {equation239, shape239, expected239}, + {equation240, shape240, expected240}, + {equation241, shape241, expected241}, + {equation242, shape242, expected242}, + {equation243, shape243, expected243}, + {equation244, shape244, expected244}, + {equation245, shape245, expected245}, + {equation246, shape246, expected246}, + {equation247, shape247, expected247}, + {equation248, shape248, expected248}, + {equation249, shape249, expected249}, + {equation250, shape250, expected250}, + {equation251, shape251, expected251}, + {equation252, shape252, expected252}, + {equation253, shape253, expected253}, + {equation254, shape254, expected254}, + {equation255, shape255, expected255}, + {equation256, shape256, expected256}, + {equation257, shape257, expected257}, + {equation258, shape258, expected258}, + {equation259, shape259, expected259}, + {equation260, shape260, expected260}, + {equation261, shape261, expected261}, + {equation262, shape262, expected262}, + {equation263, shape263, expected263}, + {equation264, shape264, expected264}, + {equation265, shape265, expected265}, + {equation266, shape266, expected266}, + {equation267, shape267, expected267}, + {equation268, shape268, expected268}, + {equation269, shape269, expected269}, + {equation270, shape270, expected270}, + {equation271, shape271, expected271}, + {equation272, shape272, expected272}, + {equation273, shape273, expected273}, + {equation274, shape274, expected274}, + {equation275, shape275, expected275}, + {equation276, shape276, expected276}, + {equation277, shape277, expected277}, + {equation278, shape278, expected278}, + {equation279, shape279, expected279}, + {equation280, shape280, expected280}, + {equation281, shape281, expected281}, + {equation282, shape282, expected282}, + {equation283, shape283, expected283}, + {equation284, shape284, expected284}, + {equation285, shape285, expected285}, + {equation286, shape286, expected286}, + {equation287, shape287, expected287}, + {equation288, shape288, expected288}, + {equation289, shape289, expected289}, + {equation290, shape290, expected290}, + {equation291, shape291, expected291}, + {equation292, shape292, expected292}, + {equation293, shape293, expected293}, + {equation294, shape294, expected294}, + {equation295, shape295, expected295}, + {equation296, shape296, expected296}, + {equation297, shape297, expected297}, + {equation298, shape298, expected298}, + {equation299, shape299, expected299}, + {equation300, shape300, expected300}, + {equation301, shape301, expected301}, + {equation302, shape302, expected302}, + {equation303, shape303, expected303}, + {equation304, shape304, expected304}, + {equation305, shape305, expected305}, + {equation306, shape306, expected306}, + {equation307, shape307, expected307}, + {equation308, shape308, expected308}, + {equation309, shape309, expected309}, + {equation310, shape310, expected310}, + {equation311, shape311, expected311}, + {equation312, shape312, expected312}, + {equation313, shape313, expected313}, + {equation314, shape314, expected314}, + {equation315, shape315, expected315}, + {equation316, shape316, expected316}, + {equation317, shape317, expected317}, + {equation318, shape318, expected318}, + {equation319, shape319, expected319}}}; TEST(Einsum, EinsumTransposeMatMulTwoInputsTestSuite) { - std::vector test_cases{ - EinsumTestCase("abc,cd->abc", std::vector{2, 2, 2}, std::vector{0.f, 5.f, 2.f, 15.f, 4.f, 25.f, 6.f, 35.f}), - EinsumTestCase("abc,cd->abd", std::vector{2, 2, 2}, std::vector{2.f, 3.f, 6.f, 11.f, 10.f, 19.f, 14.f, 27.f}), - EinsumTestCase("abc,cd->acd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 8.f, 12.f, 0.f, 10.f, 24.f, 36.f}), - EinsumTestCase("abc,dc->abd", std::vector{2, 2, 2}, std::vector{1.f, 3.f, 3.f, 13.f, 5.f, 23.f, 7.f, 33.f}), - EinsumTestCase("abc,dc->abc", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 4.f, 12.f, 8.f, 20.f, 12.f, 28.f}), - EinsumTestCase("abc,dc->acd", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 4.f, 12.f, 0.f, 20.f, 12.f, 36.f}), - EinsumTestCase("acb,cd->acd", std::vector{2, 2, 2}, std::vector{0.f, 1.f, 10.f, 15.f, 0.f, 9.f, 26.f, 39.f}), - EinsumTestCase("acb,cd->abc", std::vector{2, 2, 2}, std::vector{0.f, 10.f, 1.f, 15.f, 4.f, 30.f, 5.f, 35.f}), - EinsumTestCase("acb,cd->abd", std::vector{2, 2, 2}, std::vector{4.f, 6.f, 6.f, 10.f, 12.f, 22.f, 14.f, 26.f}), - EinsumTestCase("acb,dc->acd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 5.f, 15.f, 0.f, 18.f, 13.f, 39.f}), - EinsumTestCase("acb,dc->abd", std::vector{2, 2, 2}, std::vector{2.f, 6.f, 3.f, 11.f, 6.f, 26.f, 7.f, 31.f}), - EinsumTestCase("acb,dc->abc", std::vector{2, 2, 2}, std::vector{0.f, 8.f, 2.f, 12.f, 8.f, 24.f, 10.f, 28.f}), - EinsumTestCase("bac,cd->bac", std::vector{2, 2, 2}, std::vector{0.f, 5.f, 2.f, 15.f, 4.f, 25.f, 6.f, 35.f}), - EinsumTestCase("bac,cd->bad", std::vector{2, 2, 2}, std::vector{2.f, 3.f, 6.f, 11.f, 10.f, 19.f, 14.f, 27.f}), - EinsumTestCase("bac,cd->bcd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 8.f, 12.f, 0.f, 10.f, 24.f, 36.f}), - EinsumTestCase("bac,dc->bad", std::vector{2, 2, 2}, std::vector{1.f, 3.f, 3.f, 13.f, 5.f, 23.f, 7.f, 33.f}), - EinsumTestCase("bac,dc->bac", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 4.f, 12.f, 8.f, 20.f, 12.f, 28.f}), - EinsumTestCase("bac,dc->bcd", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 4.f, 12.f, 0.f, 20.f, 12.f, 36.f}), - EinsumTestCase("bca,cd->bcd", std::vector{2, 2, 2}, std::vector{0.f, 1.f, 10.f, 15.f, 0.f, 9.f, 26.f, 39.f}), - EinsumTestCase("bca,cd->bac", std::vector{2, 2, 2}, std::vector{0.f, 10.f, 1.f, 15.f, 4.f, 30.f, 5.f, 35.f}), - EinsumTestCase("bca,cd->bad", std::vector{2, 2, 2}, std::vector{4.f, 6.f, 6.f, 10.f, 12.f, 22.f, 14.f, 26.f}), - EinsumTestCase("bca,dc->bcd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 5.f, 15.f, 0.f, 18.f, 13.f, 39.f}), - EinsumTestCase("bca,dc->bad", std::vector{2, 2, 2}, std::vector{2.f, 6.f, 3.f, 11.f, 6.f, 26.f, 7.f, 31.f}), - EinsumTestCase("bca,dc->bac", std::vector{2, 2, 2}, std::vector{0.f, 8.f, 2.f, 12.f, 8.f, 24.f, 10.f, 28.f}), - EinsumTestCase("cab,cd->cad", std::vector{2, 2, 2}, std::vector{0.f, 1.f, 0.f, 5.f, 18.f, 27.f, 26.f, 39.f}), - EinsumTestCase("cab,cd->cbd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 0.f, 4.f, 20.f, 30.f, 24.f, 36.f}), - EinsumTestCase("cab,dc->cad", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 0.f, 10.f, 9.f, 27.f, 13.f, 39.f}), - EinsumTestCase("cab,dc->cbd", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 0.f, 8.f, 10.f, 30.f, 12.f, 36.f}), - EinsumTestCase("cba,cd->cbd", std::vector{2, 2, 2}, std::vector{0.f, 1.f, 0.f, 5.f, 18.f, 27.f, 26.f, 39.f}), - EinsumTestCase("cba,cd->cad", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 0.f, 4.f, 20.f, 30.f, 24.f, 36.f}), - EinsumTestCase("cba,dc->cbd", std::vector{2, 2, 2}, std::vector{0.f, 2.f, 0.f, 10.f, 9.f, 27.f, 13.f, 39.f}), - EinsumTestCase("cba,dc->cad", std::vector{2, 2, 2}, std::vector{0.f, 4.f, 0.f, 8.f, 10.f, 30.f, 12.f, 36.f})}; - std::vector m1{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; std::vector m2{0.f, 1.f, 2.f, 3.f}; - for (const auto& tst : test_cases) { + for (const auto& tst : case0) { OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); - test.AddAttribute("equation", tst.equation); + std::string s(tst.equation); + test.AddAttribute("equation", s); test.AddInput("x", {2, 2, 2}, m1); test.AddInput("y", {2, 2}, m2); - test.AddOutput("o", tst.shape, tst.expected); + + std::vector v1(tst.shape.begin(), tst.shape.end()); + std::vector v2(tst.expected.begin(), tst.expected.end()); + test.AddOutput("o", v1, v2); test.Run(); } } -TEST(Einsum, EinsumTransposeMatMulThreeInputsTestSuite) { - std::vector test_cases_set_1{ - EinsumTestCase("abc,cd,def->abd", std::vector{2, 2, 2}, std::vector{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}), - EinsumTestCase("abc,cd,def->abe", std::vector{2, 2, 2}, std::vector{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}), - EinsumTestCase("abc,cd,def->acd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}), - EinsumTestCase("abc,cd,def->ace", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}), - EinsumTestCase("abc,cd,dfe->abd", std::vector{2, 2, 2}, std::vector{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}), - EinsumTestCase("abc,cd,dfe->abf", std::vector{2, 2, 2}, std::vector{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}), - EinsumTestCase("abc,cd,dfe->acd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}), - EinsumTestCase("abc,cd,dfe->acf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}), - EinsumTestCase("abc,cd,edf->abe", std::vector{2, 2, 2}, std::vector{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}), - EinsumTestCase("abc,cd,edf->abd", std::vector{2, 2, 2}, std::vector{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}), - EinsumTestCase("abc,cd,edf->ace", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}), - EinsumTestCase("abc,cd,edf->acd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}), - EinsumTestCase("abc,cd,efd->abe", std::vector{2, 2, 2}, std::vector{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}), - EinsumTestCase("abc,cd,efd->abf", std::vector{2, 2, 2}, std::vector{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}), - EinsumTestCase("abc,cd,efd->ace", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}), - EinsumTestCase("abc,cd,efd->acf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}), - EinsumTestCase("abc,cd,fde->abf", std::vector{2, 2, 2}, std::vector{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}), - EinsumTestCase("abc,cd,fde->abd", std::vector{2, 2, 2}, std::vector{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}), - EinsumTestCase("abc,cd,fde->acf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}), - EinsumTestCase("abc,cd,fde->acd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}), - EinsumTestCase("abc,cd,fed->abf", std::vector{2, 2, 2}, std::vector{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}), - EinsumTestCase("abc,cd,fed->abe", std::vector{2, 2, 2}, std::vector{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}), - EinsumTestCase("abc,cd,fed->acf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}), - EinsumTestCase("abc,cd,fed->ace", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}), - EinsumTestCase("abc,dc,def->abd", std::vector{2, 2, 2}, std::vector{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}), - EinsumTestCase("abc,dc,def->abe", std::vector{2, 2, 2}, std::vector{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}), - EinsumTestCase("abc,dc,def->acd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}), - EinsumTestCase("abc,dc,def->ace", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}), - EinsumTestCase("abc,dc,dfe->abd", std::vector{2, 2, 2}, std::vector{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}), - EinsumTestCase("abc,dc,dfe->abf", std::vector{2, 2, 2}, std::vector{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}), - EinsumTestCase("abc,dc,dfe->acd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}), - EinsumTestCase("abc,dc,dfe->acf", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}), - EinsumTestCase("abc,dc,edf->abe", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}), - EinsumTestCase("abc,dc,edf->abd", std::vector{2, 2, 2}, std::vector{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}), - EinsumTestCase("abc,dc,edf->ace", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}), - EinsumTestCase("abc,dc,edf->acd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}), - EinsumTestCase("abc,dc,efd->abe", std::vector{2, 2, 2}, std::vector{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}), - EinsumTestCase("abc,dc,efd->abf", std::vector{2, 2, 2}, std::vector{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}), - EinsumTestCase("abc,dc,efd->ace", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}), - EinsumTestCase("abc,dc,efd->acf", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}), - EinsumTestCase("abc,dc,fde->abf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}), - EinsumTestCase("abc,dc,fde->abd", std::vector{2, 2, 2}, std::vector{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}), - EinsumTestCase("abc,dc,fde->acf", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}), - EinsumTestCase("abc,dc,fde->acd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}), - EinsumTestCase("abc,dc,fed->abf", std::vector{2, 2, 2}, std::vector{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}), - EinsumTestCase("abc,dc,fed->abe", std::vector{2, 2, 2}, std::vector{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}), - EinsumTestCase("abc,dc,fed->acf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}), - EinsumTestCase("abc,dc,fed->ace", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}), - EinsumTestCase("acb,cd,def->acd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}), - EinsumTestCase("acb,cd,def->ace", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}), - EinsumTestCase("acb,cd,def->abd", std::vector{2, 2, 2}, std::vector{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}), - EinsumTestCase("acb,cd,def->abe", std::vector{2, 2, 2}, std::vector{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}), - EinsumTestCase("acb,cd,dfe->acd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}), - EinsumTestCase("acb,cd,dfe->acf", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}), - EinsumTestCase("acb,cd,dfe->abd", std::vector{2, 2, 2}, std::vector{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}), - EinsumTestCase("acb,cd,dfe->abf", std::vector{2, 2, 2}, std::vector{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}), - EinsumTestCase("acb,cd,edf->ace", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}), - EinsumTestCase("acb,cd,edf->acd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}), - EinsumTestCase("acb,cd,edf->abe", std::vector{2, 2, 2}, std::vector{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}), - EinsumTestCase("acb,cd,edf->abd", std::vector{2, 2, 2}, std::vector{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}), - EinsumTestCase("acb,cd,efd->ace", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}), - EinsumTestCase("acb,cd,efd->acf", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}), - EinsumTestCase("acb,cd,efd->abe", std::vector{2, 2, 2}, std::vector{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}), - EinsumTestCase("acb,cd,efd->abf", std::vector{2, 2, 2}, std::vector{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}), - EinsumTestCase("acb,cd,fde->acf", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}), - EinsumTestCase("acb,cd,fde->acd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}), - EinsumTestCase("acb,cd,fde->abf", std::vector{2, 2, 2}, std::vector{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}), - EinsumTestCase("acb,cd,fde->abd", std::vector{2, 2, 2}, std::vector{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}), - EinsumTestCase("acb,cd,fed->acf", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}), - EinsumTestCase("acb,cd,fed->ace", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}), - EinsumTestCase("acb,cd,fed->abf", std::vector{2, 2, 2}, std::vector{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}), - EinsumTestCase("acb,cd,fed->abe", std::vector{2, 2, 2}, std::vector{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}), - EinsumTestCase("acb,dc,def->acd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}), - EinsumTestCase("acb,dc,def->ace", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f})}; - - std::vector test_cases_set_2{ - EinsumTestCase("acb,dc,def->abd", std::vector{2, 2, 2}, std::vector{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}), - EinsumTestCase("acb,dc,def->abe", std::vector{2, 2, 2}, std::vector{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}), - EinsumTestCase("acb,dc,dfe->acd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}), - EinsumTestCase("acb,dc,dfe->acf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}), - EinsumTestCase("acb,dc,dfe->abd", std::vector{2, 2, 2}, std::vector{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}), - EinsumTestCase("acb,dc,dfe->abf", std::vector{2, 2, 2}, std::vector{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}), - EinsumTestCase("acb,dc,edf->ace", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}), - EinsumTestCase("acb,dc,edf->acd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}), - EinsumTestCase("acb,dc,edf->abe", std::vector{2, 2, 2}, std::vector{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}), - EinsumTestCase("acb,dc,edf->abd", std::vector{2, 2, 2}, std::vector{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}), - EinsumTestCase("acb,dc,efd->ace", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}), - EinsumTestCase("acb,dc,efd->acf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}), - EinsumTestCase("acb,dc,efd->abe", std::vector{2, 2, 2}, std::vector{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}), - EinsumTestCase("acb,dc,efd->abf", std::vector{2, 2, 2}, std::vector{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}), - EinsumTestCase("acb,dc,fde->acf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}), - EinsumTestCase("acb,dc,fde->acd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}), - EinsumTestCase("acb,dc,fde->abf", std::vector{2, 2, 2}, std::vector{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}), - EinsumTestCase("acb,dc,fde->abd", std::vector{2, 2, 2}, std::vector{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}), - EinsumTestCase("acb,dc,fed->acf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}), - EinsumTestCase("acb,dc,fed->ace", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}), - EinsumTestCase("acb,dc,fed->abf", std::vector{2, 2, 2}, std::vector{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}), - EinsumTestCase("acb,dc,fed->abe", std::vector{2, 2, 2}, std::vector{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}), - EinsumTestCase("bac,cd,def->bad", std::vector{2, 2, 2}, std::vector{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}), - EinsumTestCase("bac,cd,def->bae", std::vector{2, 2, 2}, std::vector{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}), - EinsumTestCase("bac,cd,def->bcd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}), - EinsumTestCase("bac,cd,def->bce", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}), - EinsumTestCase("bac,cd,dfe->bad", std::vector{2, 2, 2}, std::vector{12.f, 66.f, 36.f, 242.f, 60.f, 418.f, 84.f, 594.f}), - EinsumTestCase("bac,cd,dfe->baf", std::vector{2, 2, 2}, std::vector{29.f, 49.f, 105.f, 173.f, 181.f, 297.f, 257.f, 421.f}), - EinsumTestCase("bac,cd,dfe->bcd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 48.f, 264.f, 0.f, 220.f, 144.f, 792.f}), - EinsumTestCase("bac,cd,dfe->bcf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 116.f, 196.f, 90.f, 130.f, 348.f, 588.f}), - EinsumTestCase("bac,cd,edf->bae", std::vector{2, 2, 2}, std::vector{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}), - EinsumTestCase("bac,cd,edf->bad", std::vector{2, 2, 2}, std::vector{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}), - EinsumTestCase("bac,cd,edf->bce", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}), - EinsumTestCase("bac,cd,edf->bcd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}), - EinsumTestCase("bac,cd,efd->bae", std::vector{2, 2, 2}, std::vector{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}), - EinsumTestCase("bac,cd,efd->baf", std::vector{2, 2, 2}, std::vector{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}), - EinsumTestCase("bac,cd,efd->bce", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}), - EinsumTestCase("bac,cd,efd->bcf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}), - EinsumTestCase("bac,cd,fde->baf", std::vector{2, 2, 2}, std::vector{17.f, 57.f, 61.f, 197.f, 105.f, 337.f, 149.f, 477.f}), - EinsumTestCase("bac,cd,fde->bad", std::vector{2, 2, 2}, std::vector{20.f, 54.f, 60.f, 198.f, 100.f, 342.f, 140.f, 486.f}), - EinsumTestCase("bac,cd,fde->bcf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 68.f, 228.f, 50.f, 130.f, 204.f, 684.f}), - EinsumTestCase("bac,cd,fde->bcd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 80.f, 216.f, 0.f, 180.f, 240.f, 648.f}), - EinsumTestCase("bac,cd,fed->baf", std::vector{2, 2, 2}, std::vector{16.f, 56.f, 56.f, 192.f, 96.f, 328.f, 136.f, 464.f}), - EinsumTestCase("bac,cd,fed->bae", std::vector{2, 2, 2}, std::vector{26.f, 46.f, 90.f, 158.f, 154.f, 270.f, 218.f, 382.f}), - EinsumTestCase("bac,cd,fed->bcf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 64.f, 224.f, 40.f, 120.f, 192.f, 672.f}), - EinsumTestCase("bac,cd,fed->bce", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 104.f, 184.f, 60.f, 100.f, 312.f, 552.f}), - EinsumTestCase("bac,dc,def->bad", std::vector{2, 2, 2}, std::vector{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}), - EinsumTestCase("bac,dc,def->bae", std::vector{2, 2, 2}, std::vector{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}), - EinsumTestCase("bac,dc,def->bcd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}), - EinsumTestCase("bac,dc,def->bce", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}), - EinsumTestCase("bac,dc,dfe->bad", std::vector{2, 2, 2}, std::vector{6.f, 66.f, 18.f, 286.f, 30.f, 506.f, 42.f, 726.f}), - EinsumTestCase("bac,dc,dfe->baf", std::vector{2, 2, 2}, std::vector{28.f, 44.f, 120.f, 184.f, 212.f, 324.f, 304.f, 464.f}), - EinsumTestCase("bac,dc,dfe->bcd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 24.f, 264.f, 0.f, 440.f, 72.f, 792.f}), - EinsumTestCase("bac,dc,dfe->bcf", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 112.f, 176.f, 180.f, 260.f, 336.f, 528.f}), - EinsumTestCase("bac,dc,edf->bae", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}), - EinsumTestCase("bac,dc,edf->bad", std::vector{2, 2, 2}, std::vector{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}), - EinsumTestCase("bac,dc,edf->bce", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}), - EinsumTestCase("bac,dc,edf->bcd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}), - EinsumTestCase("bac,dc,efd->bae", std::vector{2, 2, 2}, std::vector{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}), - EinsumTestCase("bac,dc,efd->baf", std::vector{2, 2, 2}, std::vector{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}), - EinsumTestCase("bac,dc,efd->bce", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}), - EinsumTestCase("bac,dc,efd->bcf", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}), - EinsumTestCase("bac,dc,fde->baf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 68.f, 196.f, 120.f, 344.f, 172.f, 492.f}), - EinsumTestCase("bac,dc,fde->bad", std::vector{2, 2, 2}, std::vector{10.f, 54.f, 30.f, 234.f, 50.f, 414.f, 70.f, 594.f}), - EinsumTestCase("bac,dc,fde->bcf", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 64.f, 192.f, 100.f, 260.f, 192.f, 576.f}), - EinsumTestCase("bac,dc,fde->bcd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 40.f, 216.f, 0.f, 360.f, 120.f, 648.f}), - EinsumTestCase("bac,dc,fed->baf", std::vector{2, 2, 2}, std::vector{14.f, 46.f, 58.f, 186.f, 102.f, 326.f, 146.f, 466.f}), - EinsumTestCase("bac,dc,fed->bae", std::vector{2, 2, 2}, std::vector{22.f, 38.f, 90.f, 154.f, 158.f, 270.f, 226.f, 386.f}), - EinsumTestCase("bac,dc,fed->bcf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 56.f, 184.f, 80.f, 240.f, 168.f, 552.f}), - EinsumTestCase("bac,dc,fed->bce", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 88.f, 152.f, 120.f, 200.f, 264.f, 456.f}), - EinsumTestCase("bca,cd,def->bcd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}), - EinsumTestCase("bca,cd,def->bce", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}), - EinsumTestCase("bca,cd,def->bad", std::vector{2, 2, 2}, std::vector{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}), - EinsumTestCase("bca,cd,def->bae", std::vector{2, 2, 2}, std::vector{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}), - EinsumTestCase("bca,cd,dfe->bcd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 60.f, 330.f, 0.f, 198.f, 156.f, 858.f}), - EinsumTestCase("bca,cd,dfe->bcf", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 145.f, 245.f, 81.f, 117.f, 377.f, 637.f}), - EinsumTestCase("bca,cd,dfe->bad", std::vector{2, 2, 2}, std::vector{24.f, 132.f, 36.f, 220.f, 72.f, 484.f, 84.f, 572.f}), - EinsumTestCase("bca,cd,dfe->baf", std::vector{2, 2, 2}, std::vector{58.f, 98.f, 96.f, 160.f, 210.f, 346.f, 248.f, 408.f}), - EinsumTestCase("bca,cd,edf->bce", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}), - EinsumTestCase("bca,cd,edf->bcd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}), - EinsumTestCase("bca,cd,edf->bae", std::vector{2, 2, 2}, std::vector{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}), - EinsumTestCase("bca,cd,edf->bad", std::vector{2, 2, 2}, std::vector{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}), - EinsumTestCase("bca,cd,efd->bce", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}), - EinsumTestCase("bca,cd,efd->bcf", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}), - EinsumTestCase("bca,cd,efd->bae", std::vector{2, 2, 2}, std::vector{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}), - EinsumTestCase("bca,cd,efd->baf", std::vector{2, 2, 2}, std::vector{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}), - EinsumTestCase("bca,cd,fde->bcf", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 85.f, 285.f, 45.f, 117.f, 221.f, 741.f}), - EinsumTestCase("bca,cd,fde->bcd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 100.f, 270.f, 0.f, 162.f, 260.f, 702.f}), - EinsumTestCase("bca,cd,fde->baf", std::vector{2, 2, 2}, std::vector{34.f, 114.f, 56.f, 184.f, 122.f, 394.f, 144.f, 464.f}), - EinsumTestCase("bca,cd,fde->bad", std::vector{2, 2, 2}, std::vector{40.f, 108.f, 60.f, 180.f, 120.f, 396.f, 140.f, 468.f}), - EinsumTestCase("bca,cd,fed->bcf", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 80.f, 280.f, 36.f, 108.f, 208.f, 728.f}), - EinsumTestCase("bca,cd,fed->bce", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 130.f, 230.f, 54.f, 90.f, 338.f, 598.f}), - EinsumTestCase("bca,cd,fed->baf", std::vector{2, 2, 2}, std::vector{32.f, 112.f, 52.f, 180.f, 112.f, 384.f, 132.f, 452.f}), - EinsumTestCase("bca,cd,fed->bae", std::vector{2, 2, 2}, std::vector{52.f, 92.f, 84.f, 148.f, 180.f, 316.f, 212.f, 372.f}), - EinsumTestCase("bca,dc,def->bcd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}), - EinsumTestCase("bca,dc,def->bce", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}), - EinsumTestCase("bca,dc,def->bad", std::vector{2, 2, 2}, std::vector{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}), - EinsumTestCase("bca,dc,def->bae", std::vector{2, 2, 2}, std::vector{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}), - EinsumTestCase("bca,dc,dfe->bcd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 30.f, 330.f, 0.f, 396.f, 78.f, 858.f}), - EinsumTestCase("bca,dc,dfe->bcf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 140.f, 220.f, 162.f, 234.f, 364.f, 572.f}), - EinsumTestCase("bca,dc,dfe->bad", std::vector{2, 2, 2}, std::vector{12.f, 132.f, 18.f, 242.f, 36.f, 572.f, 42.f, 682.f}), - EinsumTestCase("bca,dc,dfe->baf", std::vector{2, 2, 2}, std::vector{56.f, 88.f, 102.f, 158.f, 240.f, 368.f, 286.f, 438.f}), - EinsumTestCase("bca,dc,edf->bce", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}), - EinsumTestCase("bca,dc,edf->bcd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}), - EinsumTestCase("bca,dc,edf->bae", std::vector{2, 2, 2}, std::vector{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}), - EinsumTestCase("bca,dc,edf->bad", std::vector{2, 2, 2}, std::vector{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}), - EinsumTestCase("bca,dc,efd->bce", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}), - EinsumTestCase("bca,dc,efd->bcf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}), - EinsumTestCase("bca,dc,efd->bae", std::vector{2, 2, 2}, std::vector{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}), - EinsumTestCase("bca,dc,efd->baf", std::vector{2, 2, 2}, std::vector{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}), - EinsumTestCase("bca,dc,fde->bcf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 80.f, 240.f, 90.f, 234.f, 208.f, 624.f}), - EinsumTestCase("bca,dc,fde->bcd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 50.f, 270.f, 0.f, 324.f, 130.f, 702.f}), - EinsumTestCase("bca,dc,fde->baf", std::vector{2, 2, 2}, std::vector{32.f, 96.f, 58.f, 170.f, 136.f, 392.f, 162.f, 466.f}), - EinsumTestCase("bca,dc,fde->bad", std::vector{2, 2, 2}, std::vector{20.f, 108.f, 30.f, 198.f, 60.f, 468.f, 70.f, 558.f}), - EinsumTestCase("bca,dc,fed->bcf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 70.f, 230.f, 72.f, 216.f, 182.f, 598.f}), - EinsumTestCase("bca,dc,fed->bce", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 110.f, 190.f, 108.f, 180.f, 286.f, 494.f}), - EinsumTestCase("bca,dc,fed->baf", std::vector{2, 2, 2}, std::vector{28.f, 92.f, 50.f, 162.f, 116.f, 372.f, 138.f, 442.f}), - EinsumTestCase("bca,dc,fed->bae", std::vector{2, 2, 2}, std::vector{44.f, 76.f, 78.f, 134.f, 180.f, 308.f, 214.f, 366.f}), - EinsumTestCase("cab,cd,def->cad", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}), - EinsumTestCase("cab,cd,def->cae", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}), - EinsumTestCase("cab,cd,def->cbd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}), - EinsumTestCase("cab,cd,def->cbe", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}), - EinsumTestCase("cab,cd,dfe->cad", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}), - EinsumTestCase("cab,cd,dfe->caf", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}), - EinsumTestCase("cab,cd,dfe->cbd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}), - EinsumTestCase("cab,cd,dfe->cbf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}), - EinsumTestCase("cab,cd,edf->cae", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}), - EinsumTestCase("cab,cd,edf->cad", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}), - EinsumTestCase("cab,cd,edf->cbe", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}), - EinsumTestCase("cab,cd,edf->cbd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}), - EinsumTestCase("cab,cd,efd->cae", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}), - EinsumTestCase("cab,cd,efd->caf", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}), - EinsumTestCase("cab,cd,efd->cbe", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}), - EinsumTestCase("cab,cd,efd->cbf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}), - EinsumTestCase("cab,cd,fde->caf", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}), - EinsumTestCase("cab,cd,fde->cad", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}), - EinsumTestCase("cab,cd,fde->cbf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}), - EinsumTestCase("cab,cd,fde->cbd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}), - EinsumTestCase("cab,cd,fed->caf", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}), - EinsumTestCase("cab,cd,fed->cae", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}), - EinsumTestCase("cab,cd,fed->cbf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}), - EinsumTestCase("cab,cd,fed->cbe", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}), - EinsumTestCase("cab,dc,def->cad", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}), - EinsumTestCase("cab,dc,def->cae", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}), - EinsumTestCase("cab,dc,def->cbd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f})}; - - std::vector test_cases_set_3{ - EinsumTestCase("cab,dc,def->cbe", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}), - EinsumTestCase("cab,dc,dfe->cad", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}), - EinsumTestCase("cab,dc,dfe->caf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}), - EinsumTestCase("cab,dc,dfe->cbd", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}), - EinsumTestCase("cab,dc,dfe->cbf", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}), - EinsumTestCase("cab,dc,edf->cae", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}), - EinsumTestCase("cab,dc,edf->cad", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}), - EinsumTestCase("cab,dc,edf->cbe", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}), - EinsumTestCase("cab,dc,edf->cbd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}), - EinsumTestCase("cab,dc,efd->cae", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}), - EinsumTestCase("cab,dc,efd->caf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}), - EinsumTestCase("cab,dc,efd->cbe", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}), - EinsumTestCase("cab,dc,efd->cbf", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}), - EinsumTestCase("cab,dc,fde->caf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}), - EinsumTestCase("cab,dc,fde->cad", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}), - EinsumTestCase("cab,dc,fde->cbf", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}), - EinsumTestCase("cab,dc,fde->cbd", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}), - EinsumTestCase("cab,dc,fed->caf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}), - EinsumTestCase("cab,dc,fed->cae", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}), - EinsumTestCase("cab,dc,fed->cbf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}), - EinsumTestCase("cab,dc,fed->cbe", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}), - EinsumTestCase("cba,cd,def->cbd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}), - EinsumTestCase("cba,cd,def->cbe", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}), - EinsumTestCase("cba,cd,def->cad", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}), - EinsumTestCase("cba,cd,def->cae", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}), - EinsumTestCase("cba,cd,dfe->cbd", std::vector{2, 2, 2}, std::vector{0.f, 22.f, 0.f, 110.f, 108.f, 594.f, 156.f, 858.f}), - EinsumTestCase("cba,cd,dfe->cbf", std::vector{2, 2, 2}, std::vector{9.f, 13.f, 45.f, 65.f, 261.f, 441.f, 377.f, 637.f}), - EinsumTestCase("cba,cd,dfe->cad", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 88.f, 120.f, 660.f, 144.f, 792.f}), - EinsumTestCase("cba,cd,dfe->caf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 36.f, 52.f, 290.f, 490.f, 348.f, 588.f}), - EinsumTestCase("cba,cd,edf->cbe", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}), - EinsumTestCase("cba,cd,edf->cbd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}), - EinsumTestCase("cba,cd,edf->cae", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}), - EinsumTestCase("cba,cd,edf->cad", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}), - EinsumTestCase("cba,cd,efd->cbe", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}), - EinsumTestCase("cba,cd,efd->cbf", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}), - EinsumTestCase("cba,cd,efd->cae", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}), - EinsumTestCase("cba,cd,efd->caf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}), - EinsumTestCase("cba,cd,fde->cbf", std::vector{2, 2, 2}, std::vector{5.f, 13.f, 25.f, 65.f, 153.f, 513.f, 221.f, 741.f}), - EinsumTestCase("cba,cd,fde->cbd", std::vector{2, 2, 2}, std::vector{0.f, 18.f, 0.f, 90.f, 180.f, 486.f, 260.f, 702.f}), - EinsumTestCase("cba,cd,fde->caf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 20.f, 52.f, 170.f, 570.f, 204.f, 684.f}), - EinsumTestCase("cba,cd,fde->cad", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 72.f, 200.f, 540.f, 240.f, 648.f}), - EinsumTestCase("cba,cd,fed->cbf", std::vector{2, 2, 2}, std::vector{4.f, 12.f, 20.f, 60.f, 144.f, 504.f, 208.f, 728.f}), - EinsumTestCase("cba,cd,fed->cbe", std::vector{2, 2, 2}, std::vector{6.f, 10.f, 30.f, 50.f, 234.f, 414.f, 338.f, 598.f}), - EinsumTestCase("cba,cd,fed->caf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 16.f, 48.f, 160.f, 560.f, 192.f, 672.f}), - EinsumTestCase("cba,cd,fed->cae", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 24.f, 40.f, 260.f, 460.f, 312.f, 552.f}), - EinsumTestCase("cba,dc,def->cbd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}), - EinsumTestCase("cba,dc,def->cbe", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}), - EinsumTestCase("cba,dc,def->cad", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}), - EinsumTestCase("cba,dc,def->cae", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}), - EinsumTestCase("cba,dc,dfe->cbd", std::vector{2, 2, 2}, std::vector{0.f, 44.f, 0.f, 220.f, 54.f, 594.f, 78.f, 858.f}), - EinsumTestCase("cba,dc,dfe->cbf", std::vector{2, 2, 2}, std::vector{18.f, 26.f, 90.f, 130.f, 252.f, 396.f, 364.f, 572.f}), - EinsumTestCase("cba,dc,dfe->cad", std::vector{2, 2, 2}, std::vector{0.f, 88.f, 0.f, 176.f, 60.f, 660.f, 72.f, 792.f}), - EinsumTestCase("cba,dc,dfe->caf", std::vector{2, 2, 2}, std::vector{36.f, 52.f, 72.f, 104.f, 280.f, 440.f, 336.f, 528.f}), - EinsumTestCase("cba,dc,edf->cbe", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}), - EinsumTestCase("cba,dc,edf->cbd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}), - EinsumTestCase("cba,dc,edf->cae", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}), - EinsumTestCase("cba,dc,edf->cad", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}), - EinsumTestCase("cba,dc,efd->cbe", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}), - EinsumTestCase("cba,dc,efd->cbf", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}), - EinsumTestCase("cba,dc,efd->cae", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}), - EinsumTestCase("cba,dc,efd->caf", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f}), - EinsumTestCase("cba,dc,fde->cbf", std::vector{2, 2, 2}, std::vector{10.f, 26.f, 50.f, 130.f, 144.f, 432.f, 208.f, 624.f}), - EinsumTestCase("cba,dc,fde->cbd", std::vector{2, 2, 2}, std::vector{0.f, 36.f, 0.f, 180.f, 90.f, 486.f, 130.f, 702.f}), - EinsumTestCase("cba,dc,fde->caf", std::vector{2, 2, 2}, std::vector{20.f, 52.f, 40.f, 104.f, 160.f, 480.f, 192.f, 576.f}), - EinsumTestCase("cba,dc,fde->cad", std::vector{2, 2, 2}, std::vector{0.f, 72.f, 0.f, 144.f, 100.f, 540.f, 120.f, 648.f}), - EinsumTestCase("cba,dc,fed->cbf", std::vector{2, 2, 2}, std::vector{8.f, 24.f, 40.f, 120.f, 126.f, 414.f, 182.f, 598.f}), - EinsumTestCase("cba,dc,fed->cbe", std::vector{2, 2, 2}, std::vector{12.f, 20.f, 60.f, 100.f, 198.f, 342.f, 286.f, 494.f}), - EinsumTestCase("cba,dc,fed->caf", std::vector{2, 2, 2}, std::vector{16.f, 48.f, 32.f, 96.f, 140.f, 460.f, 168.f, 552.f}), - EinsumTestCase("cba,dc,fed->cae", std::vector{2, 2, 2}, std::vector{24.f, 40.f, 48.f, 80.f, 220.f, 380.f, 264.f, 456.f})}; - - auto test_lambda = [](const std::vector& test_cases_set) { - std::vector m1{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; - std::vector m2{0.f, 1.f, 2.f, 3.f}; - std::vector m3{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; - for (const auto& tst : test_cases_set) { - OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); - test.AddAttribute("equation", tst.equation); - test.AddInput("x", {2, 2, 2}, m1); - test.AddInput("y", {2, 2}, m2); - test.AddInput("z", {2, 2, 2}, m3); - test.AddOutput("o", tst.shape, tst.expected); - test.Run(); - } - }; - - test_lambda(test_cases_set_1); - test_lambda(test_cases_set_2); - test_lambda(test_cases_set_3); +class EinsumTransposeMatMulThreeInputsTest : public testing::TestWithParam { +}; -} // namespace test +TEST_P(EinsumTransposeMatMulThreeInputsTest, EinsumTransposeMatMulThreeInputsTestSuite) { + const auto& tst = GetParam(); + std::vector m1{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + std::vector m2{0.f, 1.f, 2.f, 3.f}; + std::vector m3{0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f}; + OpTester test("Einsum", 12, onnxruntime::kOnnxDomain); + std::string s(tst.equation); + test.AddAttribute("equation", s); + test.AddInput("x", {2, 2, 2}, m1); + test.AddInput("y", {2, 2}, m2); + test.AddInput("z", {2, 2, 2}, m3); + std::vector v1(tst.shape.begin(), tst.shape.end()); + std::vector v2(tst.expected.begin(), tst.expected.end()); + test.AddOutput("o", v1, v2); + test.Run(); +} + +INSTANTIATE_TEST_SUITE_P(EinsumTransposeMatMulThreeInputsTests, EinsumTransposeMatMulThreeInputsTest, testing::ValuesIn(case1)); } // namespace test } // namespace onnxruntime From 90883a366ae2ce5402c8b887638a8b2ae3e0efdd Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 31 Jan 2024 08:28:53 +0800 Subject: [PATCH 003/207] [js/webgpu] Add hardSigmoid activation for fusedConv (#19233) ### Description Add hardSigmoid activation for fusedConv. It will be used by mobilenetv3-small-100 model. --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 11 +- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 11 +- .../ops/3rd-party/matmul_packed_webgpu.ts | 12 +- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 37 +++-- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 35 ++++- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 12 +- js/web/test/data/ops/fused-conv.jsonc | 144 ++++++++++++++++++ .../core/optimizer/conv_activation_fusion.cc | 2 +- 8 files changed, 207 insertions(+), 57 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 1a03621512888..e5ca3204d4433 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -24,7 +24,7 @@ import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvAttributes} from '../conv'; -import {getActivationSnippet} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; import {biasSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; @@ -193,10 +193,7 @@ export const createConv2DMatMulProgramInfo = {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations} ]; - if (attributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); - } + appendActivationUniformsData(attributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; @@ -212,9 +209,7 @@ export const createConv2DMatMulProgramInfo = {name: 'pad', type: 'i32', length: 2}, {name: 'stride', type: 'i32', length: 2}, {name: 'dilation', type: 'i32', length: 2} ]; - if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(attributes, uniforms); // TODO: support component 2, 3. const components = isVec4 ? 4 : 1; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index 33e50a9a39cb9..e50733559dbe9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -24,7 +24,7 @@ import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; -import {getActivationSnippet} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; import {biasSnippet, typeSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; @@ -201,10 +201,7 @@ export const createConv2DTransposeMatMulProgramInfo = {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations}, {type: 'int32', data: filterDims}, {type: 'int32', data: pads} ]; - if (attributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); - } + appendActivationUniformsData(attributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); @@ -237,9 +234,7 @@ export const createConv2DTransposeMatMulProgramInfo = {name: 'filter_dims', type: 'i32', length: filterDims.length}, {name: 'pads', type: 'i32', length: pads.length} ]; - if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(attributes, uniforms); return ` ${utilFunctions('uniforms.result_strides')} ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 5881c055ef135..00c1f86d67419 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -23,7 +23,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; import {createTensorShapeVariables, getBroadcastDims, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; -import {getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -449,11 +449,7 @@ export const createMatmulProgramInfo = const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; const programUniforms: ProgramUniform[] = [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; - if (activationAttributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: activationAttributes.clipMax!}, - {type: 'float32', data: activationAttributes.clipMin!}); - } + appendActivationUniformsData(activationAttributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), ...createTensorShapeVariables(bShapeTemp)); @@ -481,9 +477,7 @@ export const createMatmulProgramInfo = } const uniforms: UniformsArrayType = [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; - if (activationAttributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(activationAttributes, uniforms); const applyActivation = getActivationSnippet(activationAttributes, output.type.value); const declareFunctions = matMulReadWriteFnSource( components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index f81d6577890c5..c0aaaa7ce134b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -7,7 +7,7 @@ import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../ import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; -import {getActivationSnippet} from './fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from './fuse-utils'; /** * naive grouped conv implementation, supports 1d/2d conv @@ -32,10 +32,7 @@ export const createGroupedConvProgramInfo = {type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]}, {type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup} ]; - if (attributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); - } + appendActivationUniformsData(attributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShape)); @@ -61,9 +58,7 @@ export const createGroupedConvProgramInfo = {name: 'strides', type: 'u32', length: 2}, {name: 'pads', type: 'u32', length: 2}, {name: 'output_channels_per_group', type: 'u32'} ]; - if (attributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(attributes, uniforms); return ` ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} @@ -132,10 +127,13 @@ export const createGroupedConvVectorizeProgramInfo = const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: attributes.strides}, - {type: 'int32', data: attributes.pads}, ...createTensorShapeVariables(xShape), - ...createTensorShapeVariables(wShape), ...createTensorShapeVariables(outputShapeInShader) + {type: 'uint32', data: outputSize}, {type: 'int32', data: [attributes.strides[0], attributes.strides[1]]}, + {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]} ]; + appendActivationUniformsData(attributes, programUniforms); + programUniforms.push( + ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), + ...createTensorShapeVariables(outputShapeInShader)); const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); @@ -147,13 +145,14 @@ export const createGroupedConvVectorizeProgramInfo = inputVars.push(inputVariable('b', inputs[2].dataType, inputs[2].dims, components)); } const processBias = hasBias ? 'value += b[output_channel];' : ''; - + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, + {name: 'strides', type: 'i32', length: 2}, + {name: 'pads', type: 'i32', length: 2}, + ]; + appendActivationUniforms(attributes, uniforms); return ` - ${ - shaderHelper.registerUniform('output_size', 'u32') - .registerUniform('strides', 'i32', 2) - .registerUniform('pads', 'i32', 2) - .declareVariables(...inputVars, output)} + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVars, output)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} let width0 = uniforms.output_shape[3]; @@ -173,7 +172,7 @@ export const createGroupedConvVectorizeProgramInfo = // Use constant instead of uniform can give better performance for w's height/width. for (var w_height: u32 = 0u; w_height < ${wShape[0]}; w_height++) { let x_height = x_corner.x + i32(w_height); - if (x_height >= 0 || u32(x_height) < uniforms.x_shape[1]) { + if (x_height >= 0 && u32(x_height) < uniforms.x_shape[1]) { for (var i = 0; i < ${xNumber}; i++) { let x_width = x_corner.y + i; if (x_width >= 0 && u32(x_width) < uniforms.x_shape[2]) { @@ -185,7 +184,7 @@ export const createGroupedConvVectorizeProgramInfo = for (var w_width: u32 = 0u; w_width < ${wShape[1]}; w_width++) { let w_val = ${w.get('w_height', 'w_width', '0', 'output_channel')}; for (var i = 0u; i < ${outputNumber}u; i++) { - values[i] = fma(x_vals[i * ${attributes.strides[1]}u + w_width], w_val, values[i]); + values[i] = fma(x_vals[i * u32(uniforms.strides[1]) + w_width], w_val, values[i]); } } } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 2e0aa33a957dc..e1dc9a5e0ab7d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -2,11 +2,16 @@ // Licensed under the MIT License. import {MAX_CLIP, MIN_CLIP} from '../../util'; +import {ProgramUniform} from '../types'; + +import {UniformsArrayType} from './common'; export interface InternalActivationAttributes { readonly activation: string; readonly clipMin?: number; readonly clipMax?: number; + readonly alpha?: number; + readonly beta?: number; } export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => { @@ -17,17 +22,41 @@ export const getActivationSnippet = (attributes: InternalActivationAttributes, v return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; case 'Clip': return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`; + case 'HardSigmoid': + return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${valueType}(uniforms.alpha) * value + ${ + valueType}(uniforms.beta)));`; + case '': + return ''; // TODO: adding other activations that can be fused. default: - return ''; + throw new Error(`Unsupported activation ${attributes.activation}`); + } +}; + +export const appendActivationUniformsData = + (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { + if (attributes.activation === 'Clip') { + programUniform.push({type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + } else if (attributes.activation === 'HardSigmoid') { + programUniform.push({type: 'float32', data: attributes.alpha!}, {type: 'float32', data: attributes.beta!}); + } + }; + +export const appendActivationUniforms = (attributes: InternalActivationAttributes, uniforms: UniformsArrayType) => { + if (attributes.activation === 'Clip') { + uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); + } else if (attributes.activation === 'HardSigmoid') { + uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'}); } }; export const parseInternalActivationAttributes = (attributes: Record|undefined): InternalActivationAttributes => { const activation = attributes?.activation as string || ''; - - if (activation === 'Clip') { + if (activation === 'HardSigmoid') { + const [alpha, beta] = attributes?.activation_params as [number, number] || [0.2, 0.5]; + return {activation, alpha, beta}; + } else if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; return {activation, clipMax, clipMin}; } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index c946ea6366123..188b88b2510d8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -7,7 +7,7 @@ import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common'; -import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; +import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[], @@ -32,11 +32,7 @@ export const createNaiveMatmulProgramInfo = {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, {type: 'uint32', data: K} ]; - if (activationAttributes.activation === 'Clip') { - programUniforms.push( - {type: 'float32', data: activationAttributes.clipMax!}, - {type: 'float32', data: activationAttributes.clipMin!}); - } + appendActivationUniformsData(activationAttributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), ...createTensorShapeVariables(bShape)); @@ -69,9 +65,7 @@ export const createNaiveMatmulProgramInfo = {name: 'output_size', type: 'u32'}, {name: 'M', type: 'u32'}, {name: 'N', type: 'u32'}, {name: 'K', type: 'u32'} ]; - if (activationAttributes.activation === 'Clip') { - uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); - } + appendActivationUniforms(activationAttributes, uniforms); const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => { const rank = variable.rank; diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index ad1c0a72c11d3..c734d6db9b92a 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -142,5 +142,149 @@ ] } ] + }, + { + "name": "fused conv with HardSigmoid", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 1, 1], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv with HardSigmoid", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 0, 1, 1], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused group-conv with HardSigmoid", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC group-conv with HardSigmoid", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "HardSigmoid", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0, 5.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "dims": [1, 2, 2, 3], + "type": "float32" + } + ] + } + ] } ] diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index d27603e4ab3a1..b7cb3ba488c62 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -111,7 +111,7 @@ class ConvActivationSelector : public NodeSelector { if (!graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "Relu", {6, 13, 14})) { return std::nullopt; } - } else if (node_ep.empty() || node_ep == kCpuExecutionProvider) { + } else if (node_ep.empty() || node_ep == kCpuExecutionProvider || node_ep == kJsExecutionProvider) { if (!is_supported_non_cuda_rocm_ep_activation(*next_node) && !graph_utils::IsSupportedOptypeVersionAndDomain(*next_node, "HardSigmoid", {6})) { return std::nullopt; From 0c38e96bb558295d68fadf33b6f835e4a181b8a0 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 30 Jan 2024 17:19:08 -0800 Subject: [PATCH 004/207] [Quant tool] Ensure MSFT opset for Q/DQ models (#19335) ### Description Updates qdq quantization to ensure the final model has the `com.microsoft` opset import if the model uses Q/DQ ops with the `com.microsoft` domain (e.g., for int16 quantization) ### Motivation and Context Need to ensure the MSFT domain is correctly set for all relevant cases. Otherwise, shape inferencing tools will raise an exception. --- onnxruntime/python/tools/quantization/qdq_quantizer.py | 2 ++ onnxruntime/test/python/quantization/test_qdq.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index b0153aed766ad..123cfe913d6e2 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -270,6 +270,8 @@ def quantize_model(self): self.model.model.producer_name = __producer__ self.model.model.producer_version = __version__ + if self.qdq_op_domain == ms_domain: + self.model.set_opset_import(ms_domain, 1) return self.model.model diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 4de797400836f..223f405e8947a 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -601,6 +601,13 @@ def verify_qdq(self, per_channel, activation_type, weight_type, extra_options=No ) check_model_correctness(self, model_fp32_path, model_qdq_path, data_reader.get_next()) + # If the model uses Q/DQ ops with "com.microsoft" domain (e.g., for int16 support), + # then ensure the model has the appropriate opset import. + if extra_options and extra_options.get("UseQDQContribOps", False): + qdq_model = onnx.load_model(model_qdq_path) + ms_opset = next((opset for opset in qdq_model.opset_import if opset.domain == "com.microsoft"), None) + self.assertIsNot(ms_opset, None) + def verify_qop(self, per_channel, is_quant_type_int8): np.random.seed(1) model_fp32_path = str(Path(self._tmp_model_dir.name) / f"conv_relu_fp32.{per_channel}.onnx") From e74f141338547d7eea9c6cbe23d1c892174163cf Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 31 Jan 2024 09:39:27 +0800 Subject: [PATCH 005/207] Save stablediffusion and open-clip in pipeline cache (#19314) ### Description 1. save the model to pipeline cache 2. lower the similarly bar to 97 3. publish the generated image that we can check it once the test fails ### Motivation and Context Reduce model downloads --- .../models/stable_diffusion/demo_utils.py | 6 ++- .../models/stable_diffusion/engine_builder.py | 8 +++- .../stable_diffusion/test/check_image.py | 19 +++++--- .../azure-pipelines/bigmodels-ci-pipeline.yml | 47 ++++++++++++++++--- 4 files changed, 64 insertions(+), 16 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py index 32c673416fce2..369f31511faca 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -242,6 +242,8 @@ def parse_arguments(is_xl: bool, parser): parser.add_argument("--deterministic", action="store_true", help="use deterministic algorithms.") parser.add_argument("-dc", "--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + parser.add_argument("--framework-model-dir", default=None, help="framework model directory") + group = parser.add_argument_group("Options for ORT_CUDA engine only") group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.") @@ -406,6 +408,7 @@ def initialize_pipeline( lora_scale=1.0, use_fp16_vae=True, use_vae=True, + framework_model_dir=None, ): pipeline_info = PipelineInfo( version, @@ -425,7 +428,7 @@ def initialize_pipeline( input_engine_dir = engine_dir onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( - work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type + work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type, framework_model_dir=framework_model_dir ) pipeline = StableDiffusionPipeline( @@ -558,6 +561,7 @@ def load_pipelines(args, batch_size=None): "lora_scale": args.lora_scale, "use_fp16_vae": "xl" in args.version, "use_vae": True, + "framework_model_dir": args.framework_model_dir, } if "xl" in args.version: diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index 46a83f5dc228d..c03c6f0b21cd3 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -5,6 +5,7 @@ import hashlib import os from enum import Enum +from typing import Optional import torch from diffusion_models import CLIP, VAE, CLIPWithProj, PipelineInfo, UNet, UNetXL @@ -273,7 +274,9 @@ def vae_decode(self, latents): return self._vae_decode(latents) -def get_engine_paths(work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType): +def get_engine_paths( + work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType, framework_model_dir: Optional[str] = None +): root_dir = work_dir or "." short_name = pipeline_info.short_name() @@ -287,6 +290,7 @@ def get_engine_paths(work_dir: str, pipeline_info: PipelineInfo, engine_type: En # Shared among ORT_CUDA, ORT_TRT and TRT engines, and need use load_model(..., always_download_fp16=True) # So that the shared model is always fp16. - framework_model_dir = os.path.join(root_dir, "torch_model") + if framework_model_dir is None: + framework_model_dir = os.path.join(root_dir, "torch_model") return onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py index fcfe8b081fb0a..da7f47b144b9b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py @@ -1,5 +1,6 @@ import argparse import os +from typing import Optional import cv2 import open_clip @@ -12,13 +13,16 @@ def arg_parser(): parser = argparse.ArgumentParser(description="Options for Compare 2 image") parser.add_argument("--image1", type=str, help="Path to image 1") parser.add_argument("--image2", type=str, help="Path to image 2") + parser.add_argument("--cache_dir", type=str, help="Path to model cache directory") args = parser.parse_args() return args -def image_encoder(img: Image.Image): # -> torch.Tensor: +def image_encoder(img: Image.Image, cache_dir: Optional[str] = None): # -> torch.Tensor: device = "cuda" if torch.cuda.is_available() else "cpu" - model, _, preprocess = open_clip.create_model_and_transforms("ViT-B-16-plus-240", pretrained="laion400m_e32") + model, _, preprocess = open_clip.create_model_and_transforms( + "ViT-B-16-plus-240", pretrained="laion400m_e32", cache_dir=cache_dir + ) model.to(device) img1 = Image.fromarray(img).convert("RGB") @@ -41,11 +45,11 @@ def load_image(image_path: str): # -> Image.Image: return img -def generate_score(image1: str, image2: str): # -> float: +def generate_score(image1: str, image2: str, cache_dir: Optional[str] = None): # -> float: test_img = load_image(image1) data_img = load_image(image2) - img1 = image_encoder(test_img) - img2 = image_encoder(data_img) + img1 = image_encoder(test_img, cache_dir) + img2 = image_encoder(data_img, cache_dir) cos_scores = util.pytorch_cos_sim(img1, img2) score = round(float(cos_scores[0][0]) * 100, 2) return score @@ -55,9 +59,10 @@ def main(): args = arg_parser() image1 = args.image1 image2 = args.image2 - score = round(generate_score(image1, image2), 2) + cache_dir = args.cache_dir + score = round(generate_score(image1, image2, cache_dir), 2) print("similarity Score: ", {score}) - if score < 99: + if score < 97: print(f"{image1} and {image2} are different") raise SystemExit(1) else: diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index dd88a4d6d5632..0de2ac44215c4 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -172,6 +172,9 @@ stages: - job: Stable_Diffusion variables: skipComponentGovernanceDetection: true + CLIP_MODEL_CACHE: $(Agent.TempDirectory)/clip_cache + STABLE_DIFFUSION_MODEL_CACHE: $(Agent.TempDirectory)/stablediffusion_cache + GenerateImage_DIR: $(Agent.TempDirectory)/images workspace: clean: all pool: onnxruntime-Linux-GPU-A10-12G @@ -188,9 +191,23 @@ stages: SpecificArtifact: ${{ parameters.specificArtifact }} BuildId: ${{ parameters.BuildId }} + - task: Cache@2 + inputs: + key: stable_diffusion | $(Build.SourcesDirectory)/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py + restoreKeys: | + stable_diffusion | $(Build.SourcesDirectory)/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py + stable_diffusion + path: $(STABLE_DIFFUSION_MODEL_CACHE) + displayName: Cache stable diffusion model + - script: | - docker run --rm --gpus all -v $PWD:/workspace -v $(Build.BinariesDirectory)/Release:/Release nvcr.io/nvidia/pytorch:22.11-py3 \ - bash -c ' + mkdir -p $(GenerateImage_DIR) + docker run --rm --gpus all -v $PWD:/workspace \ + -v $(Build.BinariesDirectory)/Release:/Release \ + -v $(STABLE_DIFFUSION_MODEL_CACHE):/model_cache:rw \ + -v $(GenerateImage_DIR):/images:rw \ + nvcr.io/nvidia/pytorch:22.11-py3 \ + bash -c ' \ set -ex; \ python3 --version; \ python3 -m pip install --upgrade pip; \ @@ -199,15 +216,33 @@ stages: python3 -m pip install -r requirements-cuda11.txt; \ python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com; \ echo Generate an image guided by a text prompt; \ - python3 demo_txt2img.py --seed 1 --deterministic "astronaut riding a horse on mars" ; \ - find $(pwd) -name "*.png" ; \ + python3 demo_txt2img.py --framework-model-dir /model_cache --seed 1 --deterministic "astronaut riding a horse on mars" ; \ + find $(pwd)/ORT_CUDA -name "*.png" -exec cp {} /images/ \; ; \ popd ; \ ' displayName: 'Run stable diffusion demo' workingDirectory: $(Build.SourcesDirectory) + # For verification we will check the generated image looks . + - task: PublishPipelineArtifact@0 + displayName: 'Publish code coverage report' + inputs: + artifactName: "Generated-Image" + targetPath: '$(GenerateImage_DIR)' + + - task: Cache@2 + inputs: + key: clip_model | $(Build.SourcesDirectory)/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py + restoreKeys: | + clip_model | $(Build.SourcesDirectory)/onnxruntime/python/tools/transformers/models/stable_diffusion/test/check_image.py + clip_model + path: $(CLIP_MODEL_CACHE) + displayName: Cache clip model + - script: | - docker run --rm --gpus all -v $PWD:/workspace nvcr.io/nvidia/pytorch:22.11-py3 \ + docker run --rm --gpus all -v $PWD:/workspace \ + -v $(CLIP_MODEL_CACHE):/model_cache:rw \ + nvcr.io/nvidia/pytorch:22.11-py3 \ bash -c ' set -ex; \ python3 --version; \ @@ -217,7 +252,7 @@ stages: pushd test; \ python3 -m pip install -r requirements.txt; \ echo check demo_txt2image.py generate image; \ - python3 -u check_image.py --image1 astronaut_riding_txt2image-DDIM-50.png --image2 $image2; \ + python3 -u check_image.py --image1 astronaut_riding_txt2image-DDIM-50.png --image2 $image2 --cache_dir /model_cache ; \ popd ; \ popd ; \ ' From 1e936bfd6391b50a6cfc96730addd68ac398ef5e Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 31 Jan 2024 10:09:16 +0800 Subject: [PATCH 006/207] [WebNN] Ignore empty optional input tensor (#19235) Empty optional input tensors are indicated by an empty name, which are allowed and we should just ignore them. --- .../core/providers/webnn/builders/helper.h | 4 +++ .../webnn/builders/impl/pad_op_builder.cc | 6 ++-- .../builders/impl/reduction_op_builder.cc | 3 +- .../webnn/builders/impl/resize_op_builder.cc | 32 ++++++++++++------- .../webnn/builders/impl/slice_op_builder.cc | 6 ++-- .../webnn/builders/impl/split_op_builder.cc | 8 ++--- .../impl/squeeze_unsqueeze_op_builder.cc | 4 +-- 7 files changed, 40 insertions(+), 23 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 92aa9abc9fdf7..d94729e60d029 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -54,6 +54,10 @@ std::string GetShapeString(std::vector& shape) { return shape_info.str(); } +inline std::string GetTensorName(const ConstPointerContainer>& input_defs, const size_t index) { + return (input_defs.size() > index) ? std::string(input_defs[index]->Name()) : ""; +} + inline std::vector GetVecUint32FromVecInt64(const std::vector& int64_vec) { std::vector uint32_vec; uint32_vec.reserve(int64_vec.size()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc index a2a1e2f2e599d..52b5518857773 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc @@ -178,8 +178,10 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } for (size_t i = 1; i < input_defs.size(); i++) { - if (!Contains(initializers, input_defs[i]->Name())) { - LOGS(logger, VERBOSE) << "Input [" << input_defs[i]->Name() << "] must be known as initializer"; + // Optional tensors (constant_value, axes) can be indicated by an empty name, just ignore it. + const std::string input_name = GetTensorName(input_defs, i); + if (!input_name.empty() && !Contains(initializers, input_name)) { + LOGS(logger, VERBOSE) << "Input [" << input_name << "] must be known as initializer"; return false; } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index 1a702649b7f05..f446a7b81d1c0 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -134,8 +134,9 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ return false; const auto& op_type = node.OpType(); + const std::string axes_name = GetTensorName(input_defs, 1); // If the optional input 'axes' is provided, it must be an initializer. - if (input_defs.size() > 1 && !Contains(initializers, input_defs[1]->Name())) { + if (!axes_name.empty() && !Contains(initializers, axes_name)) { LOGS(logger, VERBOSE) << "Input axes of " << op_type << " must be a constant"; return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 186d1e7c1035a..9018f8c96f300 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -120,8 +120,9 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector scales_hw; std::vector sizes_hw; std::vector axes; + std::string scales_name = GetTensorName(input_defs, 2); const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; - if (input_defs.size() == 3) { // Use scales. + if (!scales_name.empty()) { // Use scales. ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); if (is_nhwc) { scales_hw = {scales[1], scales[2]}; @@ -129,7 +130,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, scales_hw = {scales[2], scales[3]}; } options.set("scales", emscripten::val::array(scales_hw)); - } else { // We already checked number of inputs in IsOpSupportedImpl. + } else { // Use sizes, we already checked inputs in IsOpSupportedImpl. std::vector output_sizes; ORT_RETURN_IF_NOT(GetResizeOutputSizes(initializers, node, output_sizes, logger), "Error getting resize output_sizes"); @@ -203,26 +204,31 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers } { // scales and sizes (if present) must be initializers. - if (input_defs.size() < 3) { - LOGS(logger, VERBOSE) << "Input scales or sizes of Resize must be known"; - return false; - } + const std::string scales_name = GetTensorName(input_defs, 2); + const std::string sizes_name = GetTensorName(input_defs, 3); - // scales - if (input_defs.size() == 3 && !Contains(initializers, input_defs[2]->Name())) { + // scales (scales may be empty tensor) + bool has_scales = !scales_name.empty(); + if ((has_scales && !Contains(initializers, scales_name)) || (!has_scales && node.SinceVersion() == 11)) { LOGS(logger, VERBOSE) << "Input scales of Resize must be known"; return false; } - // sizes - if (input_defs.size() > 3 && !Contains(initializers, input_defs[3]->Name())) { + // sizes (sizes may be empty tensor) + bool has_sizes = !sizes_name.empty(); + if (has_sizes && !Contains(initializers, sizes_name)) { LOGS(logger, VERBOSE) << "Input sizes of Resize must be known"; return false; } + if (has_scales && has_sizes) { + LOGS(logger, VERBOSE) << "Only one of 'scales' and 'sizes' can be specified"; + return false; + } + const bool is_nhwc = node.Domain() == kMSInternalNHWCDomain; // We want to check if the scales or sizes are not trying to resize on N/C channels here. - if (input_defs.size() == 3) { // We are using scales. + if (has_scales) { // We are using scales. std::vector scales; if (!GetResizeScales(initializers, node, scales, logger)) return false; @@ -251,7 +257,9 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers LOGS(logger, VERBOSE) << "Resize: scale_w: " << scale_w << " is not a whole number"; return false; } - } else { + } + + if (has_sizes) { // We are using sizes. std::vector output_sizes; if (!GetResizeOutputSizes(initializers, node, output_sizes, logger)) diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index e48cf35012652..4e0628581abf2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -123,8 +123,10 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, // Inputs: starts, ends, axes, and steps must be constant initializers if present. for (size_t i = 1; i < input_defs.size(); i++) { - if (!Contains(initializers, input_defs[i]->Name())) { - LOGS(logger, VERBOSE) << "Input [" << input_defs[i]->Name() << "] of " << op_type + // Optional tensors (axes, steps) can be indicated by an empty name, just ignore it. + const std::string input_name = GetTensorName(input_defs, i); + if (!input_name.empty() && !Contains(initializers, input_name)) { + LOGS(logger, VERBOSE) << "Input [" << input_name << "] of " << op_type << " [" << name << "] must be known as initializer"; return false; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index d568d4e625077..e9a600a5933af 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -136,9 +136,9 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, int32_t axis = helper.Get("axis", 0); axis = SafeInt(HandleNegativeAxis(axis, rank)); - if (input_defs.size() == 2) { - // Inputs contains optional 'split' input - const auto& split_name = input_defs[1]->Name(); + const std::string split_name = GetTensorName(input_defs, 1); + // Inputs contain optional 'split' input. + if (!split_name.empty()) { if (!Contains(initializers, split_name)) { LOGS(logger, VERBOSE) << "The split must be a constant initializer."; return false; @@ -166,7 +166,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified."; return false; } - } else if (input_defs.size() == 1) { + } else { if (helper.HasAttr("num_outputs")) { // Split has 'num_outputs' attribute when opset is 18. const int32_t num_outputs = helper.Get("num_outputs", 1); diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 2a1672c001b0e..9f6ccb98f79dd 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -138,8 +138,8 @@ bool SqueezeUnsqueezeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& in // Squeeze/Unsqueeze opset 13 uses input 1 as axes, it needs to be an initializer. if (node.SinceVersion() >= 13) { - if (input_defs.size() > 1) { - const auto& axes_name = input_defs[1]->Name(); + const std::string axes_name = GetTensorName(input_defs, 1); + if (!axes_name.empty()) { if (!Contains(initializers, axes_name)) { LOGS(logger, ERROR) << "Input axes of " << op_type << " is not present and constant"; return false; From 6dd0079d133da09c286cb097b791f04ef387f5b2 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 31 Jan 2024 12:25:34 +1000 Subject: [PATCH 007/207] Exclude more code from custom_ops.cc when not required in minimal build (#19142) ### Description - Split out the code that implements the OrtKernelContext API (used by compiled nodes and custom ops) and the code that implements the custom ops API. - Exclude based on minimal build settings using helpers - the main change is to simply wrap the implementation into a lambda so it can be easily enabled/disabled - actual implementation of all functions are unchanged - Re-organize so the related implementations are together - most diffs are from this, but without the reorg it would be much harder to know which helper to use - General cleanup of lines that were too long. ### Motivation and Context Saves ~10KB in a minimal build. Build command used for comparison ``` ./build --android --android_api=29 --android_sdk="d:\Android" --android_abi=arm64-v8a --parallel --android_ndk_path="D:\Android\ndk\26.0.10792818\" --build_shared_lib --cmake_generator Ninja --skip_tests --minimal_build --disable_rtti --disable_ml_ops --disable_exceptions --cmake_extra_defines=onnxruntime_BUILD_UNIT_TESTS=OFF --include_ops_by_config .\no_ops.config --config MinSizeRel ``` Main: 1,218,480 bytes With changes: 1,208,320 bytes --- onnxruntime/core/session/custom_ops.cc | 958 +++++++++++++------------ 1 file changed, 508 insertions(+), 450 deletions(-) diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 4bae42f4b80ad..7a233c57cfdf3 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -26,8 +26,15 @@ #include "core/session/ort_apis.h" #include "core/platform/threadpool.h" +// NOTE: OrtKernelContext is used by both custom ops and compiled kernels. +// In a minimal build, ORT_EXTENDED_MINIMAL_BUILD is used to enable EPs like CoreML/NNAPI which use compiled kernels, +// and ORT_MINIMAL_BUILD_CUSTOM_OPS is used to allow external custom op libraries to be used. +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +#define ENABLE_ORT_KERNEL_CONTEXT_API 1 +#endif + #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -#define ENABLE_CUSTOM_OP_API +#define ENABLE_CUSTOM_OP_API 1 #endif #if !defined(ORT_MINIMAL_BUILD) @@ -36,7 +43,7 @@ static constexpr uint32_t min_ort_version_with_variadic_io_support = 14; static constexpr uint32_t min_ort_version_with_custom_version = 17; #endif -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +#if ENABLE_CUSTOM_OP_API static constexpr uint32_t min_ort_version_with_compute_v2_support = 16; static constexpr uint32_t min_ort_version_with_shape_inference = 17; #endif @@ -52,7 +59,8 @@ struct OrtShapeInferContext { size_t GetInputCount() const { return 0; } OrtTensorTypeAndShapeInfo* GetInputTypeShape(size_t) const { return {}; } onnxruntime::Status SetOutputTypeShape(size_t, const OrtTensorTypeAndShapeInfo*) const { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtShapeInferContext::SetOutputTypeShape not implemented for minimal build"); + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtShapeInferContext::SetOutputTypeShape not implemented for minimal build"); } const ONNX_NAMESPACE::AttributeProto* GetAttr(const char*) const { return {}; } }; @@ -63,13 +71,15 @@ struct OrtShapeInferContext { for (size_t ith_input = 0; ith_input < num_inputs; ++ith_input) { const auto* input_type = ctx_.getInputType(ith_input); const auto& value_case = input_type->value_case(); - ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kTensorType, "shape inference not yet supported for non-tensor types"); + ORT_ENFORCE(value_case == ONNX_NAMESPACE::TypeProto::kTensorType, + "shape inference not yet supported for non-tensor types"); const auto& shape_proto = input_type->tensor_type().shape(); const auto& type_proto = input_type->tensor_type(); auto elem_type = ::onnxruntime::utils::CApiElementTypeFromProtoType(type_proto.elem_type()); auto tensor_shape = ::onnxruntime::utils::GetTensorShapeFromTensorShapeProto(shape_proto); auto symbolic_dims = GetSymbolicDims(shape_proto); - input_type_shapes_.emplace_back(OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims).release()); + input_type_shapes_.emplace_back( + OrtTensorTypeAndShapeInfo::GetTensorShapeAndTypeHelper(elem_type, tensor_shape, &symbolic_dims).release()); } } @@ -121,304 +131,392 @@ struct OrtShapeInferContext { }; #endif -ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, _Out_ size_t* out) { +#if ENABLE_ORT_KERNEL_CONTEXT_API +template +static OrtStatusPtr ExecuteIfKernelContextApiEnabled(const T& fn) { API_IMPL_BEGIN - *out = context->GetInputCount(); - return nullptr; + return fn(); API_IMPL_END } +#else +template +static OrtStatusPtr ExecuteIfKernelContextApiEnabled(const T&) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "OrtKernelContext API is not enabled in this build"); +} +#endif -ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info) { - API_IMPL_BEGIN - *info = context->GetInputTypeShape(index); - if (*info) { +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out) { + return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + *out = reinterpret_cast(context)->InputCount(); return nullptr; - } else { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Failed to fetch type shape info for the index."); - } - API_IMPL_END -} + }); +}; -ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr) { - API_IMPL_BEGIN - *attr = reinterpret_cast(context->GetAttr(attr_name)); - if (*attr) { +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out) { + return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + *out = reinterpret_cast(context)->OutputCount(); return nullptr; - } else { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist."); - } - API_IMPL_END -} + }); +}; -ORT_API_STATUS_IMPL(OrtApis::ReadOpAttr, - _In_ const OrtOpAttr* op_attr, - _In_ OrtOpAttrType type, - _Inout_ void* data, - _In_ size_t len, - _Out_ size_t* out) { - API_IMPL_BEGIN +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, + _Out_ const OrtValue** out) { + return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + const auto* ctx = reinterpret_cast(context); + *out = reinterpret_cast(ctx->GetInputMLValue(onnxruntime::narrow(index))); + return nullptr; + }); +}; - if (!op_attr) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Invalid attribute."); - } +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, + _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out) { + return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + onnxruntime::TensorShape shape(dim_values, dim_count); + auto* ctx = reinterpret_cast(context); + *out = reinterpret_cast(ctx->OutputMLValue(onnxruntime::narrow(index), shape)); + return nullptr; + }); +}; - auto attr = reinterpret_cast(op_attr); - OrtStatusPtr ret = nullptr; - *out = 0; +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, + _Outptr_ void** out) { + return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + auto* stream = reinterpret_cast(context)->GetComputeStream(); + if (stream) + *out = stream->GetHandle(); + else + *out = nullptr; + return nullptr; + }); +}; - if (type == OrtOpAttrType::ORT_OP_ATTR_FLOAT) { - if (len < sizeof(float)) { - ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold a float."); - } else { - if (attr->has_f()) { - auto output_f = reinterpret_cast(data); - *output_f = attr->f(); - } else { - ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute has no float value."); - } +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetAllocator, _In_ const OrtKernelContext* context, + _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out) { + return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + const auto* ctx = reinterpret_cast(context); + onnxruntime::AllocatorPtr allocator = ctx->GetAllocator(mem_info->device); + if (!allocator) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available"); } - *out = sizeof(float); - } else if (type == OrtOpAttrType::ORT_OP_ATTR_FLOATS) { - const auto& floats = attr->floats(); - auto num_floats = floats.size(); + auto p = std::make_unique(std::move(allocator)); + *out = p.release(); + return nullptr; + }); +}; - if (len < sizeof(float) * num_floats) { - ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold the array of floats."); - } else { - auto output_f = reinterpret_cast(data); - for (auto f : floats) { - *output_f = f; - output_f++; - } +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetResource, _In_ const OrtKernelContext* context, + _In_ int resource_version, _In_ int resource_id, _Outptr_ void** resource) { + return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + *resource = {}; + const auto* ctx = reinterpret_cast(context); + auto* stream = reinterpret_cast(ctx->GetComputeStream()); + if (!stream) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Failed to fetch a stream hosting the requested resource"); } - *out = num_floats * sizeof(float); + *resource = stream->GetResource(resource_version, resource_id); + return nullptr; + }); +}; - } else if (type == OrtOpAttrType::ORT_OP_ATTR_INT) { - if (len < sizeof(int)) { - ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold an int64."); - } else { - if (attr->has_i()) { - auto output_i = reinterpret_cast(data); - *output_i = attr->i(); +ORT_API_STATUS_IMPL(OrtApis::KernelContext_ParallelFor, _In_ const OrtKernelContext* context, + _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data) { + return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + if (!context) { + return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, "Invalid context"); + } + if (fn && total) { + const auto* ctx = reinterpret_cast(context); + auto* tp = ctx->GetOperatorThreadPool(); + if (num_batch) { + onnxruntime::concurrency::ThreadPool::TryBatchParallelFor( + tp, + static_cast(total), + [fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast(ith)); }, + static_cast(num_batch)); } else { - ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute has no int64 value."); + onnxruntime::concurrency::ThreadPool::TrySimpleParallelFor( + tp, + static_cast(total), + [fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast(ith)); }); } } - *out = sizeof(int64_t); + return nullptr; + }); +}; - } else if (type == OrtOpAttrType::ORT_OP_ATTR_INTS) { - const auto& ints = attr->ints(); - auto num_ints = ints.size(); +ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetLogger, _In_ const OrtKernelContext* context, + _Outptr_ const OrtLogger** logger) { + return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + const auto& kernel_ctx_logger = reinterpret_cast(context)->Logger(); - if (len < sizeof(int64_t) * num_ints) { - ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold the array of int64."); - } else { - auto output_i = reinterpret_cast(data); - for (auto i : ints) { - *output_i = i; - output_i++; - } - } - *out = num_ints * sizeof(int64_t); + *logger = reinterpret_cast(&kernel_ctx_logger); + return nullptr; + }); +} - } else if (type == OrtOpAttrType::ORT_OP_ATTR_STRING) { - const auto& s = attr->s(); - if (len < s.size() + 1) { - ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold the string."); - } else { - char* output_c = reinterpret_cast(data); - for (char c : s) { - *output_c++ = c; - } - *output_c = '\0'; - } - *out = s.size() + 1; +// Enabled via ExecuteIfKernelContextApiEnabled due to KernelContext_GetLogger +ORT_API_STATUS_IMPL(OrtApis::Logger_LogMessage, _In_ const OrtLogger* logger, OrtLoggingLevel log_severity_level, + _In_z_ const char* message, _In_z_ const ORTCHAR_T* file_path, int line_number, + _In_z_ const char* func_name) { + return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + const auto& actual_logger = *reinterpret_cast(logger); + const auto severity = static_cast(log_severity_level); + const auto log_data_type = onnxruntime::logging::DataType::SYSTEM; - } else if (type == OrtOpAttrType::ORT_OP_ATTR_STRINGS) { - const auto& ss = attr->strings(); - size_t num_bytes = 0; - for_each(ss.begin(), ss.end(), [&num_bytes](const std::string& s) { num_bytes += s.size() + 1; }); + if (actual_logger.OutputIsEnabled(severity, log_data_type)) { +#ifdef _WIN32 + const std::string file_path_str = onnxruntime::ToUTF8String(file_path); + onnxruntime::CodeLocation location(file_path_str.c_str(), line_number, func_name); +#else + onnxruntime::CodeLocation location(file_path, line_number, func_name); +#endif - if (len < num_bytes) { - ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold the array of strings."); - } else { - char* output_c = reinterpret_cast(data); - for (const auto& s : ss) { - for (char c : s) { - *output_c++ = c; - } - *output_c++ = '\0'; - } + onnxruntime::logging::Capture( + actual_logger, + severity, + onnxruntime::logging::Category::onnxruntime, + log_data_type, + location) + .Stream() + << message; } - *out = num_bytes; - } else { - ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unknown attribute type."); - } - - return ret; - API_IMPL_END + return nullptr; + }); } -ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info) { - API_IMPL_BEGIN - auto status = context->SetOutputTypeShape(index, info); - if (status.IsOK()) { +// Enabled via ExecuteIfKernelContextApiEnabled due to KernelContext_GetLogger +ORT_API_STATUS_IMPL(OrtApis::Logger_GetLoggingSeverityLevel, _In_ const OrtLogger* logger, + _Out_ OrtLoggingLevel* out) { + return ExecuteIfKernelContextApiEnabled([&]() -> OrtStatusPtr { + const auto& actual_logger = *reinterpret_cast(logger); + *out = static_cast(actual_logger.GetSeverity()); return nullptr; - } else { - return OrtApis::CreateStatus(static_cast(status.Code()), status.ErrorMessage().c_str()); - } - API_IMPL_END + }); } -ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out) { +#if ENABLE_CUSTOM_OP_API +template +static OrtStatusPtr ExecuteIfCustomOpsApiEnabled(const T& fn) { API_IMPL_BEGIN - auto status = reinterpret_cast(info)->GetAttr(name, out); - if (status.IsOK()) - return nullptr; - return onnxruntime::ToOrtStatus(status); + return fn(); API_IMPL_END } +#else +template +static OrtStatusPtr ExecuteIfCustomOpsApiEnabled(const T&) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Custom operator API is not enabled in this build"); +} +#endif -ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out) { - API_IMPL_BEGIN - auto status = reinterpret_cast(info)->GetAttr(name, out); - if (status.IsOK()) +ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, + _Out_ size_t* out) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + *out = context->GetInputCount(); return nullptr; - return onnxruntime::ToOrtStatus(status); - API_IMPL_END + }); } -ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out) { - API_IMPL_BEGIN - *out = reinterpret_cast(context)->InputCount(); - return nullptr; - API_IMPL_END -}; - -ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out) { - API_IMPL_BEGIN - *out = reinterpret_cast(context)->OutputCount(); - return nullptr; - API_IMPL_END -}; - -ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, _Out_ const OrtValue** out) { - API_IMPL_BEGIN - *out = reinterpret_cast(reinterpret_cast(context)->GetInputMLValue(gsl::narrow_cast(index))); - return nullptr; - API_IMPL_END -}; - -ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count, _Out_ OrtValue** out) { - API_IMPL_BEGIN - onnxruntime::TensorShape shape(dim_values, dim_count); - *out = reinterpret_cast(reinterpret_cast(context)->OutputMLValue(gsl::narrow_cast(index), shape)); - return nullptr; - API_IMPL_END -}; +ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, + _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + *info = context->GetInputTypeShape(index); + if (*info) { + return nullptr; + } else { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "Failed to fetch type shape info for the index."); + } + }); +} -ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, _Inout_ size_t* size) { - API_IMPL_BEGIN - std::string value; - auto status = reinterpret_cast(info)->GetAttr(name, &value); - if (status.IsOK()) { - if (out == nullptr) { // User is querying the true size of the attribute - *size = value.size() + 1; +ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, + _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + *attr = reinterpret_cast(context->GetAttr(attr_name)); + if (*attr) { return nullptr; - } else if (*size >= value.size() + 1) { - std::memcpy(out, value.data(), value.size()); - out[value.size()] = '\0'; - *size = value.size() + 1; + } else { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist."); + } + }); +} + +ORT_API_STATUS_IMPL(OrtApis::ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, + _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + auto status = context->SetOutputTypeShape(index, info); + if (status.IsOK()) { return nullptr; - } else { // User has provided a buffer that is not large enough - *size = value.size() + 1; - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Result buffer is not large enough"); + } else { + return OrtApis::CreateStatus(static_cast(status.Code()), status.ErrorMessage().c_str()); } - } - return onnxruntime::ToOrtStatus(status); - API_IMPL_END + }); } -#ifdef _WIN32 -#pragma warning(push) -#pragma warning(disable : 28196 6387) -#endif +ORT_API_STATUS_IMPL(OrtApis::ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, + _In_ size_t len, _Out_ size_t* out) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + if (!op_attr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Invalid attribute."); + } -ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out) { - API_IMPL_BEGIN - auto* stream = reinterpret_cast(context)->GetComputeStream(); - if (stream) - *out = stream->GetHandle(); - else - *out = nullptr; - return nullptr; - API_IMPL_END -}; + auto attr = reinterpret_cast(op_attr); + OrtStatusPtr ret = nullptr; + *out = 0; -ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out) { - API_IMPL_BEGIN - onnxruntime::AllocatorPtr allocator = reinterpret_cast(context)->GetAllocator(mem_info->device); - if (!allocator) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available"); - } - std::unique_ptr p = std::make_unique(std::move(allocator)); - *out = p.release(); - return nullptr; - API_IMPL_END -}; + switch (type) { + case OrtOpAttrType::ORT_OP_ATTR_FLOAT: { + if (len < sizeof(float)) { + ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "Size of data not large enough to hold a float."); + } else { + if (attr->has_f()) { + auto output_f = reinterpret_cast(data); + *output_f = attr->f(); + } else { + ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute has no float value."); + } + } + *out = sizeof(float); -ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resource_version, _In_ int resource_id, _Outptr_ void** resource) { - API_IMPL_BEGIN - *resource = {}; - const auto* ctx = reinterpret_cast(context); - auto* stream = reinterpret_cast(ctx->GetComputeStream()); - if (!stream) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Failed to fetch a stream hosting the requested resource"); - } - *resource = stream->GetResource(resource_version, resource_id); - return nullptr; - API_IMPL_END -}; + break; + } + case OrtOpAttrType::ORT_OP_ATTR_FLOATS: { + const auto& floats = attr->floats(); + auto num_floats = floats.size(); -ORT_API_STATUS_IMPL(OrtApis::KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data) { -#ifdef ENABLE_CUSTOM_OP_API - API_IMPL_BEGIN - if (!context) { - return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, "Invalid context"); - } - if (fn && total) { - const auto* ctx = reinterpret_cast(context); - auto* tp = ctx->GetOperatorThreadPool(); - if (num_batch) { - onnxruntime::concurrency::ThreadPool::TryBatchParallelFor( - tp, - static_cast(total), - [fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast(ith)); }, - static_cast(num_batch)); - } else { - onnxruntime::concurrency::ThreadPool::TrySimpleParallelFor( - tp, - static_cast(total), - [fn, usr_data](std::ptrdiff_t ith) { fn(usr_data, static_cast(ith)); }); + if (len < sizeof(float) * num_floats) { + ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "Size of data not large enough to hold the array of floats."); + } else { + auto output_f = reinterpret_cast(data); + for (auto f : floats) { + *output_f = f; + output_f++; + } + } + *out = num_floats * sizeof(float); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_INT: { + if (len < sizeof(int)) { + ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "Size of data not large enough to hold an int64."); + } else { + if (attr->has_i()) { + auto output_i = reinterpret_cast(data); + *output_i = attr->i(); + } else { + ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute has no int64 value."); + } + } + *out = sizeof(int64_t); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_INTS: { + const auto& ints = attr->ints(); + auto num_ints = ints.size(); + + if (len < sizeof(int64_t) * num_ints) { + ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "Size of data not large enough to hold the array of int64."); + } else { + auto output_i = reinterpret_cast(data); + for (auto i : ints) { + *output_i = i; + output_i++; + } + } + *out = num_ints * sizeof(int64_t); + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRING: { + const auto& s = attr->s(); + if (len < s.size() + 1) { + ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "Size of data not large enough to hold the string."); + } else { + char* output_c = reinterpret_cast(data); + for (char c : s) { + *output_c++ = c; + } + *output_c = '\0'; + } + *out = s.size() + 1; + break; + } + case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { + const auto& ss = attr->strings(); + size_t num_bytes = 0; + for_each(ss.begin(), ss.end(), [&num_bytes](const std::string& s) { num_bytes += s.size() + 1; }); + + if (len < num_bytes) { + ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "Size of data not large enough to hold the array of strings."); + } else { + char* output_c = reinterpret_cast(data); + for (const auto& s : ss) { + for (char c : s) { + *output_c++ = c; + } + *output_c++ = '\0'; + } + } + *out = num_bytes; + break; + } + default: + ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type. "); } - } - return nullptr; - API_IMPL_END -#else - ORT_UNUSED_PARAMETER(context); - ORT_UNUSED_PARAMETER(fn); - ORT_UNUSED_PARAMETER(total); - ORT_UNUSED_PARAMETER(num_batch); - ORT_UNUSED_PARAMETER(usr_data); - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "ParallelFor API not implemented for this build"); -#endif -}; -#ifdef _WIN32 -#pragma warning(pop) -#endif + return ret; + }); +} + +ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ float* out) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + auto status = reinterpret_cast(info)->GetAttr(name, out); + if (status.IsOK()) + return nullptr; + return onnxruntime::ToOrtStatus(status); + }); +} + +ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ int64_t* out) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + auto status = reinterpret_cast(info)->GetAttr(name, out); + if (status.IsOK()) + return nullptr; + return onnxruntime::ToOrtStatus(status); + }); +} + +ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ char* out, _Inout_ size_t* size) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + std::string value; + auto status = reinterpret_cast(info)->GetAttr(name, &value); + if (status.IsOK()) { + if (out == nullptr) { // User is querying the true size of the attribute + *size = value.size() + 1; + return nullptr; + } else if (*size >= value.size() + 1) { + std::memcpy(out, value.data(), value.size()); + out[value.size()] = '\0'; + *size = value.size() + 1; + return nullptr; + } else { // User has provided a buffer that is not large enough + *size = value.size() + 1; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Result buffer is not large enough"); + } + } + return onnxruntime::ToOrtStatus(status); + }); +} template ::value, int>::type = 0> static Status CopyDataFromVectorToMemory(const std::vector& values, T* out, size_t* size) { @@ -438,256 +536,209 @@ static Status CopyDataFromVectorToMemory(const std::vector& values, T* out, s ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out, _Inout_ size_t* size) { - API_IMPL_BEGIN - std::vector values; - auto status = reinterpret_cast(info)->GetAttrs(name, values); - if (status.IsOK()) { - status = CopyDataFromVectorToMemory(values, out, size); - } - return onnxruntime::ToOrtStatus(status); - API_IMPL_END + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + std::vector values; + auto status = reinterpret_cast(info)->GetAttrs(name, values); + if (status.IsOK()) { + status = CopyDataFromVectorToMemory(values, out, size); + } + return onnxruntime::ToOrtStatus(status); + }); } ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ int64_t* out, _Inout_ size_t* size) { - API_IMPL_BEGIN - std::vector values; - auto status = reinterpret_cast(info)->GetAttrs(name, values); - if (status.IsOK()) { - status = CopyDataFromVectorToMemory(values, out, size); - } - return onnxruntime::ToOrtStatus(status); - API_IMPL_END + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + std::vector values; + auto status = reinterpret_cast(info)->GetAttrs(name, values); + if (status.IsOK()) { + status = CopyDataFromVectorToMemory(values, out, size); + } + return onnxruntime::ToOrtStatus(status); + }); } ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAttribute_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name, _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out) { - API_IMPL_BEGIN - const auto* op_kinfo = reinterpret_cast(info); - - // Get TensorProto attribute - onnx::TensorProto tensor_proto; - auto status = op_kinfo->GetAttr(name, &tensor_proto); - if (!status.IsOK()) { - return onnxruntime::ToOrtStatus(status); - } + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + const auto* op_kinfo = reinterpret_cast(info); + + // Get TensorProto attribute + onnx::TensorProto tensor_proto; + auto status = op_kinfo->GetAttr(name, &tensor_proto); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } - // Determine the tensor's size in bytes. - size_t req_size = 0; - status = onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &req_size); - if (!status.IsOK()) { - return onnxruntime::ToOrtStatus(status); - } + // Determine the tensor's size in bytes. + size_t req_size = 0; + status = onnxruntime::utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &req_size); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } - // Create Tensor that owns buffer memory that will be allocated with the provided OrtAllocator. - onnxruntime::TensorShape tensor_shape = onnxruntime::utils::GetTensorShapeFromTensorProto(tensor_proto); - const auto* const type = onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); - onnxruntime::AllocatorPtr alloc_ptr = std::make_shared(allocator); - auto tensorp = std::make_unique(type, tensor_shape, std::move(alloc_ptr)); + // Create Tensor that owns buffer memory that will be allocated with the provided OrtAllocator. + onnxruntime::TensorShape tensor_shape = onnxruntime::utils::GetTensorShapeFromTensorProto(tensor_proto); + const auto* type = onnxruntime::DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); + onnxruntime::AllocatorPtr alloc_ptr = std::make_shared(allocator); + auto tensorp = std::make_unique(type, tensor_shape, std::move(alloc_ptr)); - // Deserialize TensorProto into pre-allocated, empty Tensor. - status = onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), nullptr, tensor_proto, *tensorp); - if (!status.IsOK()) { - return onnxruntime::ToOrtStatus(status); - } + // Deserialize TensorProto into pre-allocated, empty Tensor. + status = onnxruntime::utils::TensorProtoToTensor(onnxruntime::Env::Default(), nullptr, tensor_proto, *tensorp); + if (!status.IsOK()) { + return onnxruntime::ToOrtStatus(status); + } - // Initialize OrtValue from Tensor. - auto ml_tensor = onnxruntime::DataTypeImpl::GetType(); - auto value = std::make_unique(); - value->Init(tensorp.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + // Initialize OrtValue from Tensor. + auto ml_tensor = onnxruntime::DataTypeImpl::GetType(); + auto value = std::make_unique(); + value->Init(tensorp.release(), ml_tensor, ml_tensor->GetDeleteFunc()); - *out = value.release(); - return nullptr; - API_IMPL_END + *out = value.release(); + return nullptr; + }); } ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out) { - API_IMPL_BEGIN - *out = reinterpret_cast(info)->GetInputCount(); - return nullptr; - API_IMPL_END + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + *out = reinterpret_cast(info)->GetInputCount(); + return nullptr; + }); }; ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out) { - API_IMPL_BEGIN - *out = reinterpret_cast(info)->GetOutputCount(); - return nullptr; - API_IMPL_END + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + *out = reinterpret_cast(info)->GetOutputCount(); + return nullptr; + }); }; -ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, - _Inout_ size_t* size) { - API_IMPL_BEGIN - const auto* op_info = reinterpret_cast(info); - const auto input_defs = op_info->node().InputDefs(); +ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, + _Out_ char* out, _Inout_ size_t* size) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + const auto* op_info = reinterpret_cast(info); + const auto input_defs = op_info->node().InputDefs(); - if (index >= input_defs.size()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo input index is out of bounds"); - } + if (index >= input_defs.size()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo input index is out of bounds"); + } - auto status = CopyStringToOutputArg(input_defs[index]->Name(), - "Output buffer is not large enough for ::OrtKernelInfo input name", out, size); + auto status = CopyStringToOutputArg(input_defs[index]->Name(), + "Output buffer is not large enough for ::OrtKernelInfo input name", out, size); - return onnxruntime::ToOrtStatus(status); - API_IMPL_END + return onnxruntime::ToOrtStatus(status); + }); } ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, _Inout_ size_t* size) { - API_IMPL_BEGIN - const auto* op_info = reinterpret_cast(info); - const auto output_defs = op_info->node().OutputDefs(); + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + const auto* op_info = reinterpret_cast(info); + const auto output_defs = op_info->node().OutputDefs(); - if (index >= output_defs.size()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo output index is out of bounds"); - } + if (index >= output_defs.size()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo output index is out of bounds"); + } - auto status = CopyStringToOutputArg(output_defs[index]->Name(), - "Output buffer is not large enough for ::OrtKernelInfo output name", out, size); + auto status = CopyStringToOutputArg(output_defs[index]->Name(), + "Output buffer is not large enough for ::OrtKernelInfo output name", + out, size); - return onnxruntime::ToOrtStatus(status); - API_IMPL_END + return onnxruntime::ToOrtStatus(status); + }); } ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetInputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, _Outptr_ OrtTypeInfo** type_info) { - API_IMPL_BEGIN - const auto* op_info = reinterpret_cast(info); - const auto input_defs = op_info->node().InputDefs(); + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + const auto* op_info = reinterpret_cast(info); + const auto input_defs = op_info->node().InputDefs(); - if (index >= input_defs.size()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo input index is out of bounds"); - } + if (index >= input_defs.size()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo input index is out of bounds"); + } - const onnxruntime::NodeArg* node_arg = input_defs[index]; - const ONNX_NAMESPACE::TypeProto* type_proto = node_arg->TypeAsProto(); + const onnxruntime::NodeArg* node_arg = input_defs[index]; + const ONNX_NAMESPACE::TypeProto* type_proto = node_arg->TypeAsProto(); - if (type_proto == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo input does not have a type"); - } + if (type_proto == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo input does not have a type"); + } - auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto); - *type_info = type_info_ret.release(); - return nullptr; - API_IMPL_END + auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto); + *type_info = type_info_ret.release(); + return nullptr; + }); } ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetOutputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, _Outptr_ OrtTypeInfo** type_info) { - API_IMPL_BEGIN - const auto* op_info = reinterpret_cast(info); - const auto output_defs = op_info->node().OutputDefs(); + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + const auto* op_info = reinterpret_cast(info); + const auto output_defs = op_info->node().OutputDefs(); - if (index >= output_defs.size()) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo output index is out of bounds"); - } + if (index >= output_defs.size()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "::OrtKernelInfo output index is out of bounds"); + } - const onnxruntime::NodeArg* node_arg = output_defs[index]; - const ONNX_NAMESPACE::TypeProto* type_proto = node_arg->TypeAsProto(); + const onnxruntime::NodeArg* node_arg = output_defs[index]; + const ONNX_NAMESPACE::TypeProto* type_proto = node_arg->TypeAsProto(); - if (type_proto == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo output does not have a type"); - } + if (type_proto == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo output does not have a type"); + } - auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto); - *type_info = type_info_ret.release(); - return nullptr; - API_IMPL_END + auto type_info_ret = OrtTypeInfo::FromTypeProto(*type_proto); + *type_info = type_info_ret.release(); + return nullptr; + }); } ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, _In_ size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out) { - API_IMPL_BEGIN - const auto* op_info = reinterpret_cast(info); - *is_constant = static_cast(op_info->TryGetConstantInput(gsl::narrow_cast(index), out)); - return nullptr; - API_IMPL_END + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + const auto* op_info = reinterpret_cast(info); + *is_constant = static_cast(op_info->TryGetConstantInput(gsl::narrow_cast(index), out)); + return nullptr; + }); }; ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_ char* out, _Inout_ size_t* size) { - API_IMPL_BEGIN - const auto* op_info = reinterpret_cast(info); + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + const auto* op_info = reinterpret_cast(info); - auto status = CopyStringToOutputArg(op_info->node().Name(), - "Output buffer is not large enough for ::OrtKernelInfo node name", out, size); + auto status = CopyStringToOutputArg(op_info->node().Name(), + "Output buffer is not large enough for ::OrtKernelInfo node name", out, size); - return onnxruntime::ToOrtStatus(status); - API_IMPL_END + return onnxruntime::ToOrtStatus(status); + }); } ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* info, _Outptr_ const OrtLogger** logger) { - API_IMPL_BEGIN - const auto* ep = reinterpret_cast(info)->GetExecutionProvider(); + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + const auto* ep = reinterpret_cast(info)->GetExecutionProvider(); - if (ep == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo does not have an execution provider"); - } - - const auto* ep_logger = ep->GetLogger(); - - if (ep_logger == nullptr) { - return OrtApis::CreateStatus(ORT_INVALID_GRAPH, - "::OrtKernelInfo cannot get a valid logger from " - "its execution provider"); - } - - *logger = reinterpret_cast(ep_logger); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetLogger, _In_ const OrtKernelContext* context, _Outptr_ const OrtLogger** logger) { - API_IMPL_BEGIN - const auto& kernel_ctx_logger = reinterpret_cast(context)->Logger(); - - *logger = reinterpret_cast(&kernel_ctx_logger); - return nullptr; - API_IMPL_END -} - -ORT_API_STATUS_IMPL(OrtApis::Logger_LogMessage, _In_ const OrtLogger* logger, OrtLoggingLevel log_severity_level, - _In_z_ const char* message, _In_z_ const ORTCHAR_T* file_path, int line_number, - _In_z_ const char* func_name) { - API_IMPL_BEGIN - const auto& actual_logger = *reinterpret_cast(logger); - const auto severity = static_cast(log_severity_level); - const auto log_data_type = onnxruntime::logging::DataType::SYSTEM; - - if (actual_logger.OutputIsEnabled(severity, log_data_type)) { -#ifdef _WIN32 - const std::string file_path_str = onnxruntime::ToUTF8String(file_path); - onnxruntime::CodeLocation location(file_path_str.c_str(), line_number, func_name); -#else - onnxruntime::CodeLocation location(file_path, line_number, func_name); -#endif + if (ep == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, "::OrtKernelInfo does not have an execution provider"); + } - onnxruntime::logging::Capture( - actual_logger, - severity, - onnxruntime::logging::Category::onnxruntime, - log_data_type, - location) - .Stream() - << message; - } + const auto* ep_logger = ep->GetLogger(); - return nullptr; - API_IMPL_END -} + if (ep_logger == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, + "::OrtKernelInfo cannot get a valid logger from " + "its execution provider"); + } -ORT_API_STATUS_IMPL(OrtApis::Logger_GetLoggingSeverityLevel, _In_ const OrtLogger* logger, _Out_ OrtLoggingLevel* out) { - API_IMPL_BEGIN - const auto& actual_logger = *reinterpret_cast(logger); - *out = static_cast(actual_logger.GetSeverity()); - return nullptr; - API_IMPL_END + *logger = reinterpret_cast(ep_logger); + return nullptr; + }); } -#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +#if ENABLE_CUSTOM_OP_API #include "core/framework/customregistry.h" namespace onnxruntime { - struct CustomOpKernel : OpKernel { CustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { if (op_.version > ORT_API_VERSION) { @@ -766,7 +817,8 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust if (input_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { def_builder.TypeConstraint(input_name, SUPPORTED_TENSOR_TYPES); } else { - def_builder.TypeConstraint(input_name, DataTypeImpl::TensorTypeFromONNXEnum(static_cast(input_type))->AsTensorType()); + def_builder.TypeConstraint(input_name, + DataTypeImpl::TensorTypeFromONNXEnum(static_cast(input_type))->AsTensorType()); } } @@ -776,7 +828,8 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { def_builder.TypeConstraint(output_name, SUPPORTED_TENSOR_TYPES); } else { - def_builder.TypeConstraint(output_name, DataTypeImpl::TensorTypeFromONNXEnum(static_cast(output_type))->AsTensorType()); + def_builder.TypeConstraint(output_name, + DataTypeImpl::TensorTypeFromONNXEnum(static_cast(output_type))->AsTensorType()); } } @@ -786,7 +839,8 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust def_builder.Provider(onnxruntime::kCpuExecutionProvider); } - KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, + std::unique_ptr& out) -> Status { out = std::make_unique(info, *op); return Status::OK(); }; @@ -891,8 +945,8 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vect "There must be one (and only one) dynamic typed input to the custom op. " "Its type info at runtime will be used to infer the type info of this dynamic typed output " "which is required for the success of the model loading step. " - "More than one dynamic typed inputs are currently not supported as differing types at runtime means the output type " - "cannot be inferred without which model loading cannot proceed."); + "More than one dynamic typed inputs are currently not supported as differing types at runtime " + "means the output type cannot be inferred without which model loading cannot proceed."); } } create_type_constraint(op, static_cast(output_count), static_cast(i), false); @@ -1003,7 +1057,8 @@ void InferOutputTypes(const InlinedVector& kernel_defs, if (tc_iter != type_constraints.end()) { if (tc_iter->second.size() > 1) { undef = elem_type; - } else if (tc_iter->second.size() != 1 || tc_iter->second[0] != DataTypeImpl::TensorTypeFromONNXEnum(elem_type)) { + } else if (tc_iter->second.size() != 1 || + tc_iter->second[0] != DataTypeImpl::TensorTypeFromONNXEnum(elem_type)) { matched = false; } } else { @@ -1030,7 +1085,8 @@ void InferOutputTypes(const InlinedVector& kernel_defs, if (tc_iter->second.size() > 1) { output_type->mutable_tensor_type()->set_elem_type(undef); } else { - output_type->mutable_tensor_type()->set_elem_type(tc_iter->second[0]->GetTypeProto()->tensor_type().elem_type()); + output_type->mutable_tensor_type()->set_elem_type( + tc_iter->second[0]->GetTypeProto()->tensor_type().elem_type()); } } break; @@ -1052,7 +1108,8 @@ common::Status CreateCustomRegistry(gsl::span op_domai // If domain is empty, it is assumed to be part of the ONNX domain if (!domain->domain_.empty()) { // Add it to the DomainToVersion ONNX map if it doesn't already exist - // For example, two sessions using the same session_options should not add the same custom op domain to the version map twice + // For example, two sessions using the same session_options should not add the same custom op domain + // to the version map twice auto& domain_to_version_range_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); const auto& domain_to_version_map = domain_to_version_range_instance.Map(); @@ -1099,12 +1156,13 @@ common::Status CreateCustomRegistry(gsl::span op_domai schemas.push_back(schema_iter.second); InlinedVector kernel_defs = std::move(kernel_def_map[schema_iter.first]); auto infer_fn = schemas.back().GetTypeAndShapeInferenceFunction(); - ONNX_NAMESPACE::InferenceFunction extended_infer_fn = [infer_fn, kernel_defs](ONNX_NAMESPACE::InferenceContext& infer_ctx) { - InferOutputTypes(kernel_defs, infer_ctx); - if (infer_fn) { - infer_fn(infer_ctx); - } - }; + ONNX_NAMESPACE::InferenceFunction extended_infer_fn = + [infer_fn, kernel_defs](ONNX_NAMESPACE::InferenceContext& infer_ctx) { + InferOutputTypes(kernel_defs, infer_ctx); + if (infer_fn) { + infer_fn(infer_ctx); + } + }; schemas.back().TypeAndShapeInferenceFunction(extended_infer_fn); } @@ -1163,4 +1221,4 @@ common::Status CreateCustomRegistry(gsl::span op_domai } } // namespace onnxruntime -#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +#endif // ENABLE_CUSTOM_OP_API From 85cef0af8cdd99687402fe631bdba3d63926de6b Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Wed, 31 Jan 2024 10:28:03 +0800 Subject: [PATCH 008/207] [js/webgpu] Support capture and replay for jsep (#18989) ### Description This PR expands the graph capture capability to JS EP, which is similar to #16081. But for JS EP, we don't use the CUDA Graph, instead, we records all gpu commands and replay them, which removes most of the cpu overhead to avoid the the situation that gpu waiting for cpu. mobilenetv2-12 becomes 3.7ms from 6ms on NV 3090 and becomes 3.38ms from 4.58ms on Intel A770. All limitations are similar with CUDA EP: 1. Models with control-flow ops (i.e. If, Loop and Scan ops) are not supported. 2. Usage of graph capture is limited to models where-in all ops in the model can be partitioned to the JS EP or CPU EP and no memory copy between them. 3. Shapes of inputs/outputs cannot change across inference calls. 4. IObinding is required. The usage is like below: Method 1: specify outputs buffers explicitly. ``` const sessionOptions = { executionProviders: [ { name: "webgpu", }, ], enableGraphCapture: true, }; const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions); // prepare the inputBuffer/outputBuffer ... ... const feeds = { 'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims }) }; const fetches = { 'output': ort.Tensor.fromGpuBuffer(outputBuffer, { dataType: 'float32', dims: [1, 1000] }) }; let results = await session.run(feeds, fetches); // The first run will begin to capture the graph. // update inputBuffer content ... ... results = = await session.run(feeds, fetches); // The 2ed run and after will directly call replay to execute the graph. ... ... session.release(); ``` Method 2: Don't specify outputs buffers explicitly. Internally, when graph capture is enabled, it will set all outputs location to 'gpu-buffer'. ``` const sessionOptions = { executionProviders: [ { name: "webgpu", }, ], enableGraphCapture: true, }; const session = await ort.InferenceSession.create('./models/mobilenetv2-12.onnx', sessionOptions); // prepare the inputBuffer ... ... const feeds = { 'input': ort.Tensor.fromGpuBuffer(inputBuffer, { dataType: 'float32', dims }) }; let results = await session.run(feeds); // The first run will begin to capture the graph. // update inputBuffer content ... ... results = = await session.run(feeds); // The 2ed run and after will directly call replay to execute the graph. ... ... session.release(); --- js/common/lib/inference-session.ts | 8 +- js/web/lib/wasm/binding/ort-wasm.d.ts | 25 ++- js/web/lib/wasm/jsep/backend-webgpu.ts | 100 ++++++++++- js/web/lib/wasm/jsep/init.ts | 8 +- .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 74 ++++++-- .../lib/wasm/jsep/webgpu/program-manager.ts | 15 +- js/web/lib/wasm/jsep/webgpu/types.ts | 2 + js/web/lib/wasm/session-options.ts | 12 ++ js/web/lib/wasm/wasm-core-impl.ts | 166 +++++++++++------- .../providers/js/js_execution_provider.cc | 49 +++++- .../core/providers/js/js_execution_provider.h | 18 +- .../core/providers/js/js_provider_factory.cc | 11 +- .../js/js_provider_factory_creator.h | 4 +- onnxruntime/core/session/inference_session.cc | 63 ++++--- .../core/session/provider_registration.cc | 2 +- onnxruntime/wasm/js_internal_api.js | 15 +- 16 files changed, 436 insertions(+), 136 deletions(-) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 1221b52cd4985..4f85c3b46e253 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -111,7 +111,7 @@ export declare namespace InferenceSession { optimizedModelFilePath?: string; /** - * Wether enable profiling. + * Whether enable profiling. * * This setting is a placeholder for a future use. */ @@ -154,6 +154,12 @@ export declare namespace InferenceSession { */ preferredOutputLocation?: OnnxValueDataLocation|{readonly [outputName: string]: OnnxValueDataLocation}; + /** + * Whether enable graph capture. + * This setting is available only in ONNXRuntime Web for WebGPU EP. + */ + enableGraphCapture?: boolean; + /** * Store configurations for a session. See * https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/session/ diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 24d7062c85fcb..5dd715191c830 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -13,6 +13,9 @@ export declare namespace JSEP { type ReleaseKernelFunction = (kernel: number) => void; type RunFunction = (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; + type CaptureBeginFunction = () => void; + type CaptureEndFunction = () => void; + type ReplayFunction = () => void; } export interface OrtWasmModule extends EmscriptenModule { @@ -128,7 +131,8 @@ export interface OrtWasmModule extends EmscriptenModule { jsepInit? (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, - releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; + releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction, captureBegin: JSEP.CaptureBeginFunction, + captureEnd: JSEP.CaptureEndFunction, replay: JSEP.ReplayFunction): void; /** * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). @@ -158,12 +162,6 @@ export interface OrtWasmModule extends EmscriptenModule { * @returns the GPU data ID for the registered GPU buffer. */ jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; - /** - * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. - * - * @param sessionId - specify the session ID. - */ - jsepUnregisterBuffers?: (sessionId: number) => void; /** * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. * @@ -183,9 +181,18 @@ export interface OrtWasmModule extends EmscriptenModule { (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; /** - * [exported from js_internal_api.js] Called when InferenceSession.run started. + * [exported from js_internal_api.js] Called when InferenceSession.run started. This function will be called before + * _OrtRun[WithBinding]() is called. + * @param sessionId - specify the session ID. + */ + jsepOnRunStart: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Release a session. This function will be called before _OrtReleaseSession() is + * called. + * @param sessionId - specify the session ID. + * @returns */ - jsepOnRunStart: () => void; + jsepOnReleaseSession: (sessionId: number) => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index a48fe99570abf..e1faecfc046e3 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -10,7 +10,14 @@ import {createView, TensorView} from './tensor-view'; import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; -import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, TimestampQuery} from './webgpu/types'; +import {ComputeContext, GpuData, ProgramInfo, ProgramInputTensorInfoDependency, SessionState, TimestampQuery} from './webgpu/types'; + +interface CommandInfo { + readonly kernelId: number; + readonly computePipeline: GPUComputePipeline; + readonly bindGroup: GPUBindGroup; + readonly dispatchGroup: [number, number, number]; +} interface KernelInfo { readonly kernelType: string; @@ -103,6 +110,13 @@ export class WebGpuBackend { */ programManager: ProgramManager; + /** + * representing the session ID of which is currently being run. + * `null` means no session is being run. + * only valid when session.run is executed. + */ + currentSessionId: number|null = null; + /** * representing the kernel ID of which is currently being computed (CPU code perspective). * `null` means no kernel is being computed. @@ -155,6 +169,16 @@ export class WebGpuBackend { queryType: TimestampQuery; env: Env; + sessionStatus: SessionState = 'default'; + /** + * a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session. + */ + capturedCommandList: Map = new Map(); + + /** + * a SessionID -> PendingKernelInfo[] mapping for profiling. + */ + private capturedPendingKernels: Map = new Map(); /** * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. @@ -228,6 +252,7 @@ export class WebGpuBackend { getComputePassEncoder(): GPUComputePassEncoder { if (!this.computePassEncoder) { + const commandEncoder = this.getCommandEncoder(); const computePassDescriptor: GPUComputePassDescriptor = {}; if (this.queryType === 'at-passes') { @@ -238,7 +263,7 @@ export class WebGpuBackend { }; } - this.computePassEncoder = this.getCommandEncoder().beginComputePass(computePassDescriptor); + this.computePassEncoder = commandEncoder.beginComputePass(computePassDescriptor); } return this.computePassEncoder; } @@ -494,7 +519,7 @@ export class WebGpuBackend { () => `[ProgramManager] run "${program.name}" (key=${key}) with ${normalizedDispatchGroup[0]}x${ normalizedDispatchGroup[1]}x${normalizedDispatchGroup[2]}`); - if (this.queryType !== 'none') { + if (this.queryType !== 'none' || this.sessionStatus === 'capturing') { const pendingKernelInfo: PendingKernelInfo = { kernelId: this.currentKernelId!, programName: artifact.programInfo.name, @@ -502,6 +527,9 @@ export class WebGpuBackend { outputTensorViews, }; this.pendingKernels.push(pendingKernelInfo); + + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + sessionPendingKernels!.push(pendingKernelInfo); } this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding); @@ -672,7 +700,71 @@ export class WebGpuBackend { } } } - onRunStart(): void { + + captureBegin(): void { + LOG_DEBUG('info', 'captureBegin'); + let sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); + let sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + if (!sessionCommandList) { + sessionCommandList = []; + this.capturedCommandList.set(this.currentSessionId!, sessionCommandList); + sessionPendingKernels = []; + this.capturedPendingKernels.set(this.currentSessionId!, sessionPendingKernels); + } + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'capturing'; + } + captureEnd(): void { + LOG_DEBUG('info', 'captureEnd'); + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'default'; + } + replay(): void { + LOG_DEBUG('info', 'replay'); + this.sessionStatus = 'replaying'; + const sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + const length = sessionCommandList!.length; + this.pendingKernels = []; + for (let i = 0; i < length; i++) { + const computePassEncoder = this.getComputePassEncoder(); + const command = sessionCommandList![i]; + this.writeTimestamp(this.pendingDispatchNumber * 2); + computePassEncoder.setPipeline(command.computePipeline); + computePassEncoder.setBindGroup(0, command.bindGroup); + computePassEncoder.dispatchWorkgroups(...command.dispatchGroup); + this.writeTimestamp(this.pendingDispatchNumber * 2 + 1); + this.pendingDispatchNumber++; + if (this.queryType !== 'none') { + this.pendingKernels.push(sessionPendingKernels![i]); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') { + this.endComputePass(); + } + if (this.pendingDispatchNumber >= this.maxDispatchNumber) { + this.flush(); + } + } + // flush the left commands before we change the status. + this.flush(); + this.sessionStatus = 'default'; + } + + onReleaseSession(sessionId: number): void { + this.unregisterBuffers(sessionId); + if (this.capturedCommandList.has(sessionId)) { + this.capturedCommandList.delete(sessionId); + } + if (this.capturedPendingKernels.has(sessionId)) { + this.capturedPendingKernels.delete(sessionId); + } + this.gpuDataManager.onReleaseSession(sessionId); + } + + onRunStart(sessionId: number): void { + this.currentSessionId = sessionId; this.setQueryType(); } } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index f1794d71579bf..786ae41646554 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -201,5 +201,11 @@ export const init = async(module: OrtWasmModule, env: Env, gpuAdapter: GPUAdapte contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); return backend.computeKernel(kernel, context, errors); - }); + }, + // jsepCaptureBegin + () => backend.captureBegin(), + // jsepCaptureEnd + () => backend.captureEnd(), + // jsepReplay + () => backend.replay()); }; diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 6f3d9a52d9f5d..c17bd1e1477ec 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -60,9 +60,15 @@ export interface GpuDataManager { unregisterExternalBuffer(buffer: GPUBuffer): void; /** - * destroy all gpu buffers. Call this when the session.release is called. + * destroy all gpu buffers. */ dispose(): void; + + /** + * release session related data. + * @param sessionId - specify the session ID. + */ + onReleaseSession(sessionId: number): void; } interface StorageCacheValue { @@ -139,6 +145,10 @@ class GpuDataManagerImpl implements GpuDataManager { // The external buffers registered users for IO Binding. private externalBuffers: Map; + // The pendingBuffers for capture graph. + // a SessionID -> GPUBuffer[] mapping. + private capturedPendingBuffers: Map; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); @@ -146,6 +156,7 @@ class GpuDataManagerImpl implements GpuDataManager { this.buffersForUploadingPending = []; this.buffersPending = []; this.externalBuffers = new Map(); + this.capturedPendingBuffers = new Map(); } upload(id: GpuDataId, data: Uint8Array): void { @@ -220,6 +231,9 @@ class GpuDataManagerImpl implements GpuDataManager { () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ id}, buffer is the same, skip.`); return id; + } else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) { + throw new Error(`Registering a different external buffer under graph capture mode is not supported yet. + Please use the previous external buffer!`); } this.externalBuffers.delete(previousBuffer); } else { @@ -312,20 +326,39 @@ class GpuDataManagerImpl implements GpuDataManager { buffer.destroy(); } this.buffersForUploadingPending = []; - for (const buffer of this.buffersPending) { - // eslint-disable-next-line no-bitwise - if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { - // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. - this.freeBuffers.get(buffer.size)!.push(buffer); + + if (this.buffersPending.length === 0) { + return; + } + + if (this.backend.sessionStatus === 'default') { + for (const buffer of this.buffersPending) { // eslint-disable-next-line no-bitwise - } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { - // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. - this.freeUniformBuffers.get(buffer.size)!.push(buffer); - } else { - buffer.destroy(); + if ((buffer.usage & GPUBufferUsage.STORAGE) === GPUBufferUsage.STORAGE) { + // Put the pending buffer to freeBuffers list instead of really destroying it for buffer reusing. + this.freeBuffers.get(buffer.size)!.push(buffer); + // eslint-disable-next-line no-bitwise + } else if ((buffer.usage & GPUBufferUsage.UNIFORM) === GPUBufferUsage.UNIFORM) { + // Put the pending buffer to freeUniformBuffers list instead of really destroying it for buffer reusing. + this.freeUniformBuffers.get(buffer.size)!.push(buffer); + } else { + buffer.destroy(); + } + } + this.buffersPending = []; + } else { + // Don't release intermediate tensors in non-default mode. + // TODO: reuse the storage buffers in non-default mode. + let capturedBuffers = this.capturedPendingBuffers.get(this.backend.currentSessionId!); + if (!capturedBuffers) { + capturedBuffers = []; + this.capturedPendingBuffers.set(this.backend.currentSessionId!, capturedBuffers); } + for (const buffer of this.buffersPending) { + capturedBuffers.push(buffer); + } + this.buffersPending = []; } - this.buffersPending = []; } dispose() { @@ -344,9 +377,26 @@ class GpuDataManagerImpl implements GpuDataManager { storage.gpuData.buffer.destroy(); }); + this.capturedPendingBuffers.forEach((buffers) => { + buffers.forEach(buffer => { + buffer.destroy(); + }); + }); this.storageCache = new Map(); this.freeBuffers = new Map(); this.freeUniformBuffers = new Map(); + this.capturedPendingBuffers = new Map(); + } + + onReleaseSession(sessionId: number) { + // release the captured pending buffers. + const pendingBuffers = this.capturedPendingBuffers.get(sessionId); + if (pendingBuffers) { + pendingBuffers.forEach(buffer => { + buffer.destroy(); + }); + this.capturedPendingBuffers.delete(sessionId); + } } } diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index 72eb9713e26a8..9d05f607f817f 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -38,7 +38,6 @@ export class ProgramManager { const device = this.backend.device; const computePassEncoder = this.backend.getComputePassEncoder(); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2); - computePassEncoder.setPipeline(buildArtifact.computePipeline); const entries = []; for (const input of inputs) { entries.push({binding: entries.length, resource: {buffer: input.buffer}}); @@ -51,8 +50,20 @@ export class ProgramManager { } const bindGroup = device.createBindGroup( {layout: buildArtifact.computePipeline.getBindGroupLayout(0), entries, label: buildArtifact.programInfo.name}); - computePassEncoder.setBindGroup(0, bindGroup); + if (this.backend.sessionStatus === 'capturing') { + const commandInfo = { + kernelId: this.backend.currentKernelId!, + computePipeline: buildArtifact.computePipeline, + bindGroup, + dispatchGroup + }; + const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!); + sessionCommandList!.push(commandInfo); + } + + computePassEncoder.setPipeline(buildArtifact.computePipeline); + computePassEncoder.setBindGroup(0, bindGroup); computePassEncoder.dispatchWorkgroups(...dispatchGroup); this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1); this.backend.pendingDispatchNumber++; diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 789ac70a6913a..a34b6190b7244 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -5,6 +5,8 @@ import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; +export type SessionState = 'default'|'capturing'|'replaying'; + export enum GpuDataType { default = 0, upload = 1, diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 41ab2d52ca209..48eac57494726 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -168,6 +168,18 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n setExecutionProviders(sessionOptionsHandle, sessionOptions.executionProviders, allocs); } + if (sessionOptions.enableGraphCapture !== undefined) { + if (typeof sessionOptions.enableGraphCapture !== 'boolean') { + throw new Error(`enableGraphCapture must be a boolean value: ${sessionOptions.enableGraphCapture}`); + } + const keyDataOffset = allocWasmString('enableGraphCapture', allocs); + const valueDataOffset = allocWasmString(sessionOptions.enableGraphCapture.toString(), allocs); + if (wasm._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== 0) { + checkLastError( + `Can't set a session config entry: 'enableGraphCapture' - ${sessionOptions.enableGraphCapture}.`); + } + } + if (sessionOptions.freeDimensionOverrides) { for (const [name, value] of Object.entries(sessionOptions.freeDimensionOverrides)) { if (typeof name !== 'string') { diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 046336dc9cac0..37b9ed6a1002f 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -139,7 +139,7 @@ type IOBindingState = { */ type SessionMetadata = [ inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], - bindingState: IOBindingState|null + bindingState: IOBindingState|null, enableGraphCapture: boolean, inputOutputBound: boolean ]; const activeSessions = new Map(); @@ -235,6 +235,8 @@ export const createSession = async( const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); + const enableGraphCapture = !!options?.enableGraphCapture; + const inputNames = []; const outputNames = []; const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; @@ -256,12 +258,20 @@ export const createSession = async( outputNames.push(nameString); if (!BUILD_DEFS.DISABLE_WEBGPU) { + if (enableGraphCapture && options?.preferredOutputLocation === undefined) { + outputPreferredLocations.push('gpu-buffer'); + continue; + } const location = typeof options?.preferredOutputLocation === 'string' ? options.preferredOutputLocation : options?.preferredOutputLocation?.[nameString] ?? 'cpu'; if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { throw new Error(`Not supported preferred output location: ${location}.`); } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${ + location}. Only 'gpu-buffer' location is supported when enableGraphCapture is true.`); + } outputPreferredLocations.push(location); } } @@ -281,7 +291,9 @@ export const createSession = async( }; } - activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); + activeSessions.set( + sessionHandle, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState, enableGraphCapture, false]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -313,13 +325,16 @@ export const releaseSession = (sessionId: number): void => { if (!session) { throw new Error(`cannot release session. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture] = session; if (ioBindingState) { + if (enableGraphCapture) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); + } wasm._OrtReleaseBinding(ioBindingState.handle); } - wasm.jsepUnregisterBuffers?.(sessionId); + wasm.jsepOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -328,70 +343,75 @@ export const releaseSession = (sessionId: number): void => { }; export const prepareInputOutputTensor = - (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): - void => { - if (!tensor) { - tensorHandles.push(0); - return; - } + (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number, + enableGraphCapture = false): void => { + if (!tensor) { + tensorHandles.push(0); + return; + } - const wasm = getInstance(); + const wasm = getInstance(); - const dataType = tensor[0]; - const dims = tensor[1]; - const location = tensor[3]; + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; - let rawData: number; - let dataByteLength: number; + let rawData: number; + let dataByteLength: number; - if (dataType === 'string' && location === 'gpu-buffer') { - throw new Error('String tensor is not supported on GPU.'); - } + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } - if (location === 'gpu-buffer') { - const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; - const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; - dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; - rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); - } else { - const data = tensor[2]; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - let dataIndex = rawData / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); - } - } else { - dataByteLength = data.byteLength; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); - } - } + if (enableGraphCapture && location !== 'gpu-buffer') { + throw new Error( + `External buffer must be provided for input/output index ${index} when enableGraphCapture is true.`); + } - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, - dataLocationStringToEnum(location)); - if (tensor === 0) { - checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; + + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); } - tensorHandles.push(tensor); - } finally { - wasm.stackRestore(stack); + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); } - }; + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(location)); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } + }; /** * perform inference run @@ -404,7 +424,12 @@ export const run = async( if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + const sessionHandle = session[0]; + const inputNamesUTF8Encoded = session[1]; + const outputNamesUTF8Encoded = session[2]; + const ioBindingState = session[3]; + const enableGraphCapture = session[4]; + const inputOutputBound = session[5]; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -427,13 +452,15 @@ export const run = async( // create input tensors for (let i = 0; i < inputCount; i++) { - prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + prepareInputOutputTensor( + inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i], enableGraphCapture); } // create output tensors for (let i = 0; i < outputCount; i++) { prepareInputOutputTensor( - outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i], + enableGraphCapture); } let inputValuesIndex = inputValuesOffset / 4; @@ -449,7 +476,7 @@ export const run = async( wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState && !inputOutputBound) { const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; if (inputNamesUTF8Encoded.length !== inputCount) { @@ -486,9 +513,12 @@ export const run = async( } } } + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, true]); } - wasm.jsepOnRunStart?.(); + wasm.jsepOnRunStart?.(sessionHandle); let errorCode: number; if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( @@ -595,10 +625,12 @@ export const run = async( } } - if (ioBindingState) { + if (ioBindingState && !enableGraphCapture) { wasm._OrtClearBoundOutputs(ioBindingState.handle); + activeSessions.set( + sessionId, + [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, false]); } - return output; } finally { wasm.stackRestore(beforeRunStack); diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 0448487e6faec..799d4172f2b64 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -3,6 +3,7 @@ #include "js_execution_provider.h" +#include #include #include #include @@ -681,9 +682,13 @@ std::unique_ptr RegisterKernels() { using namespace js; -JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info) +JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options) : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)}, preferred_data_layout_{info.data_layout} { + if (session_options) { + enable_graph_capture_ = session_options->config_options.GetConfigOrDefault("enableGraphCapture", "false") == "true"; + LOGS_DEFAULT(VERBOSE) << "Graph capture enable: " << enable_graph_capture_; + } } std::vector JsExecutionProvider::CreatePreferredAllocators() { @@ -751,4 +756,46 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer JsExecutionProvider::~JsExecutionProvider() { } +Status JsExecutionProvider::OnRunStart() { + if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; + EM_ASM({ Module.jsepCaptureBegin(); }); + } + return Status::OK(); +} + +Status JsExecutionProvider::OnRunEnd(bool sync_stream) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { + if (IsGraphCaptureAllowed()) { + EM_ASM({ Module.jsepCaptureEnd(); }); + is_graph_captured_ = true; + } else { + IncrementRegularRunCountBeforeGraphCapture(); + } + } + + return Status::OK(); +} + +bool JsExecutionProvider::IsGraphCaptureEnabled() const { + return enable_graph_capture_; +} + +bool JsExecutionProvider::IsGraphCaptured() const { + return is_graph_captured_; +} + +Status JsExecutionProvider::ReplayGraph() { + ORT_ENFORCE(IsGraphCaptured()); + EM_ASM({ Module.jsepReplay(); }); + return Status::OK(); +} + +bool JsExecutionProvider::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +} + +void JsExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 39d43498c0717..91a3256ec2bd5 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -5,6 +5,7 @@ #pragma once #include "core/framework/execution_provider.h" +#include "core/framework/session_options.h" #include "core/graph/constants.h" #include "core/providers/providers.h" @@ -38,7 +39,7 @@ struct JsExecutionProviderInfo { class JsExecutionProvider : public IExecutionProvider { public: - JsExecutionProvider(const JsExecutionProviderInfo& info); + JsExecutionProvider(const JsExecutionProviderInfo& info, const SessionOptions* session_options); ~JsExecutionProvider() override; std::vector> GetCapability( @@ -57,7 +58,22 @@ class JsExecutionProvider : public IExecutionProvider { bool ConcurrentRunSupported() const override { return false; } std::vector CreatePreferredAllocators() override; + + Status OnRunStart() override; + Status OnRunEnd(bool sync_stream) override; + + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured() const override; + Status ReplayGraph() override; + + private: + bool IsGraphCaptureAllowed() const; + void IncrementRegularRunCountBeforeGraphCapture(); DataLayout preferred_data_layout_; + bool enable_graph_capture_ = false; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_provider_factory.cc b/onnxruntime/core/providers/js/js_provider_factory.cc index 5b7329a87cf6a..cbdf99f702150 100644 --- a/onnxruntime/core/providers/js/js_provider_factory.cc +++ b/onnxruntime/core/providers/js/js_provider_factory.cc @@ -10,21 +10,22 @@ namespace onnxruntime { struct JsProviderFactory : IExecutionProviderFactory { - JsProviderFactory(const ProviderOptions& provider_options) - : info_{provider_options} { + JsProviderFactory(const ProviderOptions& provider_options, const SessionOptions* session_options) + : info_{provider_options}, session_options_(session_options) { } std::unique_ptr CreateProvider() override { - return std::make_unique(info_); + return std::make_unique(info_, session_options_); } private: JsExecutionProviderInfo info_; + const SessionOptions* session_options_; }; std::shared_ptr JsProviderFactoryCreator::Create( - const ProviderOptions& provider_options) { - return std::make_shared(provider_options); + const ProviderOptions& provider_options, const SessionOptions* session_options) { + return std::make_shared(provider_options, session_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_provider_factory_creator.h b/onnxruntime/core/providers/js/js_provider_factory_creator.h index dbabe255c2d7b..510b0fb4248ca 100644 --- a/onnxruntime/core/providers/js/js_provider_factory_creator.h +++ b/onnxruntime/core/providers/js/js_provider_factory_creator.h @@ -9,9 +9,11 @@ #include "core/providers/providers.h" namespace onnxruntime { +struct SessionOptions; struct JsProviderFactoryCreator { - static std::shared_ptr Create(const ProviderOptions& provider_options); + static std::shared_ptr Create(const ProviderOptions& provider_options, + const SessionOptions* session_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 39f47c09f2402..cae714954f72f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -145,28 +145,30 @@ static bool HasMemcpyNodes(const Graph& graph) { return false; } -static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { - bool nodes_on_cpu_and_cuda_eps_only = true; +static bool AreAllComputeNodesAssignedToCudaOrJsEp(const Graph& graph) { + bool nodes_on_cpu_and_cuda_and_js_eps_only = true; for (const auto& node : graph.Nodes()) { const auto& node_provider = node.GetExecutionProviderType(); // Empty node provider means CPU EP if (!node_provider.empty() && - !(node_provider == kCudaExecutionProvider || node_provider == kRocmExecutionProvider) && + !(node_provider == kCudaExecutionProvider || + node_provider == kRocmExecutionProvider || + node_provider == kJsExecutionProvider) && node_provider != kCpuExecutionProvider) { - nodes_on_cpu_and_cuda_eps_only = false; + nodes_on_cpu_and_cuda_and_js_eps_only = false; break; } } - // If we see nodes assigned to EPs other than CPU or CUDA + // If we see nodes assigned to EPs other than CPU, or CUDA/JS // (or) if there are Memcpy nodes, then all compute nodes have - // not been parititoned to the CUDA EP. + // not been parititoned to the CUDA/JS EP. // We allow CPU EPs to show up in the EP list as long as thre is no Memcpy // involved as shape subgraphs will be forced onto CPU and these will not have // Memcpy nodes involved. - return nodes_on_cpu_and_cuda_eps_only && !HasMemcpyNodes(graph); + return nodes_on_cpu_and_cuda_and_js_eps_only && !HasMemcpyNodes(graph); } static bool AreAllNodesInMainGraphAssignedToOneEp(const Graph& graph, ProviderType provider) { @@ -1715,8 +1717,7 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); - // Currently CUDA graph is only considered by CUDA EP and TRT EP, and - // HIP graph is only considered by ROCM EP. + // Currently graph capture is only considered by CUDA EP, TRT EP, ROCM EP and JS EP. // // Check for CUDA EP: // If the CUDA EP is part of the providers list for this session AND @@ -1730,6 +1731,12 @@ common::Status InferenceSession::Initialize() { // All the graph nodes have been assigned to the TRT EP, // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). // + // Check for JS EP: + // If the JS EP is part of the providers list for this session AND + // The JS EP is configured to do a graph capture AND + // All the "compute" graph nodes have been assigned to the JS EP, + // Then the JS EP is cached for triggering a ReplayGraph() in Run(). + // // Check for ROCM EP: // If the ROCM EP is part of the providers list for this session AND // The ROCM EP is configured to do a graph capture AND @@ -1739,48 +1746,54 @@ common::Status InferenceSession::Initialize() { std::vector graph_support_ep_list = { onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider, - onnxruntime::kRocmExecutionProvider}; + onnxruntime::kRocmExecutionProvider, + onnxruntime::kJsExecutionProvider}; for (auto& it : graph_support_ep_list) { auto* target_ep = execution_providers_.Get(it); if (target_ep && target_ep->IsGraphCaptureEnabled()) { - // CUDA/HIP Graphs can't work with control flow nodes + // Graphs capture can't work with control flow nodes if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user " - << "as the model has control flow nodes which can't be supported by CUDA/HIP Graphs."; + LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " + << "as the model has control flow nodes which can't be supported by " + << target_ep->Type(); ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA/HIP Graph feature as requested by the user " - "as the model has control flow nodes which can't be supported by CUDA/HIP Graphs.")); + "This session cannot use the graph capture feature as requested by the user " + "as the model has control flow nodes which can't be supported by" + + target_ep->Type())); } if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || - strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0) { - // Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes + strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0 || + strcmp(target_ep->Type().c_str(), onnxruntime::kJsExecutionProvider) == 0) { + // Ensure that all nodes have been partitioned to CUDA/JS or CPU EP && there are no memcpy nodes // The reasoning behind this logic is that certain shape nodes will be forced onto CPU // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP // which is all we care about. - if (!AreAllComputeNodesAssignedToCudaEp(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user " - << " as all compute graph nodes have not been partitioned to the CUDA/HIP EP."; + if (!AreAllComputeNodesAssignedToCudaOrJsEp(graph)) { + LOGS(*session_logger_, ERROR) << "This session cannot use the graph capture feature as requested by the user " + << " as all compute graph nodes have not been partitioned to the " + << target_ep->Type(); ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA/HIP Graph feature as requested by the user " - " as all compute graph nodes have not been partitioned to the CUDA/HIP EP.")); + "This session cannot use the graph capture feature as requested by the user " + " as all compute graph nodes have not been partitioned to the " + + target_ep->Type())); } // Log a warning for the user to know that there are shape subgraphs that will execute on CPU if (HasShapeSubgraphNodes(graph)) { LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " - << "Use the CUDA/HIP Graph feature with caution. " + << "Use the graph capture feature with caution. " << "As long as the intermediate shapes produced in the model " - << "using the representative input used to capture the CUDA/HIP graph, " + << "using the representative input used to capture the graph, " << "will match the shapes produced in the model for other inputs " << "of the same shape as the representative input (common case), " - << "it is safe to use the CUDA/HIP Graph feature."; + << "it is safe to use the graph capture feature."; } } else { // Following code path is for TRT EP currently. diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 86b3d01c640a3..964355956b4ab 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -145,7 +145,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, if (options->value.config_options.TryGetConfigEntry("preferredLayout", preferred_layout)) { provider_options["preferred_layout"] = preferred_layout; } - options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options)); + options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); #endif diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 7e9c0a6f99c32..cbc60c70b57aa 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -24,7 +24,7 @@ Module['unmountExternalData'] = () => { /** * init JSEP */ -Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel) => { +Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, releaseKernel, runKernel, captureBegin, captureEnd, replay) => { Module.jsepBackend = backend; Module.jsepAlloc = alloc; Module.jsepFree = free; @@ -33,6 +33,9 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module.jsepCreateKernel = createKernel; Module.jsepReleaseKernel = releaseKernel; Module.jsepRunKernel = runKernel; + Module.jsepCaptureBegin = captureBegin; + Module.jsepCaptureEnd = captureEnd; + Module.jsepReplay = replay; // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) // It removes some overhead in cwarp() and ccall() that we don't need. @@ -181,16 +184,16 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { return backend['registerBuffer'](sessionId, index, buffer, size); }; - Module['jsepUnregisterBuffers'] = sessionId => { - backend['unregisterBuffers'](sessionId); - }; Module['jsepGetBuffer'] = (dataId) => { return backend['getBuffer'](dataId); }; Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; - Module['jsepOnRunStart'] = () => { - return backend['onRunStart'](); + Module['jsepOnReleaseSession'] = sessionId => { + backend['onReleaseSession'](sessionId); + }; + Module['jsepOnRunStart'] = sessionId => { + return backend['onRunStart'](sessionId); }; }; From d73131cf0f4eecd5f639c40d8fb6fad4efeaf4ef Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Wed, 31 Jan 2024 13:05:08 +0800 Subject: [PATCH 009/207] [js/webgpu] Use DataType as uniform cpu type (#19281) This saves turning data type to string by tensorDataTypeEnumToString. --- js/web/lib/wasm/jsep/backend-webgpu.ts | 18 ++++++----- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 7 +++-- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 8 +++-- .../ops/3rd-party/conv_backprop_webgpu.ts | 8 +++-- .../ops/3rd-party/matmul_packed_webgpu.ts | 7 +++-- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 30 +++++++++---------- js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts | 5 ++-- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 5 ++-- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 5 ++-- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 13 ++++---- js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/einsum.ts | 7 +++-- js/web/lib/wasm/jsep/webgpu/ops/expand.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 7 +++-- .../wasm/jsep/webgpu/ops/gather-elements.ts | 7 +++-- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 6 ++-- js/web/lib/wasm/jsep/webgpu/ops/gemm.ts | 6 ++-- .../lib/wasm/jsep/webgpu/ops/instance-norm.ts | 14 +++++---- js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 5 ++-- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 5 ++-- .../jsep/webgpu/ops/multi-head-attentiion.ts | 7 +++-- js/web/lib/wasm/jsep/webgpu/ops/pad.ts | 7 ++--- js/web/lib/wasm/jsep/webgpu/ops/pool.ts | 20 +++++++------ js/web/lib/wasm/jsep/webgpu/ops/range.ts | 5 ++-- .../lib/wasm/jsep/webgpu/ops/reduce-shared.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/reduce.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/resize.ts | 7 +++-- .../wasm/jsep/webgpu/ops/skip-layer-norm.ts | 8 ++--- js/web/lib/wasm/jsep/webgpu/ops/slice.ts | 6 ++-- js/web/lib/wasm/jsep/webgpu/ops/softmax.ts | 3 +- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 5 ++-- js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 3 +- js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/where.ts | 5 ++-- js/web/lib/wasm/jsep/webgpu/types.ts | 3 +- 37 files changed, 148 insertions(+), 108 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index e1faecfc046e3..58efa795dba48 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -3,7 +3,7 @@ import {Env, Tensor, TRACE, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common'; -import {tensorDataTypeEnumToString} from '../wasm-common'; +import {DataType, tensorDataTypeEnumToString} from '../wasm-common'; import {configureLogger, LOG_DEBUG} from './log'; import {createView, TensorView} from './tensor-view'; @@ -453,10 +453,10 @@ export class WebGpuBackend { return; } // https://www.w3.org/TR/WGSL/#alignof - const sizeOfElement = v.type === 'float16' ? 2 : 4; + const sizeOfElement = v.type === DataType.float16 ? 2 : 4; let sizeOfVecOrMat; let baseAlignment; - if (v.type === 'float16') { + if (v.type === DataType.float16) { baseAlignment = data.length > 4 ? 16 : (data.length > 2 ? 8 : data.length * sizeOfElement); sizeOfVecOrMat = data.length > 4 ? 16 : sizeOfElement * data.length; } else { @@ -470,7 +470,7 @@ export class WebGpuBackend { // SizeOf(vec4). For float16 type, when data.length > 4, the uniform variable is of type // array,N>, where N = Math.ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte // length is N * SizeOf(mat2x4). - const elementPerVecOrMat = v.type === 'float16' ? 8 : 4; + const elementPerVecOrMat = v.type === DataType.float16 ? 8 : 4; currentOffset += data.length > 4 ? Math.ceil(data.length / elementPerVecOrMat) * sizeOfVecOrMat : data.length * sizeOfElement; }); @@ -483,15 +483,17 @@ export class WebGpuBackend { programUniforms.forEach((v, i) => { const offset = offsets[i]; const data = typeof v.data === 'number' ? [v.data] : v.data; - if (v.type === 'int32') { + if (v.type === DataType.int32) { new Int32Array(arrayBuffer, offset, data.length).set(data); - } else if (v.type === 'uint32') { + } else if (v.type === DataType.uint32) { new Uint32Array(arrayBuffer, offset, data.length).set(data); - } else if (v.type === 'float16') { + } else if (v.type === DataType.float16) { // TODO: use Float16Array. new Uint16Array(arrayBuffer, offset, data.length).set(data); - } else { + } else if (v.type === DataType.float) { new Float32Array(arrayBuffer, offset, data.length).set(data); + } else { + throw new Error(`Unsupported uniform type: ${tensorDataTypeEnumToString(v.type)}`); } }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index e5ca3204d4433..bc39bd94e3072 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -19,6 +19,7 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; @@ -189,9 +190,9 @@ export const createConv2DMatMulProgramInfo = const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; const programUniforms: ProgramUniform[] = [ - {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, - {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'int32', data: attributes.strides}, - {type: 'int32', data: attributes.dilations} + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]}, + {type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations} ]; appendActivationUniformsData(attributes, programUniforms); programUniforms.push( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index e50733559dbe9..d18f8586dd071 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -19,6 +19,7 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; @@ -197,9 +198,10 @@ export const createConv2DTransposeMatMulProgramInfo = ]; const programUniforms: ProgramUniform[] = [ - {type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}, - {type: 'int32', data: attributes.strides}, {type: 'int32', data: attributes.dilations}, - {type: 'int32', data: filterDims}, {type: 'int32', data: pads} + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner}, {type: DataType.int32, data: attributes.strides}, + {type: DataType.int32, data: attributes.dilations}, {type: DataType.int32, data: filterDims}, + {type: DataType.int32, data: pads} ]; appendActivationUniformsData(attributes, programUniforms); programUniforms.push( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 380efc8bc577a..ba6776e9d8c94 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -17,6 +17,7 @@ // sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_webgpu.ts +import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; @@ -264,9 +265,10 @@ export const createConvTranspose2DProgramInfo = const outputChannelsPerGroup = wShape[1]; const programUniforms: ProgramUniform[] = [ - {type: 'int32', data: outputSize}, {type: 'uint32', data: strides}, {type: 'uint32', data: filterDims}, - {type: 'uint32', data: dilations}, {type: 'uint32', data: effectiveFilterDims}, {type: 'int32', data: pads}, - {type: 'uint32', data: inputChannelsPerGroup}, {type: 'uint32', data: outputChannelsPerGroup}, + {type: DataType.int32, data: outputSize}, {type: DataType.uint32, data: strides}, + {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations}, + {type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads}, + {type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims) ]; if (hasBias) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 00c1f86d67419..d9a8d59f731de 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -19,6 +19,7 @@ // // modified to fit the needs of the project +import {DataType} from '../../../../wasm-common'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; @@ -447,8 +448,10 @@ export const createMatmulProgramInfo = const bShapeTemp = [...outerDimsB, dimInner, dimBOuter / components]; const bRank = bShapeTemp.length; const outputShapeTemp = [batchSize, dimAOuter, dimBOuter / components]; - const programUniforms: ProgramUniform[] = - [{type: 'int32', data: dimAOuter}, {type: 'int32', data: dimBOuter}, {type: 'int32', data: dimInner}]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.int32, data: dimAOuter}, {type: DataType.int32, data: dimBOuter}, + {type: DataType.int32, data: dimInner} + ]; appendActivationUniformsData(activationAttributes, programUniforms); programUniforms.push( ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index f07a21a343fa8..2cfe6356dd6e7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {tensorDataTypeEnumToString} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ComputeContext, GpuDataType, ProgramUniform} from '../types'; @@ -241,9 +241,10 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView WG = Math.ceil(dComp / 8); } const elementsPerWG = Math.ceil(d / components / WG); - const tensorDataType = tensorDataTypeEnumToString(input.dataType) as ProgramUniform['type']; - const programUniforms: ProgramUniform[] = - [{type: tensorDataType, data: 1 / d}, {type: 'uint32', data: dComp}, {type: 'uint32', data: elementsPerWG}]; + const programUniforms: ProgramUniform[] = [ + {type: input.dataType, data: 1 / d}, {type: DataType.uint32, data: dComp}, + {type: DataType.uint32, data: elementsPerWG} + ]; const dataType = tensorTypeToWsglStorageType(input.dataType, components); const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -336,11 +337,10 @@ const computeAttentionProbs = y: Math.ceil(parameters.sequenceLength / TILE_SIZE), z: parameters.batchSize * parameters.numHeads }; - const tensorDataType = tensorDataTypeEnumToString(q.dataType) as ProgramUniform['type']; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: parameters.sequenceLength}, {type: 'uint32', data: vectorizedHeadSize}, - {type: 'uint32', data: parameters.totalSequenceLength}, {type: 'uint32', data: parameters.kvSequenceLength}, - {type: tensorDataType, data: alpha} + {type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize}, + {type: DataType.uint32, data: parameters.totalSequenceLength}, + {type: DataType.uint32, data: parameters.kvSequenceLength}, {type: q.dataType, data: alpha} ]; const inputs = [q, key]; @@ -430,9 +430,9 @@ const computeVxAttentionScore = z: params.batchSize * params.numHeads }; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: params.sequenceLength}, {type: 'uint32', data: params.totalSequenceLength}, - {type: 'uint32', data: params.vHeadSize}, {type: 'uint32', data: params.numHeads}, - {type: 'uint32', data: params.vHiddenSize} + {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength}, + {type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads}, + {type: DataType.uint32, data: params.vHiddenSize} ]; const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -526,10 +526,10 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { }; const inputs = [context.inputs[0], context.inputs[1], context.inputs[2]]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: M}, {type: 'uint32', data: K}, {type: 'uint32', data: N}, - {type: 'uint32', data: parameters.numHeads}, {type: 'uint32', data: parameters.headSize}, - {type: 'uint32', data: parameters.hiddenSize}, - {type: 'uint32', data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} + {type: DataType.uint32, data: M}, {type: DataType.uint32, data: K}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.headSize}, + {type: DataType.uint32, data: parameters.hiddenSize}, + {type: DataType.uint32, data: parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize} ]; const getShaderSource = (shaderHelper: ShaderHelper) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts index 159b971636765..39b932375891b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/batch-norm.ts @@ -3,6 +3,7 @@ import {env} from 'onnxruntime-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -123,11 +124,11 @@ const createBatchNormInferenceProgramInfo = dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: useShapesUniforms ? [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(yShape), ] : [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ], }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 8e144a36dc1b0..51f0c76ed8824 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -179,7 +179,7 @@ const createBinaryOpProgramInfo = outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, programUniforms: [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, + {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, ...createTensorShapeVariables(a.dims), ...createTensorShapeVariables(b.dims), ...createTensorShapeVariables(outputShape), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 1bedf31ee4e38..3de57d5ac7f7c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -259,8 +259,9 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = return typeof mappedType === 'string' ? mappedType : mappedType[1]; }; -export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => - dims.length === 0 ? [] : [{type: 'uint32', data: dims}, {type: 'uint32', data: ShapeUtil.computeStrides(dims)}]; +export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => dims.length === 0 ? + [] : + [{type: DataType.uint32, data: dims}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dims)}]; /** * A helper function to get maximum vector size for specified data length diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index daa326b1a34e2..b06c9fb496d15 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -95,14 +96,14 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P let previousSum = 0; const inputDependencies: ProgramInputTensorInfoDependency[] = []; const inputRanks = []; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}]; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: outputSize}]; for (let i = 0; i < inputs.length; ++i) { previousSum += inputs[i].dims[adjustedAxis]; sizeInConcatAxis[i] = previousSum; inputRanks.push(inputs[i].dims.length); inputVars[i] = inputVariable(`input${i}`, dataType, inputRanks[i]); inputDependencies.push('rank'); - programUniforms.push({type: 'uint32', data: sizeInConcatAxis[i]}); + programUniforms.push({type: DataType.uint32, data: sizeInConcatAxis[i]}); } for (let i = 0; i < inputs.length; ++i) { programUniforms.push(...createTensorShapeVariables(inputs[i].dims)); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index c0aaaa7ce134b..3c2c3cc4e046c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; @@ -28,9 +29,10 @@ export const createGroupedConvProgramInfo = const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.dilations}, - {type: 'uint32', data: [attributes.strides[0], attributes.strides[1]]}, - {type: 'uint32', data: [attributes.pads[0], attributes.pads[1]]}, {type: 'uint32', data: outputChannelsPerGroup} + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.dilations}, + {type: DataType.uint32, data: [attributes.strides[0], attributes.strides[1]]}, + {type: DataType.uint32, data: [attributes.pads[0], attributes.pads[1]]}, + {type: DataType.uint32, data: outputChannelsPerGroup} ]; appendActivationUniformsData(attributes, programUniforms); programUniforms.push( @@ -127,8 +129,9 @@ export const createGroupedConvVectorizeProgramInfo = const outputShapeInShader = [outputShape[0], outputShape[1], outputShape[2], outputShape[3] / components]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: [attributes.strides[0], attributes.strides[1]]}, - {type: 'int32', data: [attributes.pads[0], attributes.pads[1]]} + {type: DataType.uint32, data: outputSize}, + {type: DataType.int32, data: [attributes.strides[0], attributes.strides[1]]}, + {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]} ]; appendActivationUniformsData(attributes, programUniforms); programUniforms.push( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts index 2ff909c30e62e..fb17202cd042f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts @@ -54,7 +54,7 @@ const createCumsumProgramInfo = outputs: [{dims: inputShape, dataType: inputType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: axis}, + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis}, ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape) ] diff --git a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts index 9e1f58bbfa127..19a009c2eb79b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/einsum.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -272,8 +273,10 @@ const createEinsumProgramInfo = // filter is added to make sure that dimValue is never 0. const programUniformsInit: ProgramUniform[] = uniformsSymbols.filter((symbol) => einsumEquation.symbolToInfo.has(symbol)) - .map((symbol) => ({type: 'uint32', data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); - programUniformsInit.push({type: 'uint32', data: outputSize}); + .map( + (symbol) => + ({type: DataType.uint32, data: einsumEquation.symbolToInfo.get(symbol)?.dimValue || 0})); + programUniformsInit.push({type: DataType.uint32, data: outputSize}); const programUniforms: ProgramUniform[] = inputShapes.map((dims, _) => [...createTensorShapeVariables(dims)]) .reduce((acc, inputProgramUniforms) => acc.concat(inputProgramUniforms), programUniformsInit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index dd18bd23a5912..f8fdb63160380 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -85,7 +85,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => }; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(outputShape) ]; return { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index e1dc9a5e0ab7d..60067c014613b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {MAX_CLIP, MIN_CLIP} from '../../util'; import {ProgramUniform} from '../types'; @@ -36,9 +37,11 @@ export const getActivationSnippet = (attributes: InternalActivationAttributes, v export const appendActivationUniformsData = (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { if (attributes.activation === 'Clip') { - programUniform.push({type: 'float32', data: attributes.clipMax!}, {type: 'float32', data: attributes.clipMin!}); + programUniform.push( + {type: DataType.float, data: attributes.clipMax!}, {type: DataType.float, data: attributes.clipMin!}); } else if (attributes.activation === 'HardSigmoid') { - programUniform.push({type: 'float32', data: attributes.alpha!}, {type: 'float32', data: attributes.beta!}); + programUniform.push( + {type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!}); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index a945954adcaa4..a2d4e3d28f7c5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -46,8 +47,10 @@ const createGatherElementsProgramInfo = const output = outputVariable('output', inputOutputDataType, outputShape.length); - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, + {type: DataType.uint32, data: axis} + ]; programUniforms.push(...createTensorShapeVariables(inputShape)); programUniforms.push(...createTensorShapeVariables(indicesShape)); programUniforms.push(...createTensorShapeVariables(outputShape)); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index e2a62c6655c72..f2c71a9cd4188 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -34,9 +34,9 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const outputSize = Math.ceil(ShapeUtil.size(outputShape) / components); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}, - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims), - ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, + {type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims), + ...createTensorShapeVariables(inputs[1].dims), ...createTensorShapeVariables(outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts index a0d4021516bf7..76302e1af2e53 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gemm.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {GemmUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -45,8 +46,9 @@ const createGemmProgramInfo = (inputs: readonly TensorView[], attributes: GemmAt } const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, {type: 'uint32', data: K}, - {type: 'float32', data: attributes.alpha}, {type: 'float32', data: attributes.beta} + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: K}, {type: DataType.float, data: attributes.alpha}, + {type: DataType.float, data: attributes.beta} ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; if (inputs.length === 3) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index a835c90bd5451..2096b898b5d40 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -25,7 +25,7 @@ const createInstanceNormProgramInfo = const inputShape = [xShape[0], xShape[1], normPackedSize]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: normSize}, {type: 'uint32', data: normPackedSize}]; + [{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}]; programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)); const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -132,8 +132,9 @@ const computeMean = const meanInputDependencies: ProgramInputTensorInfoDependency[] = ['type']; const meanProgramUniforms: ProgramUniform[] = [ - {type: 'uint32', data: wgSize}, {type: 'uint32', data: h}, {type: 'uint32', data: Math.floor(c / components)}, - {type: 'uint32', data: Math.floor(h * c / components)} + {type: DataType.uint32, data: wgSize}, {type: DataType.uint32, data: h}, + {type: DataType.uint32, data: Math.floor(c / components)}, + {type: DataType.uint32, data: Math.floor(h * c / components)} ]; const getMeanShaderSource = (shaderHelper: ShaderHelper) => { @@ -182,8 +183,9 @@ const computeMean = {inputs: [input], outputs: [-1]})[0]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: unitsOfWork}, {type: 'uint32', data: h}, - {type: 'uint32', data: Math.floor(c / components)}, {type: 'uint32', data: Math.floor(WG * c / components)} + {type: DataType.uint32, data: unitsOfWork}, {type: DataType.uint32, data: h}, + {type: DataType.uint32, data: Math.floor(c / components)}, + {type: DataType.uint32, data: Math.floor(WG * c / components)} ]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type', 'type']; const getShaderSource = (shaderHelper: ShaderHelper) => { @@ -246,7 +248,7 @@ const createInstanceNormNHWCProgramInfo = const components = getMaxComponents(C); const outputSize = ShapeUtil.size(outputShape) / components; const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: H}, {type: 'uint32', data: Math.floor(C / components)}]; + [{type: DataType.uint32, data: H}, {type: DataType.uint32, data: Math.floor(C / components)}]; const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; // first compute mean const channelScaleShift = computeMean(context, inputs[0], inputs[1], inputs[2], N, H, C, attributes.epsilon); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 3c9f6ce71bb67..3f73d9cb7c5bc 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -49,8 +49,9 @@ const createLayerNormProgramInfo = const components = getMaxComponents(normSize); const inputDependencies: ProgramInputTensorInfoDependency[] = ['type', 'type']; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: normCount}, {type: 'float32', data: normSize}, - {type: 'uint32', data: Math.floor(normSize / components)}, {type: 'float32', data: attributes.epsilon} + {type: DataType.uint32, data: normCount}, {type: DataType.float, data: normSize}, + {type: DataType.uint32, data: Math.floor(normSize / components)}, + {type: DataType.float, data: attributes.epsilon} ]; if (bias) { inputDependencies.push('type'); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 188b88b2510d8..b263451b99134 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; @@ -29,8 +30,8 @@ export const createNaiveMatmulProgramInfo = const outputShapeInShader = [batchSize, M, N]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N}, - {type: 'uint32', data: K} + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: M}, {type: DataType.uint32, data: N}, + {type: DataType.uint32, data: K} ]; appendActivationUniformsData(activationAttributes, programUniforms); programUniforms.push( diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts index 6d22e3780efd9..5c5c849d99811 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -238,8 +239,10 @@ const addBiasTranspose = hiddenSize: number, biasOffset: number) => { const outputShape = [batchSize, sequenceLength, hiddenSize]; const outputSize = ShapeUtil.size(outputShape); - const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'uint32', data: biasOffset}, {type: 'uint32', data: hiddenSize}]; + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: biasOffset}, + {type: DataType.uint32, data: hiddenSize} + ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('qkv_with_bias', qkv.dataType, outputShape); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index c65b741e1105a..9f5e60773f080 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, tensorDataTypeEnumToString} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; @@ -153,10 +153,9 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr const inputDims = inputs[0].dims; const outputSize = ShapeUtil.size(outputShape); const programUniforms: ProgramUniform[] = - [{type: 'uint32', data: outputSize}, {type: 'uint32', data: attributes.pads}]; + [{type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.pads}]; if (attributes.mode === 0) { - const tensorDataType = tensorDataTypeEnumToString(inputs[0].dataType) as ProgramUniform['type']; - programUniforms.push({type: tensorDataType, data: attributes.value}); + programUniforms.push({type: inputs[0].dataType, data: attributes.value}); } programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape)); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 9e9b361c1af1c..70b8acc3146a0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -3,6 +3,7 @@ import {env} from 'onnxruntime-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -56,7 +57,8 @@ const getUniformAndPadInfo = ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: outputSize}, - programUniforms: [{type: 'uint32', data: reduceSize}] + programUniforms: [{type: DataType.uint32, data: reduceSize}] }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index e8851ac546942..123eb38a1fb93 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -101,7 +101,7 @@ export const createReduceProgramInfo = outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape), + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(outputShape) ] }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index f68526acc0e63..edfd856aeb850 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -641,9 +642,9 @@ const createResizeProgramInfo = outputs: [{dims: outputShape, dataType: inputTensor.dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, - {type: 'float32', data: scales}, - {type: 'float32', data: roi}, + {type: DataType.uint32, data: outputSize}, + {type: DataType.float, data: scales}, + {type: DataType.float, data: roi}, ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(outputShape), ] diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 509a722f4b52a..7be9ceec6bc65 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -88,10 +88,10 @@ const createSkipLayerNormProgramInfo = const components = getMaxComponents(hiddenSize); const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, - {type: 'uint32', data: components}, - {type: 'uint32', data: hiddenSize}, - {type: 'float32', data: attributes.epsilon}, + {type: DataType.uint32, data: outputSize}, + {type: DataType.uint32, data: components}, + {type: DataType.uint32, data: hiddenSize}, + {type: DataType.float, data: attributes.epsilon}, ]; const getShaderSource = (shaderHelper: ShaderHelper) => { const uniformsArray: UniformsArrayType = [ diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 5212c6475dce0..6baa634f69f82 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -155,9 +155,9 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice ]; const programUniforms: ProgramUniform[] = [ - {type: 'uint32', data: outputSize}, {type: 'uint32', data: starts}, {type: 'int32', data: signs}, - {type: 'uint32', data: steps}, ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: starts}, + {type: DataType.int32, data: signs}, {type: DataType.uint32, data: steps}, + ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts index 324dc3af1a710..6f8bfa08d7b62 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/softmax.ts @@ -5,6 +5,7 @@ // performance limitations when the reduced axis is long. Need to add // a optimized codepath for this. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -136,7 +137,7 @@ const createSoftmaxProgramInfo = (input: TensorView, attributes: SoftmaxAttribut getRunData: () => ({ outputs: [{dims: shape, dataType: input.dataType}], dispatchGroup: {x: rows}, - programUniforms: [{type: 'uint32', data: packedCols}] + programUniforms: [{type: DataType.uint32, data: packedCols}] }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index b8582614fa214..0b703de2ffa1c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -72,7 +73,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; let previousSum = 0; - const programUniforms: ProgramUniform[] = [{type: 'uint32', data: inputSize}]; + const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: inputSize}]; for (let i = 0; i < attributes.numOutputs; i++) { previousSum += attributes.splitSizes[i]; sizeInSplitAxis[i] = previousSum; @@ -82,7 +83,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split outputs[i] = outputVariable(`output${i}`, dataType, outputShape); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } - programUniforms.push({type: 'uint32', data: sizeInSplitAxis}); + programUniforms.push({type: DataType.uint32, data: sizeInSplitAxis}); programUniforms.push(...createTensorShapeVariables(inputShape)); outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape))); const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index 90a36a7bec2a9..b080767d2faac 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -80,7 +80,7 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape) ], }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index ab9a9ac8dd1f0..920da04398832 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -65,7 +66,7 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: 'uint32', data: outputSize}, + {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape), ], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 76929efb32537..1accfac18b876 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -53,7 +53,7 @@ const createElementwiseProgramInfo = dispatchGroup: {x: Math.ceil(ShapeUtil.size(inputTensors[0].dims) / 64 /* workgroup size */ / 4 /* vec size */)}, programUniforms: [ - {type: 'uint32', data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, + {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(input.dims) / 4)}, ], }) }); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 2ef9637bcda5e..51e8f56c229bd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -98,8 +98,9 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}, programUniforms: [ - {type: 'uint32', data: vecSize}, ...createTensorShapeVariables(dimsC), ...createTensorShapeVariables(dimsA), - ...createTensorShapeVariables(dimsB), ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC), + ...createTensorShapeVariables(dimsA), ...createTensorShapeVariables(dimsB), + ...createTensorShapeVariables(outputShape) ], }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index a34b6190b7244..ba5b84fcfe067 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {DataType} from '../../wasm-common'; import {TensorView} from '../tensor-view'; import {ShaderHelper} from './ops/common'; @@ -26,7 +27,7 @@ export interface TensorInfo { } export interface ProgramUniform { - type: 'int32'|'float16'|'float32'|'uint32'; + type: DataType; data: number|readonly number[]; } From dd1f6ccc45e2cc852a5ff496019ce5d532855d76 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 30 Jan 2024 21:06:21 -0800 Subject: [PATCH 010/207] [js/webgpu] resolve codescan alert (#19343) ### Description resolve codescan alert: https://github.com/microsoft/onnxruntime/security/code-scanning/17687 --- js/web/lib/wasm/jsep/backend-webgpu.ts | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 58efa795dba48..4b544595d76bb 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -705,13 +705,11 @@ export class WebGpuBackend { captureBegin(): void { LOG_DEBUG('info', 'captureBegin'); - let sessionCommandList = this.capturedCommandList.get(this.currentSessionId!); - let sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); - if (!sessionCommandList) { - sessionCommandList = []; - this.capturedCommandList.set(this.currentSessionId!, sessionCommandList); - sessionPendingKernels = []; - this.capturedPendingKernels.set(this.currentSessionId!, sessionPendingKernels); + if (!this.capturedCommandList.get(this.currentSessionId!)) { + this.capturedCommandList.set(this.currentSessionId!, []); + } + if (!this.capturedPendingKernels.get(this.currentSessionId!)) { + this.capturedPendingKernels.set(this.currentSessionId!, []); } // flush the left commands before we change the status. this.flush(); From 4562c910fed0d8446497388415d103be40b65157 Mon Sep 17 00:00:00 2001 From: petermcaughan Date: Tue, 30 Jan 2024 21:53:18 -0800 Subject: [PATCH 011/207] Whisper Crash Fix (#19345) ### Description There is a current bug in the BeamSearch implementation of T5, GPT, and Whisper due to an interaction between two PRs merged in the past 7 months. First PR/code change is the addition of BeamSearchScorer GPU implementation. This PR accelerates some operations by executing them in the GPU and not the CPU. The approach for this code change didn't utilize a cudaStream when copying one particular variable from GPU to CPU (see nullptr value here: [[link](https://github.com/microsoft/onnxruntime/blob/b65d3d0a5374daa3bc9272c2c02763a8428660db/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h#L213)]). The second PR/code change was the alteration to utilize a cudaStream to initialize various memory buffers in BeamSearch (see `stream` included as the last argument in these allocations [[link](https://github.com/microsoft/onnxruntime/blob/d1431e1b78fb81bf90fdc58c9118cb011171f387/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h#L25)]). During the in-between period of these two PRs, I believe neither allocation utilized a stream and were thus synchronized. Once the latter PR was merged, the copy became desynchronized with the initialization due to different streams. The fix for this is to reintroduce the same stream into the copy operation added in the first PR. ### Motivation and Context This does not happen reliably on every hardware with every script due to the race condition nature, but the bug completely breaks ORT execution with a BeamSearch model. --------- Co-authored-by: Peter McAughan --- onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h | 2 +- onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h | 2 +- .../contrib_ops/cpu/transformers/beam_search_impl_whisper.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index dc72a038c3d58..b18e122980eda 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -258,7 +258,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index cd891a9508019..8f5cdc97f27e5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -214,7 +214,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 4d6643c68a98b..72e6d3930a548 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -226,7 +226,7 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe cpu_state.sequences.InitDevice(beam_state.sequences_device); ORT_RETURN_IF_ERROR(this->device_copy_int32_func_(beam_state.sequences_device.subspan(0, beam_state.sequences_device.size() / 2), cpu_state.sequences_space.subspan(0, cpu_state.sequences_space.size() / 2), - nullptr, + this->ort_stream_, DeviceCopyDirection::hostToDevice)); } From 3262e8df2f57c649ff0a75ccdc92e07ddaa8e4c4 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 30 Jan 2024 22:11:25 -0800 Subject: [PATCH 012/207] Introduce a Nominal Checkpoint for On-Device Training (#19232) --- .../core/flatbuffers/checkpoint_version.h | 4 +- .../ort_flatbuffers_py/fbs/ModuleState.py | 10 +- .../schema/ort_training_checkpoint.fbs | 4 + .../schema/ort_training_checkpoint.fbs.h | 20 +- .../core/graph/graph_flatbuffers_utils.cc | 8 + .../testdata/training_api/checkpoint.ckpt | Bin 1590376 -> 1590376 bytes .../training_api/custom_ops/checkpoint | Bin 456 -> 464 bytes .../testdata/training_api/nominal_checkpoint | Bin 0 -> 280 bytes .../training_api/ort_format/checkpoint | Bin 1608 -> 1632 bytes .../python/orttraining_pybind_state.cc | 14 +- .../python/training/api/checkpoint_state.py | 2 + .../orttraining/python/training/api/module.py | 3 + .../orttraining/python/training/artifacts.py | 66 ++-- .../training/onnxblock/checkpoint_utils.py | 15 +- .../orttraining_test_ort_apis_onnxblock.py | 23 ++ .../orttraining_test_ort_apis_py_bindings.py | 324 +++++++++-------- .../test/training_api/core/checkpoint_test.cc | 39 ++- .../training_api/core/training_api_tests.cc | 161 +++++++++ .../training_api/core/training_capi_tests.cc | 75 ++++ .../orttraining/training_api/checkpoint.cc | 33 +- .../orttraining/training_api/checkpoint.h | 3 +- .../include/onnxruntime_training_c_api.h | 5 +- .../include/onnxruntime_training_cxx_api.h | 5 + .../include/onnxruntime_training_cxx_inline.h | 17 +- .../orttraining/training_api/module.cc | 329 ++++++++++++------ orttraining/orttraining/training_api/module.h | 42 ++- .../onnxruntime_training_c_api.cc | 25 +- .../orttraining/training_api/optimizer.cc | 32 +- .../orttraining/training_api/optimizer.h | 14 +- .../training_api/training_session.cc | 11 +- 30 files changed, 973 insertions(+), 311 deletions(-) create mode 100644 onnxruntime/test/testdata/training_api/nominal_checkpoint diff --git a/onnxruntime/core/flatbuffers/checkpoint_version.h b/onnxruntime/core/flatbuffers/checkpoint_version.h index 6cad27c35024b..e6ee20bf508ce 100644 --- a/onnxruntime/core/flatbuffers/checkpoint_version.h +++ b/onnxruntime/core/flatbuffers/checkpoint_version.h @@ -13,7 +13,9 @@ namespace onnxruntime { // The format includes support for the ModuleState (stores the module parameters), OptimizerGroups // (stores the optimizer states), and PropertyBag // (stores custom user properties with support for int64, float and strings). -constexpr const int kCheckpointVersion = 1; +// Version 2: Introduces the On-Device Training nominal checkpoint state. +// Changes include the addition of the is_nominal_state field in the checkpoint's ModuleState. +constexpr const int kCheckpointVersion = 2; /** * @brief Check if the given checkpoint version is supported in this build diff --git a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py index 2be826fee2cc3..19c6b1b6f2753 100644 --- a/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py +++ b/onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/ModuleState.py @@ -74,9 +74,17 @@ def FrozenParamsIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) return o == 0 -def ModuleStateStart(builder): builder.StartObject(2) + # ModuleState + def IsNominalState(self): + o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) + if o != 0: + return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos)) + return False + +def ModuleStateStart(builder): builder.StartObject(3) def ModuleStateAddRequiresGradParams(builder, requiresGradParams): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(requiresGradParams), 0) def ModuleStateStartRequiresGradParamsVector(builder, numElems): return builder.StartVector(4, numElems, 4) def ModuleStateAddFrozenParams(builder, frozenParams): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(frozenParams), 0) def ModuleStateStartFrozenParamsVector(builder, numElems): return builder.StartVector(4, numElems, 4) +def ModuleStateAddIsNominalState(builder, isNominalState): builder.PrependBoolSlot(2, isNominalState, 0) def ModuleStateEnd(builder): return builder.EndObject() diff --git a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs index c8244b0a426f3..94757fa6d5bf5 100644 --- a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs +++ b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs @@ -8,6 +8,10 @@ namespace onnxruntime.fbs; table ModuleState { requires_grad_params:[Tensor]; frozen_params:[Tensor]; + // Nominal state just means that the Tensors in the ModuleState + // are empty. i.e. The tensors are treated as named entities + // without any meaningful data. + is_nominal_state:bool; } table ParameterOptimizerState { diff --git a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h index 48feebb197694..d205c5eb8f409 100644 --- a/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h +++ b/onnxruntime/core/flatbuffers/schema/ort_training_checkpoint.fbs.h @@ -39,7 +39,8 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ModuleStateBuilder Builder; enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { VT_REQUIRES_GRAD_PARAMS = 4, - VT_FROZEN_PARAMS = 6 + VT_FROZEN_PARAMS = 6, + VT_IS_NOMINAL_STATE = 8 }; const flatbuffers::Vector> *requires_grad_params() const { return GetPointer> *>(VT_REQUIRES_GRAD_PARAMS); @@ -47,6 +48,9 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const flatbuffers::Vector> *frozen_params() const { return GetPointer> *>(VT_FROZEN_PARAMS); } + bool is_nominal_state() const { + return GetField(VT_IS_NOMINAL_STATE, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_REQUIRES_GRAD_PARAMS) && @@ -55,6 +59,7 @@ struct ModuleState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyOffset(verifier, VT_FROZEN_PARAMS) && verifier.VerifyVector(frozen_params()) && verifier.VerifyVectorOfTables(frozen_params()) && + VerifyField(verifier, VT_IS_NOMINAL_STATE) && verifier.EndTable(); } }; @@ -69,6 +74,9 @@ struct ModuleStateBuilder { void add_frozen_params(flatbuffers::Offset>> frozen_params) { fbb_.AddOffset(ModuleState::VT_FROZEN_PARAMS, frozen_params); } + void add_is_nominal_state(bool is_nominal_state) { + fbb_.AddElement(ModuleState::VT_IS_NOMINAL_STATE, static_cast(is_nominal_state), 0); + } explicit ModuleStateBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -84,23 +92,27 @@ struct ModuleStateBuilder { inline flatbuffers::Offset CreateModuleState( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset>> requires_grad_params = 0, - flatbuffers::Offset>> frozen_params = 0) { + flatbuffers::Offset>> frozen_params = 0, + bool is_nominal_state = false) { ModuleStateBuilder builder_(_fbb); builder_.add_frozen_params(frozen_params); builder_.add_requires_grad_params(requires_grad_params); + builder_.add_is_nominal_state(is_nominal_state); return builder_.Finish(); } inline flatbuffers::Offset CreateModuleStateDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector> *requires_grad_params = nullptr, - const std::vector> *frozen_params = nullptr) { + const std::vector> *frozen_params = nullptr, + bool is_nominal_state = false) { auto requires_grad_params__ = requires_grad_params ? _fbb.CreateVector>(*requires_grad_params) : 0; auto frozen_params__ = frozen_params ? _fbb.CreateVector>(*frozen_params) : 0; return onnxruntime::fbs::CreateModuleState( _fbb, requires_grad_params__, - frozen_params__); + frozen_params__, + is_nominal_state); } struct ParameterOptimizerState FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { diff --git a/onnxruntime/core/graph/graph_flatbuffers_utils.cc b/onnxruntime/core/graph/graph_flatbuffers_utils.cc index 8e962403556dd..6d7ed94b2956d 100644 --- a/onnxruntime/core/graph/graph_flatbuffers_utils.cc +++ b/onnxruntime/core/graph/graph_flatbuffers_utils.cc @@ -392,6 +392,14 @@ Status LoadOrtTensorOrtFormat(const fbs::Tensor& fbs_tensor, const AllocatorPtr ort_tensor = onnxruntime::Tensor( tensor_dtype, TensorShape(tensor_dims->data(), tensor_dims->size()), allocator); + if (fbs_tensor.raw_data()->size() == 0U) { + // Empty tensor. Nothing to unpack. + // This check is necessary because an empty ort tensor will return a size of 1. + // As a result, the following call to UnpackTensor will fail since the src and + // dst sizes do not match (0 and 1 elements). + return Status::OK(); + } + // The tensor proto is used as a dummy here. The actual data is stored in the raw_data field of the flatbuffer. // The data is copied from the raw_data field to the ort_tensor. ONNX_NAMESPACE::TensorProto unused_tensor_proto; diff --git a/onnxruntime/test/testdata/training_api/checkpoint.ckpt b/onnxruntime/test/testdata/training_api/checkpoint.ckpt index d0b7d0deb654c52528df157aa4d4f8e01e2ea7e2..d1bc1f121c8e62e0a0372ffaa89b46721e8cff14 100644 GIT binary patch delta 86 zcmV~$%Mn9R06@|42@!<&M3!zZX4LUYvA8pqX+}@44Cll%#54TYBt4KsQpqHjhZIuE bQ!05$EseDDmQH&4$RMLkGRs$%i{HN89FP`o delta 86 zcmV~$$qj-~007Z^5J8YtS&mLGB-TMG947fknxQAk;JqtuSKN+Y2s;uakyIX%$y0JE bh!0_(h l|Nk3z@G;8s0p$gN0uViXKo*DxRK;Kd#5NOstta+K004Sm4F&)J delta 74 zcmcb>e1ch&hk=2?-zCHuNHKi*|Np-LkY=2yDlgdpWUv77G?*ZeZ34s&lWQ5RCvRbt QnH<9?voV5=k%@r;0ImKJL;wH) diff --git a/onnxruntime/test/testdata/training_api/nominal_checkpoint b/onnxruntime/test/testdata/training_api/nominal_checkpoint new file mode 100644 index 0000000000000000000000000000000000000000..2eadfeece2ed94623d746e7f09e4b9b47286af53 GIT binary patch literal 280 zcmWe(U|{ff32_Ee3>*x6K+FQf3=ASb%mfl-;9?M9-~o!U16g3e2$Bb35MTk~6+oN= z#1>F~+5i9l1%PZuAm#yL5E~hQ^rj^n=_O?*7K6l>5l|1Z7o@&CH8VYtY%LGEV(`VQh2n7femLF_Pu+W`QYV;m#^ literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/training_api/ort_format/checkpoint b/onnxruntime/test/testdata/training_api/ort_format/checkpoint index ab35c9ad5acdee3b2249d8ea06edbd1c981aed45..83ef6aa4c30de0ae711650e69c71ee5e0a1d3473 100644 GIT binary patch delta 141 zcmX@X^MEH+gn@y<-zCJEfq{!bfPsgBgMkIeVgRxj7(jANKnwyLP&p1Tn*pd6B3HtU zV4q=SV0Z)Md;R+VA0*ER#5zE}!^XST%#%e}M3`3no#@uaV)XC-|IH07Cm1K^u*gg{ JVX*>Q0svVH7We=F delta 123 zcmaFBbArc>hk=2?-zCHuNHKi*|Np-LkY)s84h9||W&vUln}Go&&IH7Bkk}3^3=A)r w85pkpp1h5@O)%~cLc0!7*k)s@HS^>p%mOUCfCf*NVRM?y#U{1cgXI7t0D&YLvj6}9 diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 0c2bfa19e1671..4ab8db8565bf9 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -802,6 +802,9 @@ void addObjectMethodsForTraining(py::module& m) { .def("copy_parameter_from", [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name, OrtValue& value) -> void { + if (state->module_checkpoint_state.is_nominal_state) { + ORT_THROW("Cannot copy parameter to a nominal state. Please load all the parameter states first"); + } auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); if (it == state->module_checkpoint_state.named_parameters.end()) { ORT_THROW("Parameter with name ", parameter_name, " does not exist."); @@ -811,6 +814,9 @@ void addObjectMethodsForTraining(py::module& m) { }) .def("get_parameter", [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + if (state->module_checkpoint_state.is_nominal_state) { + ORT_THROW("Cannot get parameter from a nominal state. Please load the parameter states first"); + } auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); if (it == state->module_checkpoint_state.named_parameters.end()) { ORT_THROW("Parameter with name ", parameter_name, " does not exist."); @@ -851,6 +857,9 @@ void addObjectMethodsForTraining(py::module& m) { return std::make_unique(optimizer_model_uri, state, providers, session_options); })) .def("optimizer_step", [](PyOptimizer* optimizer) -> void { + // In case the optimizer was constructed using a nominal checkpoint, + // the optimizer state construction is delayed until the first call to Optimizer::Step(). + // It is expected that the model parameter state is available at this point. ORT_THROW_IF_ERROR(optimizer->optimizer_->Step()); }) .def("set_learning_rate", [](PyOptimizer* optimizer, float lr) -> void { @@ -893,7 +902,7 @@ void addObjectMethodsForTraining(py::module& m) { "save_checkpoint", [](const std::vector& trainable_tensor_protos_pybytes, const std::vector& non_trainable_tensor_protos_pybytes, - const std::string& checkpoint_path) { + const std::string& checkpoint_path, const bool nominal_checkpoint) { std::vector trainable_tensor_protos(trainable_tensor_protos_pybytes.size()); std::vector non_trainable_tensor_protos(non_trainable_tensor_protos_pybytes.size()); @@ -914,7 +923,8 @@ void addObjectMethodsForTraining(py::module& m) { ORT_THROW_IF_ERROR(onnxruntime::training::api::SaveCheckpoint(trainable_tensor_protos, non_trainable_tensor_protos, - ToPathString(checkpoint_path))); + ToPathString(checkpoint_path), + nominal_checkpoint)); }); m.def("save_checkpoint", diff --git a/orttraining/orttraining/python/training/api/checkpoint_state.py b/orttraining/orttraining/python/training/api/checkpoint_state.py index ba95cd04fce7e..cc4e84111c47c 100644 --- a/orttraining/orttraining/python/training/api/checkpoint_state.py +++ b/orttraining/orttraining/python/training/api/checkpoint_state.py @@ -222,6 +222,8 @@ def __init__(self, state: C.CheckpointState): def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: """Loads the checkpoint state from the checkpoint file + The checkpoint file can either be the complete checkpoint or the nominal checkpoint. + Args: checkpoint_uri: The path to the checkpoint file. diff --git a/orttraining/orttraining/python/training/api/module.py b/orttraining/orttraining/python/training/api/module.py index f8f6b4322ce79..a87cd6fdd93cf 100644 --- a/orttraining/orttraining/python/training/api/module.py +++ b/orttraining/orttraining/python/training/api/module.py @@ -178,6 +178,9 @@ def get_parameters_size(self, trainable_only: bool = True) -> int: def copy_buffer_to_parameters(self, buffer: OrtValue, trainable_only: bool = True) -> None: """Copies the OrtValue buffer to the training session parameters. + In case the module was loaded from a nominal checkpoint, invoking this function is required + to load the updated parameters onto the checkpoint to complete it. + Args: buffer: The OrtValue buffer to copy to the training session parameters. """ diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index a57105545e114..7a4eb251bc5bc 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -43,7 +43,11 @@ def generate_artifacts( loss: Optional[Union[LossType, onnxblock.Block]] = None, optimizer: Optional[OptimType] = None, artifact_directory: Optional[Union[str, bytes, os.PathLike]] = None, - **extra_options, + prefix: str = "", + ort_format: bool = False, + custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None, + additional_output_names: Optional[List[str]] = None, + nominal_checkpoint: bool = False, ) -> None: """Generates artifacts required for training with ORT training api. @@ -63,11 +67,16 @@ def generate_artifacts( optimizer: The optimizer enum to be used for training. If None, no optimizer model is generated. artifact_directory: The directory to save the generated artifacts. If None, the current working directory is used. - prefix (str): The prefix to be used for the generated artifacts. If not specified, no prefix is used. - ort_format (bool): Whether to save the generated artifacts in ORT format or not. Default is False. - custom_op_library (str | os.PathLike): The path to the custom op library. - If not specified, no custom op library is used. - additional_output_names (List[str]): List of additional output names to be added to the training/eval model. + prefix: The prefix to be used for the generated artifacts. If not specified, no prefix is used. + ort_format: Whether to save the generated artifacts in ORT format or not. Default is False. + custom_op_library: The path to the custom op library. + If not specified, no custom op library is used. + additional_output_names: List of additional output names to be added to the training/eval model in addition + to the loss output. Default is None. + nominal_checkpoint: Whether to generate the nominal checkpoint in addition to the complete checkpoint. + Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model + parameters. It can be used on the device to reduce overhead while constructing the training model + as well as to reduce the size of the checkpoint packaged with the on-device application. Raises: RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block` @@ -107,19 +116,19 @@ def __init__(self, _loss): self._loss = _loss def build(self, *inputs_to_loss): - if "additional_output_names" in extra_options: + if additional_output_names: # If additional output names is not a list, raise an error - if not isinstance(extra_options["additional_output_names"], list): + if not isinstance(additional_output_names, list): raise RuntimeError( - f"Unknown type provided for additional output names {type(extra_options['additional_output_names'])}. " + f"Unknown type provided for additional output names {type(additional_output_names)}. " "Expected additional output names to be a list of strings." ) loss_output = self._loss(*inputs_to_loss) if isinstance(loss_output, tuple): - return (*loss_output, *tuple(extra_options["additional_output_names"])) + return (*loss_output, *tuple(additional_output_names)) else: - return (loss_output, *tuple(extra_options["additional_output_names"])) + return (loss_output, *tuple(additional_output_names)) return self._loss(*inputs_to_loss) @@ -143,58 +152,57 @@ def build(self, *inputs_to_loss): eval_model = None model_params = None - custom_op_library = extra_options.get("custom_op_library", None) + custom_op_library_path = None if custom_op_library is not None: logging.info("Custom op library provided: %s", custom_op_library) - custom_op_library = pathlib.Path(custom_op_library) + custom_op_library_path = pathlib.Path(custom_op_library) with onnxblock.base(model), onnxblock.custom_op_library( - custom_op_library + custom_op_library_path ) if custom_op_library is not None else contextlib.nullcontext(): _ = training_block(*[output.name for output in model.graph.output]) training_model, eval_model = training_block.to_model_proto() model_params = training_block.parameters() - def _export_to_ort_format(model_path, output_dir, extra_options): - if extra_options.get("ort_format", False): - custom_op_library = extra_options.get("custom_op_library", None) - if custom_op_library is not None: - custom_op_library = pathlib.Path(custom_op_library) + def _export_to_ort_format(model_path, output_dir, ort_format, custom_op_library_path): + if ort_format: convert_onnx_models_to_ort( model_path, output_dir=output_dir, - custom_op_library_path=custom_op_library, + custom_op_library_path=custom_op_library_path, optimization_styles=[OptimizationStyle.Fixed], ) if artifact_directory is None: artifact_directory = pathlib.Path.cwd() - prefix = "" - if "prefix" in extra_options: - prefix = extra_options["prefix"] - logging.info("Using prefix %s for generated artifacts.", prefix) - artifact_directory = pathlib.Path(artifact_directory) + if prefix: + logging.info("Using prefix %s for generated artifacts.", prefix) + training_model_path = artifact_directory / f"{prefix}training_model.onnx" if os.path.exists(training_model_path): logging.info("Training model path %s already exists. Overwriting.", training_model_path) onnx.save(training_model, training_model_path) - _export_to_ort_format(training_model_path, artifact_directory, extra_options) + _export_to_ort_format(training_model_path, artifact_directory, ort_format, custom_op_library_path) logging.info("Saved training model to %s", training_model_path) eval_model_path = artifact_directory / f"{prefix}eval_model.onnx" if os.path.exists(eval_model_path): logging.info("Eval model path %s already exists. Overwriting.", eval_model_path) onnx.save(eval_model, eval_model_path) - _export_to_ort_format(eval_model_path, artifact_directory, extra_options) + _export_to_ort_format(eval_model_path, artifact_directory, ort_format, custom_op_library_path) logging.info("Saved eval model to %s", eval_model_path) checkpoint_path = artifact_directory / f"{prefix}checkpoint" if os.path.exists(checkpoint_path): logging.info("Checkpoint path %s already exists. Overwriting.", checkpoint_path) - onnxblock.save_checkpoint(training_block.parameters(), checkpoint_path) + onnxblock.save_checkpoint(training_block.parameters(), checkpoint_path, nominal_checkpoint=False) logging.info("Saved checkpoint to %s", checkpoint_path) + if nominal_checkpoint: + nominal_checkpoint_path = artifact_directory / f"{prefix}nominal_checkpoint" + onnxblock.save_checkpoint(training_block.parameters(), nominal_checkpoint_path, nominal_checkpoint=True) + logging.info("Saved nominal checkpoint to %s", nominal_checkpoint_path) # If optimizer is not specified, skip creating the optimizer model if optimizer is None: @@ -225,5 +233,5 @@ def _export_to_ort_format(model_path, output_dir, extra_options): optimizer_model_path = artifact_directory / f"{prefix}optimizer_model.onnx" onnx.save(optim_model, optimizer_model_path) - _export_to_ort_format(optimizer_model_path, artifact_directory, extra_options) + _export_to_ort_format(optimizer_model_path, artifact_directory, ort_format, custom_op_library_path) logging.info("Saved optimizer model to %s", optimizer_model_path) diff --git a/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py b/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py index bc50d4afa2fe1..de3453c630f9c 100644 --- a/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/checkpoint_utils.py @@ -6,18 +6,21 @@ import onnx -from onnxruntime.capi._pybind_state import get_model_after_loading_checkpoint as _internal_load_checkpoint_to_model -from onnxruntime.capi._pybind_state import save_checkpoint as _internal_save_checkpoint +from onnxruntime.capi._pybind_state import get_model_after_loading_checkpoint as _load_checkpoint_to_model +from onnxruntime.capi._pybind_state import save_checkpoint as _save_checkpoint def save_checkpoint( - parameters: Tuple[List[onnx.TensorProto], List[onnx.TensorProto]], path_to_checkpoint: Union[str, os.PathLike] + parameters: Tuple[List[onnx.TensorProto], List[onnx.TensorProto]], + path_to_checkpoint: Union[str, os.PathLike], + nominal_checkpoint: bool = False, ) -> None: """Saves the parameters to the checkpoint directory path_to_checkpoint. Args: parameters tuple(trainable_params, non_trainable_params): The parameters to save to the checkpoint file. - path_to_checkpoint (str): The path to the checkpoint directory. + path_to_checkpoint: The path to the checkpoint directory. + nominal_checkpoint: If True, the checkpoint is saved as a nominal checkpoint. Default is False. """ if parameters is None: @@ -26,7 +29,7 @@ def save_checkpoint( trainable_params, non_trainable_params = parameters trainable_params = [param.SerializeToString() for param in trainable_params] non_trainable_params = [param.SerializeToString() for param in non_trainable_params] - _internal_save_checkpoint(trainable_params, non_trainable_params, os.fspath(path_to_checkpoint)) + _save_checkpoint(trainable_params, non_trainable_params, os.fspath(path_to_checkpoint), nominal_checkpoint) def load_checkpoint_to_model(path_to_checkpoint: Union[str, os.PathLike], model: onnx.ModelProto) -> None: @@ -37,4 +40,4 @@ def load_checkpoint_to_model(path_to_checkpoint: Union[str, os.PathLike], model: model (onnx.ModelProto): The model to load the checkpoint to. """ - model.ParseFromString(_internal_load_checkpoint_to_model(os.fspath(path_to_checkpoint), model.SerializeToString())) + model.ParseFromString(_load_checkpoint_to_model(os.fspath(path_to_checkpoint), model.SerializeToString())) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py index 910ddb34e2b52..3d41c8678278c 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_onnxblock.py @@ -1047,3 +1047,26 @@ def build(self, input1, input2): with tempfile.TemporaryDirectory() as temp_dir: artifacts.generate_artifacts(onnx_model, loss=CustomLossBlock(), artifact_directory=temp_dir) + + +def test_save_nominal_checkpoint(): + device = "cpu" + batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10 + _, base_model = _get_models(device, batch_size, input_size, hidden_size, output_size) + + with tempfile.TemporaryDirectory() as temp_dir: + artifacts.generate_artifacts( + base_model, + requires_grad=["fc1.weight", "fc1.bias", "fc2.weight", "fc2.bias"], + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=temp_dir, + nominal_checkpoint=True, + ) + + assert os.path.exists(os.path.join(temp_dir, "checkpoint")) + assert os.path.exists(os.path.join(temp_dir, "nominal_checkpoint")) + assert ( + os.stat(os.path.join(temp_dir, "checkpoint")).st_size + > os.stat(os.path.join(temp_dir, "nominal_checkpoint")).st_size + ) diff --git a/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py index 34d8c24ccfab4..ce251b98447bf 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_ort_apis_py_bindings.py @@ -6,6 +6,7 @@ import os import pathlib import tempfile +from dataclasses import dataclass import numpy as np import onnx @@ -28,11 +29,22 @@ def build(self, output_name): return self.loss(output_name) +@dataclass +class Artifacts: + checkpoint_file_path: str + training_model_file_path: str + eval_model_file_path: str + optimizer_model_file_path: str + pt_model: torch.nn.Module + nominal_checkpoint_file_path: str | None = None + + def _create_training_artifacts( artifact_directory: str | os.PathLike, requires_grad: list[str] | None = None, frozen_params: list[str] | None = None, optimizer_type=artifacts.OptimType.AdamW, + nominal_checkpoint: bool = False, ): device = "cpu" batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10 @@ -51,14 +63,20 @@ def _create_training_artifacts( requires_grad=requires_grad, frozen_params=frozen_params, artifact_directory=artifact_directory, + nominal_checkpoint=nominal_checkpoint, ) training_model_file = os.path.join(artifact_directory, "training_model.onnx") eval_model_file = os.path.join(artifact_directory, "eval_model.onnx") optimizer_model_file = os.path.join(artifact_directory, "optimizer_model.onnx") checkpoint_file = os.path.join(artifact_directory, "checkpoint") + nominal_checkpoint_file = None + if nominal_checkpoint: + nominal_checkpoint_file = os.path.join(artifact_directory, "nominal_checkpoint") - return checkpoint_file, training_model_file, eval_model_file, optimizer_model_file, pt_model + return Artifacts( + checkpoint_file, training_model_file, eval_model_file, optimizer_model_file, pt_model, nominal_checkpoint_file + ) def test_train_step(): @@ -67,22 +85,16 @@ def test_train_step(): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - _, - pt_model, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state) + model = Module(artifacts.training_model_file_path, state) model.train() ort_loss = model(inputs, labels) # Calculate loss using pytorch model to compare it with Module's output. - pt_outputs = pt_model(torch.from_numpy(inputs)) + pt_outputs = artifacts.pt_model(torch.from_numpy(inputs)) loss_fn = torch.nn.CrossEntropyLoss() pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels).long()) @@ -95,17 +107,11 @@ def test_eval_step(): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - eval_model_file_path, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state, eval_model_file_path) + model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path) model.train() model(inputs, labels) @@ -121,18 +127,12 @@ def test_optimizer_step(optimizer_type): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - optimizer_model_file_path, - _, - ) = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) + artifacts = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module and Optimizer. - model = Module(training_model_file_path, state) - optimizer = Optimizer(optimizer_model_file_path, model) + model = Module(artifacts.training_model_file_path, state) + optimizer = Optimizer(artifacts.optimizer_model_file_path, model) model.train() old_flatten_params = model.get_contiguous_parameters() @@ -147,18 +147,12 @@ def test_optimizer_step(optimizer_type): @pytest.mark.parametrize("optimizer_type", [artifacts.OptimType.SGD, artifacts.OptimType.AdamW]) def test_get_and_set_lr(optimizer_type): with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - optimizer_model_file_path, - _, - ) = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) + artifacts = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module and Optimizer. - model = Module(training_model_file_path, state) - optimizer = Optimizer(optimizer_model_file_path, model) + model = Module(artifacts.training_model_file_path, state) + optimizer = Optimizer(artifacts.optimizer_model_file_path, model) # Test get and set learning rate. lr = optimizer.get_learning_rate() @@ -178,18 +172,11 @@ def test_scheduler_step(optimizer_type): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - optimizer_model_file_path, - _, - ) = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) - # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + artifacts = _create_training_artifacts(temp_dir, optimizer_type=optimizer_type) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module and Optimizer. - model = Module(training_model_file_path, state) - optimizer = Optimizer(optimizer_model_file_path, model) + model = Module(artifacts.training_model_file_path, state) + optimizer = Optimizer(artifacts.optimizer_model_file_path, model) scheduler = LinearLRScheduler(optimizer, 1, 2, 0.2) # Test get and set learning rate. @@ -212,17 +199,11 @@ def test_training_module_checkpoint(): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Training Module and Training Optimizer. - model = Module(training_model_file_path, state) + model = Module(artifacts.training_model_file_path, state) model.train() model(inputs, labels) @@ -237,7 +218,7 @@ def test_training_module_checkpoint(): # Assert the checkpoint parameters remain after saving. new_state = CheckpointState.load_checkpoint(checkpoint_save_path) - new_model = Module(training_model_file_path, new_state) + new_model = Module(artifacts.training_model_file_path, new_state) new_params = new_model.get_contiguous_parameters() @@ -252,23 +233,17 @@ def test_copy_buffer_to_parameters(trainable_only, optimizer_type): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - optimizer_model_file_path, - _, - ) = _create_training_artifacts( + artifacts = _create_training_artifacts( temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"], optimizer_type=optimizer_type, ) - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module and Optimizer. - model = Module(training_model_file_path, state) - optimizer = Optimizer(optimizer_model_file_path, model) + model = Module(artifacts.training_model_file_path, state) + optimizer = Optimizer(artifacts.optimizer_model_file_path, model) # Keep a copy of the parameters. old_output_params = model.get_contiguous_parameters(trainable_only=trainable_only) @@ -295,19 +270,13 @@ def test_copy_buffer_to_parameters(trainable_only, optimizer_type): def test_export_model_for_inferencing(): with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - eval_model_file_path, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state, eval_model_file_path) + model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path) # Export inference model inference_model_file_path = os.path.join(temp_dir, "inference_model.onnx") @@ -317,18 +286,12 @@ def test_export_model_for_inferencing(): def test_cuda_execution_provider(): with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state, device="cuda") + model = Module(artifacts.training_model_file_path, state, device="cuda") params = model.get_contiguous_parameters() # Check if parameters are moved to cuda. @@ -341,19 +304,13 @@ def test_cuda_execution_provider(): ) def test_add_get_property(property_value): with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - _ = Module(training_model_file_path, state) + _ = Module(artifacts.training_model_file_path, state) # Float values in python are double precision. # Convert to float32 to match the type of the property. @@ -367,8 +324,8 @@ def test_add_get_property(property_value): assert state.properties["property"] == property_value assert len(state.properties) == 1 - CheckpointState.save_checkpoint(state, checkpoint_file_path) - new_state = CheckpointState.load_checkpoint(checkpoint_file_path) + CheckpointState.save_checkpoint(state, artifacts.checkpoint_file_path) + new_state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) assert "property" in new_state.properties assert new_state.properties["property"] == property_value assert len(new_state.properties) == 1 @@ -376,21 +333,15 @@ def test_add_get_property(property_value): def test_get_input_output_names(): with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - eval_model_file_path, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state, eval_model_file_path) + model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path) - training_model = onnx.load(training_model_file_path) + training_model = onnx.load(artifacts.training_model_file_path) assert model.input_names() == [input.name for input in training_model.graph.input][:2] assert model.output_names() == [output.name for output in training_model.graph.output][:1] @@ -518,23 +469,18 @@ def test_train_step_with_ort_values(): labels = OrtValue.ortvalue_from_numpy(labels_np) with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - _, - _, - pt_model, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) + # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state) + model = Module(artifacts.training_model_file_path, state) model.train() ort_loss = model(inputs, labels) assert isinstance(ort_loss, OrtValue) # Calculate loss using pytorch model to compare it with Module's output. - pt_outputs = pt_model(torch.from_numpy(inputs_np)) + pt_outputs = artifacts.pt_model(torch.from_numpy(inputs_np)) loss_fn = torch.nn.CrossEntropyLoss() pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels_np).long()) @@ -549,17 +495,11 @@ def test_eval_step_with_ort_values(): labels = OrtValue.ortvalue_from_numpy(labels_np) with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - eval_model_file_path, - _, - _, - ) = _create_training_artifacts(temp_dir) + artifacts = _create_training_artifacts(temp_dir) # Create Checkpoint State. - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) # Create a Module. - model = Module(training_model_file_path, state, eval_model_file_path) + model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path) model.train() model(inputs, labels) @@ -572,26 +512,20 @@ def test_eval_step_with_ort_values(): @pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_get_and_set_parameter_values(device): with tempfile.TemporaryDirectory() as temp_dir: - ( - checkpoint_file_path, - training_model_file_path, - eval_model_file_path, - _, - pt_model, - ) = _create_training_artifacts( + artifacts = _create_training_artifacts( temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"] ) - state = CheckpointState.load_checkpoint(checkpoint_file_path) + state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) - model = Module(training_model_file_path, state, eval_model_file_path, device=device) + model = Module(artifacts.training_model_file_path, state, artifacts.eval_model_file_path, device=device) - state_dict = pt_model.state_dict() + state_dict = artifacts.pt_model.state_dict() assert len(state_dict) == len(state.parameters) for parameter_name, _ in state.parameters: assert parameter_name in state_dict - for name, pt_param in pt_model.named_parameters(): + for name, pt_param in artifacts.pt_model.named_parameters(): ort_param = state.parameters[name] assert ort_param.name == name assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data) @@ -612,7 +546,7 @@ def test_get_and_set_parameter_values(device): labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() loss = model(inputs, labels) assert loss is not None - for name, _ in pt_model.named_parameters(): + for name, _ in artifacts.pt_model.named_parameters(): ort_param = state.parameters[name] assert ort_param.name == name if name in ["fc1.weight", "fc1.bias"]: @@ -624,3 +558,111 @@ def test_get_and_set_parameter_values(device): state.parameters["fc1.weight"] = original_param assert np.allclose(state.parameters["fc1.weight"].data, original_param) + + +def test_model_construction_with_nominal_checkpoint(): + with tempfile.TemporaryDirectory() as temp_dir: + artifacts = _create_training_artifacts(temp_dir, nominal_checkpoint=True) + + nominal_state = CheckpointState.load_checkpoint(artifacts.nominal_checkpoint_file_path) + model_with_nominal_state = Module( + artifacts.training_model_file_path, nominal_state, artifacts.eval_model_file_path + ) + optimizer_with_nominal_state = Optimizer(artifacts.optimizer_model_file_path, model_with_nominal_state) + + inputs = torch.randn(64, 784).numpy() + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + + err_msg = "Please load the parameter states first" + + # Accessing the checkpoint parameter raises + state_dict = artifacts.pt_model.state_dict() + for param_name in state_dict: + assert param_name in nominal_state.parameters + with pytest.raises(Exception) as exc_info: + _ = nominal_state.parameters["fc1.weight"] + + assert err_msg in str(exc_info.value) + + err_msg = "Please load all the parameter states first" + with pytest.raises(Exception) as exc_info: + nominal_state.parameters["fc1.weight"] = np.ones((10, 10), dtype=np.float32) + + assert err_msg in str(exc_info.value) + + err_msg = "Please load the model parameters first." + + # Getting contiguous parameters raises + with pytest.raises(Exception) as exc_info: + _ = model_with_nominal_state.get_contiguous_parameters() + + assert err_msg in str(exc_info.value) + + # Train step raises + with pytest.raises(Exception) as exc_info: + model_with_nominal_state.train() + model_with_nominal_state(inputs, labels) + + assert err_msg in str(exc_info.value) + + # Optimizer step raises + with pytest.raises(Exception) as exc_info: + optimizer_with_nominal_state.step() + + assert err_msg in str(exc_info.value) + + # Eval step raises + with pytest.raises(Exception) as exc_info: + model_with_nominal_state.eval() + model_with_nominal_state(inputs, labels) + + assert err_msg in str(exc_info.value) + + # Get parameters size does not raise + params_size = model_with_nominal_state.get_parameters_size() + assert params_size > 0 + + +def test_train_with_nominal_checkpoint(): + with tempfile.TemporaryDirectory() as temp_dir: + artifacts = _create_training_artifacts(temp_dir, nominal_checkpoint=True) + + # Create Checkpoint State with nominal checkpoint as well as the complete checkpoint. + complete_state = CheckpointState.load_checkpoint(artifacts.checkpoint_file_path) + nominal_state = CheckpointState.load_checkpoint(artifacts.nominal_checkpoint_file_path) + + # Create a Module with both complete and nominal checkpoint states. + model_with_complete_state = Module(artifacts.training_model_file_path, complete_state) + model_with_nominal_state = Module(artifacts.training_model_file_path, nominal_state) + + optimizer_with_complete_state = Optimizer(artifacts.optimizer_model_file_path, model_with_complete_state) + optimizer_with_nominal_state = Optimizer(artifacts.optimizer_model_file_path, model_with_nominal_state) + + parameter_buffer = model_with_complete_state.get_contiguous_parameters() + model_with_nominal_state.copy_buffer_to_parameters(parameter_buffer, trainable_only=False) + + model_with_complete_state.train() + model_with_nominal_state.train() + + # Generate random data for testing. + inputs = torch.randn(64, 784).numpy() + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + + ort_loss_1 = model_with_complete_state(inputs, labels) + ort_loss_2 = model_with_nominal_state(inputs, labels) + + # Calculate loss using pytorch model to compare it with both the Modules' output. + pt_outputs = artifacts.pt_model(torch.from_numpy(inputs)) + loss_fn = torch.nn.CrossEntropyLoss() + pt_loss = loss_fn(pt_outputs, torch.from_numpy(labels).long()) + + assert np.allclose(ort_loss_1, ort_loss_2) + assert np.allclose(ort_loss_1, pt_loss.detach().numpy()) + + optimizer_with_complete_state.step() + optimizer_with_nominal_state.step() + + new_params_1 = model_with_complete_state.get_contiguous_parameters() + new_params_2 = model_with_nominal_state.get_contiguous_parameters() + + assert np.allclose(new_params_1.numpy(), new_params_2.numpy()) diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc index 1369c9c69865a..5c53addb853e4 100644 --- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc @@ -95,7 +95,8 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpoint_ThenLoad_CPU) { // Call Save APIs. PathString checkpoint_path{ ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("e2e_ckpt_save_cpu"))}; - ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path)); + ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path, + false /* nominal checkpoint */)); /// Phase 3 - Run load checkpoint APIs. /// And check the result comparable with initial parameter values. @@ -193,7 +194,8 @@ TEST(CheckpointApiTest, SaveOnnxModelAsCheckpointThenLoadFromBufferCPU) { // Call Save APIs. PathString checkpoint_path{ ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("e2e_ckpt_save_cpu"))}; - ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path)); + ASSERT_STATUS_OK(SaveCheckpoint(trainable_param_values, non_trainable_param_values, checkpoint_path, + false /* nominal checkpoint */)); /// Phase 3 - Run load checkpoint APIs. /// And check the result comparable with initial parameter values. @@ -435,4 +437,37 @@ TEST(CheckpointApiTest, SaveCustomPropertyAsCheckpoint_ThenLoad_CPU) { std::string restored_s_data = restored_property_bag.GetProperty(s_property_name); ASSERT_EQ(s_data, restored_s_data); } + +/** + * Loads a nominal checkpoint. Checks for nominal flag, and that the state is empty. + * Saves the checkpoint, and loads it again. Checks for nominal flag, and that the state is empty. + */ +TEST(CheckpointApiTest, LoadAndSaveNominalCheckpoint) { + PathString nominal_checkpoint_path{ORT_TSTR("testdata/training_api/nominal_checkpoint")}; + + CheckpointState checkpoint_state; + ASSERT_STATUS_OK(LoadCheckpoint(nominal_checkpoint_path, checkpoint_state)); + ASSERT_TRUE(checkpoint_state.module_checkpoint_state.is_nominal_state); + for (auto& [name, param] : checkpoint_state.module_checkpoint_state.named_parameters) { + ASSERT_TRUE(param->Data().IsTensor()); + // An empty tensor will have size 1. + ASSERT_EQ(param->Data().Get().Shape().Size(), 1); + } + + // Remove the temporary directory if it already exists. + auto ckpt_test_root_dir = ORT_TSTR("checkpointing_api_test_dir"); + TemporaryDirectory tmp_dir{ckpt_test_root_dir}; + PathString checkpoint_path{ + ConcatPathComponent(tmp_dir.Path(), ORT_TSTR("nominal_checkpoint_2"))}; + ASSERT_STATUS_OK(SaveCheckpoint(checkpoint_state, checkpoint_path, false)); + + CheckpointState checkpoint_state_2; + ASSERT_STATUS_OK(LoadCheckpoint(checkpoint_path, checkpoint_state_2)); + ASSERT_TRUE(checkpoint_state_2.module_checkpoint_state.is_nominal_state); + for (auto& [name, param] : checkpoint_state_2.module_checkpoint_state.named_parameters) { + ASSERT_TRUE(param->Data().IsTensor()); + // An empty tensor will have size 1. + ASSERT_EQ(param->Data().Get().Shape().Size(), 1); + } +} } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index 2170f7957e6a6..e2232687d0b07 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -537,6 +537,167 @@ TEST(TrainingApiTest, OptimStep) { } } +TEST(TrainingApiTest, ModuleAndOptimizerWithNominalState) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; + auto optim_uri = MODEL_FOLDER "adamw.onnx"; + + onnxruntime::training::api::CheckpointState complete_state; + onnxruntime::training::api::CheckpointState nominal_state; + auto complete_checkpoint_path = MODEL_FOLDER "checkpoint.ckpt"; + auto nominal_checkpoint_path = MODEL_FOLDER "nominal_checkpoint"; + ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpoint(complete_checkpoint_path, complete_state)); + ASSERT_STATUS_OK(onnxruntime::training::api::LoadCheckpoint(nominal_checkpoint_path, nominal_state)); + + ASSERT_FALSE(complete_state.module_checkpoint_state.is_nominal_state); + ASSERT_TRUE(nominal_state.module_checkpoint_state.is_nominal_state); + + onnxruntime::SessionOptions session_option; + std::unique_ptr env; + std::vector> providers; +#if defined(USE_CUDA) + providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider()); +#endif + ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::optional(onnxruntime::ToUTF8String(eval_model_uri)), + std::optional(onnxruntime::ToUTF8String(optim_uri))); + auto model_with_complete_state = std::make_unique( + model_identifier, &complete_state, session_option, + *env, providers); + auto model_with_nominal_state = std::make_unique( + model_identifier, &nominal_state, session_option, + *env, providers); + auto optim_with_complete_state = std::make_unique( + model_identifier, &complete_state, session_option, + *env, providers); + auto optim_with_nominal_state = std::make_unique( + model_identifier, &nominal_state, session_option, + *env, providers); + + // Before running the test, copy all the parameters to the nominal module. + ASSERT_EQ(model_with_complete_state->GetParametersSize(), model_with_nominal_state->GetParametersSize()); + int64_t params_size = static_cast(model_with_nominal_state->GetParametersSize()); + OrtValue params_buffer; + Tensor::InitOrtValue(DataTypeImpl::GetType(), {params_size}, + onnxruntime::test::TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + params_buffer); + ASSERT_STATUS_OK(model_with_complete_state->CopyParametersToBuffer(params_buffer, false)); + ASSERT_STATUS_OK(model_with_nominal_state->CopyBufferToParameters(params_buffer, false)); + + ASSERT_STATUS_OK(optim_with_nominal_state->ConstructOptimizerStateAndInputs()); + + OrtValue input, target; + GenerateRandomInput(std::array{2, 784}, input); + target = onnxruntime::test::CreateInputOrtValueOnCPU( + std::array{2}, std::vector(2, 1)); + auto data_loader = std::vector>(4, std::vector{input, target}); + + for (auto it = data_loader.begin(); it != data_loader.end(); ++it) { + std::vector& inputs = *it; + std::vector complete_fetches; + std::vector nominal_fetches; + ASSERT_STATUS_OK(model_with_complete_state->TrainStep(inputs, complete_fetches)); + ASSERT_STATUS_OK(model_with_nominal_state->TrainStep(inputs, nominal_fetches)); + + ASSERT_GT(complete_fetches.size(), 0); + for (size_t i = 0; i < complete_fetches.size(); ++i) { + ASSERT_TRUE(complete_fetches[i].IsTensor()); + ASSERT_TRUE(nominal_fetches[i].IsTensor()); + const Tensor& complete_tensor = complete_fetches[i].Get(); + const Tensor& nominal_tensor = nominal_fetches[i].Get(); + ASSERT_EQ(complete_tensor.Shape(), nominal_tensor.Shape()); + ASSERT_EQ(complete_tensor.DataType(), nominal_tensor.DataType()); + + std::vector complete_fetches_vec; + std::vector nominal_fetches_vec; +#if defined(USE_CUDA) + CudaOrtValueToCpuVec(complete_fetches[i], complete_fetches_vec); + CudaOrtValueToCpuVec(nominal_fetches[i], nominal_fetches_vec); +#else + CpuOrtValueToVec(complete_fetches[i], complete_fetches_vec); + CpuOrtValueToVec(nominal_fetches[i], nominal_fetches_vec); +#endif + + for (size_t j = 0; j < complete_fetches_vec.size(); ++j) { + ASSERT_EQ(complete_fetches_vec[j], nominal_fetches_vec[j]); + } + } + + ASSERT_STATUS_OK(optim_with_complete_state->Step()); + ASSERT_STATUS_OK(optim_with_nominal_state->Step()); + + for (auto& [name, param] : model_with_complete_state->NamedParameters()) { + ASSERT_TRUE(param->Data().IsTensor()); + ASSERT_TRUE(param->Gradient().IsTensor()); + ASSERT_TRUE(model_with_nominal_state->NamedParameters().at(name)->Data().IsTensor()); + ASSERT_TRUE(model_with_nominal_state->NamedParameters().at(name)->Gradient().IsTensor()); + + const Tensor& complete_data = param->Data().Get(); + const Tensor& complete_grad = param->Gradient().Get(); + const Tensor& nominal_data = model_with_nominal_state->NamedParameters().at(name)->Data().Get(); + const Tensor& nominal_grad = model_with_nominal_state->NamedParameters().at(name)->Gradient().Get(); + + ASSERT_EQ(complete_data.Shape(), nominal_data.Shape()); + ASSERT_EQ(complete_data.DataType(), nominal_data.DataType()); + ASSERT_EQ(complete_grad.Shape(), nominal_grad.Shape()); + ASSERT_EQ(complete_grad.DataType(), nominal_grad.DataType()); + + std::vector complete_data_vec; + std::vector complete_grad_vec; + std::vector nominal_data_vec; + std::vector nominal_grad_vec; + +#if defined(USE_CUDA) + CudaOrtValueToCpuVec(param->Data(), complete_data_vec); + CudaOrtValueToCpuVec(param->Gradient(), complete_grad_vec); + CudaOrtValueToCpuVec(model_with_nominal_state->NamedParameters().at(name)->Data(), nominal_data_vec); + CudaOrtValueToCpuVec(model_with_nominal_state->NamedParameters().at(name)->Gradient(), nominal_grad_vec); +#else + CpuOrtValueToVec(param->Data(), complete_data_vec); + CpuOrtValueToVec(param->Gradient(), complete_grad_vec); + CpuOrtValueToVec(model_with_nominal_state->NamedParameters().at(name)->Data(), nominal_data_vec); + CpuOrtValueToVec(model_with_nominal_state->NamedParameters().at(name)->Gradient(), nominal_grad_vec); +#endif + + for (size_t j = 0; j < complete_data_vec.size(); ++j) { + ASSERT_EQ(complete_data_vec[j], nominal_data_vec[j]); + ASSERT_EQ(complete_grad_vec[j], nominal_grad_vec[j]); + } + } + + std::vector complete_eval_fetches; + std::vector nominal_eval_fetches; + ASSERT_STATUS_OK(model_with_complete_state->EvalStep(inputs, complete_eval_fetches)); + ASSERT_STATUS_OK(model_with_nominal_state->EvalStep(inputs, nominal_eval_fetches)); + + ASSERT_GT(complete_eval_fetches.size(), 0); + for (size_t i = 0; i < complete_eval_fetches.size(); ++i) { + ASSERT_TRUE(complete_eval_fetches[i].IsTensor()); + ASSERT_TRUE(nominal_eval_fetches[i].IsTensor()); + const Tensor& complete_tensor = complete_eval_fetches[i].Get(); + const Tensor& nominal_tensor = nominal_eval_fetches[i].Get(); + ASSERT_EQ(complete_tensor.Shape(), nominal_tensor.Shape()); + ASSERT_EQ(complete_tensor.DataType(), nominal_tensor.DataType()); + + std::vector complete_eval_fetches_vec; + std::vector nominal_eval_fetches_vec; +#if defined(USE_CUDA) + CudaOrtValueToCpuVec(complete_eval_fetches[i], complete_eval_fetches_vec); + CudaOrtValueToCpuVec(nominal_eval_fetches[i], nominal_eval_fetches_vec); +#else + CpuOrtValueToVec(complete_eval_fetches[i], complete_eval_fetches_vec); + CpuOrtValueToVec(nominal_eval_fetches[i], nominal_eval_fetches_vec); +#endif + + for (size_t j = 0; j < complete_eval_fetches_vec.size(); ++j) { + ASSERT_EQ(complete_eval_fetches_vec[j], nominal_eval_fetches_vec[j]); + } + } + } +} + } // namespace test } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index e46952d87c2bf..8f25e1e4c92b8 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -420,4 +420,79 @@ TEST(TrainingCApiTest, UpdateParameterDifferentDevices) { } #endif +TEST(TrainingCApiTest, ModuleAndOptimizerWithNominalState) { + auto training_model_uri = MODEL_FOLDER "training_model.onnx"; + auto eval_model_uri = MODEL_FOLDER "eval_model.onnx"; + auto optimizer_model_uri = MODEL_FOLDER "adamw.onnx"; + + Ort::Env env; + Ort::SessionOptions session_options_for_complete_state; + Ort::SessionOptions session_options_for_nominal_state; + Ort::CheckpointState complete_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::CheckpointState nominal_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "nominal_checkpoint"); + +#ifdef USE_CUDA + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options_for_complete_state, 0)); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options_for_nominal_state, 0)); +#endif + + Ort::TrainingSession complete_training_session = Ort::TrainingSession(env, session_options_for_complete_state, complete_state, + training_model_uri, eval_model_uri, optimizer_model_uri); + Ort::TrainingSession nominal_training_session = Ort::TrainingSession(env, session_options_for_nominal_state, nominal_state, + training_model_uri, eval_model_uri, + optimizer_model_uri); + + Ort::Value params_buffer = complete_training_session.ToBuffer(false); + nominal_training_session.FromBuffer(params_buffer); + + for (size_t i = 0; i < 4U; ++i) { + std::vector x(2 * 784); + std::vector x_shape{2, 784}; + GenerateRandomData(x); + + std::vector labels{0, 8}; + std::vector labels_shape{2}; + + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + std::vector ort_inputs; + ort_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, x.data(), + x.size() * sizeof(float), + x_shape.data(), x_shape.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)); + ort_inputs.emplace_back(Ort::Value::CreateTensor(memory_info, labels.data(), + labels.size() * sizeof(int32_t), + labels_shape.data(), labels_shape.size(), + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)); + + std::vector complete_fetches = complete_training_session.TrainStep(ort_inputs); + std::vector nominal_fetches = nominal_training_session.TrainStep(ort_inputs); + + ASSERT_EQ(complete_fetches.size(), nominal_fetches.size()); + ASSERT_GT(complete_fetches.size(), 0U); + for (size_t j = 0; j < complete_fetches.size(); ++j) { + ASSERT_TRUE(complete_fetches[j].IsTensor()); + ASSERT_TRUE(nominal_fetches[j].IsTensor()); + + auto complete_tensor_info = complete_fetches[j].GetTensorTypeAndShapeInfo(); + auto nominal_tensor_info = nominal_fetches[j].GetTensorTypeAndShapeInfo(); + + ASSERT_EQ(complete_tensor_info.GetShape(), nominal_tensor_info.GetShape()); + ASSERT_EQ(complete_tensor_info.GetElementType(), nominal_tensor_info.GetElementType()); + + gsl::span complete_data = gsl::span(complete_fetches[j].GetTensorMutableData(), + complete_tensor_info.GetElementCount()); + gsl::span nominal_data = gsl::span(nominal_fetches[j].GetTensorMutableData(), + nominal_tensor_info.GetElementCount()); + + ASSERT_EQ(complete_data, nominal_data); + } + + complete_training_session.OptimizerStep(); + nominal_training_session.OptimizerStep(); + + complete_training_session.LazyResetGrad(); + nominal_training_session.LazyResetGrad(); + } +} + } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/training_api/checkpoint.cc b/orttraining/orttraining/training_api/checkpoint.cc index dbcef78c3965c..720bdd7e68dd3 100644 --- a/orttraining/orttraining/training_api/checkpoint.cc +++ b/orttraining/orttraining/training_api/checkpoint.cc @@ -174,7 +174,7 @@ Status ToFile(const PathString& checkpoint_path, flatbuffers::FlatBufferBuilder& Status FromTensorProtos( gsl::span trainable_tensor_protos, gsl::span non_trainable_tensor_protos, - const PathString& checkpoint_path) { + const PathString& checkpoint_path, const bool nominal_checkpoint) { const auto check_unique = [](gsl::span tensor_protos, InlinedHashSet& unique_names) { for (const auto& tensor_proto : tensor_protos) { @@ -230,6 +230,7 @@ Status FromTensorProtos( fbs::ModuleStateBuilder module_state_builder(builder); module_state_builder.add_requires_grad_params(fbs_trainable_tensors); module_state_builder.add_frozen_params(fbs_non_trainable_tensors); + module_state_builder.add_is_nominal_state(nominal_checkpoint); flatbuffers::Offset fbs_module_state = module_state_builder.Finish(); fbs::CheckpointBuilder checkpoint_builder(builder); @@ -294,6 +295,7 @@ Status FromModuleState(const ModuleCheckpointState& module_state, fbs::ModuleStateBuilder module_state_builder(builder); module_state_builder.add_requires_grad_params(fbs_trainable_tensors); module_state_builder.add_frozen_params(fbs_non_trainable_tensors); + module_state_builder.add_is_nominal_state(module_state.is_nominal_state); fbs_module_state = module_state_builder.Finish(); return Status::OK(); @@ -513,6 +515,8 @@ Status ToModuleState( module_state.named_parameters.insert({name, param}); } + module_state.is_nominal_state = fbs_module_state.is_nominal_state(); + return Status::OK(); } @@ -646,6 +650,10 @@ Status ToModelProto(gsl::span checkpoint_bytes, ORT_RETURN_IF_NOT(frozen_params, "Checkpoint is invalid. Expected: Valid non-trainable params flatbuffer. Actual: nullptr."); + ORT_RETURN_IF(module_state->is_nominal_state(), + "Cannot load a nominal checkpoint to a model proto. " + "Expected: Complete checkpoint. Actual: Nominal checkpoint."); + InlinedHashMap param_tensor_protos; param_tensor_protos.reserve( static_cast(requires_grad_params->size()) + static_cast(frozen_params->size())); @@ -717,14 +725,33 @@ Status ToCheckpointState(gsl::span checkpoint_bytes, CheckpointSt } // namespace load +#if !defined(ORT_MINIMAL_BUILD) +InlinedVector Nominalize(gsl::span tensor_protos) { + InlinedVector nominal_tensor_protos; + nominal_tensor_protos.reserve(tensor_protos.size()); + for (const auto& tensor_proto : tensor_protos) { + ONNX_NAMESPACE::TensorProto nominal_tensor_proto; + nominal_tensor_proto.set_name(tensor_proto.name()); + nominal_tensor_proto.set_data_type(tensor_proto.data_type()); + nominal_tensor_protos.push_back(nominal_tensor_proto); + } + + return nominal_tensor_protos; +} +#endif + } // namespace #if !defined(ORT_MINIMAL_BUILD) Status SaveCheckpoint(gsl::span trainable_tensor_protos, gsl::span non_trainable_tensor_protos, - const PathString& checkpoint_path) { + const PathString& checkpoint_path, const bool nominal_checkpoint) { ORT_RETURN_IF_NOT(FLATBUFFERS_LITTLEENDIAN, "ORT training checkpoint format only supports little-endian machines"); - return save::FromTensorProtos(trainable_tensor_protos, non_trainable_tensor_protos, checkpoint_path); + return nominal_checkpoint + ? save::FromTensorProtos(Nominalize(trainable_tensor_protos), Nominalize(non_trainable_tensor_protos), + checkpoint_path, nominal_checkpoint) + : save::FromTensorProtos(trainable_tensor_protos, non_trainable_tensor_protos, checkpoint_path, + nominal_checkpoint); } #endif diff --git a/orttraining/orttraining/training_api/checkpoint.h b/orttraining/orttraining/training_api/checkpoint.h index 5d8554662f48d..95d3820a33a70 100644 --- a/orttraining/orttraining/training_api/checkpoint.h +++ b/orttraining/orttraining/training_api/checkpoint.h @@ -49,11 +49,12 @@ Status SaveCheckpoint(const CheckpointState& state, const PathString& checkpoint * @param trainable_tensor_protos trainable parameters in TensorProto format. * @param non_trainable_tensor_protos non-trainable parameters in TensorProto format. * @param checkpoint_path file where checkpoint is saved. + * @param nominal_checkpoint flag indicating whether to save the complete checkpoint or the nominal checkpoint. * @return Status */ Status SaveCheckpoint(gsl::span trainable_tensor_protos, gsl::span non_trainable_tensor_protos, - const PathString& checkpoint_path); + const PathString& checkpoint_path, const bool nominal_checkpoint); #endif /** diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 0e8544a7639ba..ed6d151a595b4 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -132,6 +132,7 @@ struct OrtTrainingApi { * \note Note that the training session created with a checkpoint state uses this state to store the entire * training state (including model parameters, its gradients, the optimizer states and the properties). * As a result, it is required that the checkpoint state outlive the lifetime of the training session. + * \note Note that the checkpoint file can be either the complete checkpoint or the nominal checkpoint. * * \param[in] checkpoint_path Path to the checkpoint file * \param[out] checkpoint_state Checkpoint state that contains the states of the training session. @@ -463,10 +464,12 @@ struct OrtTrainingApi { * * The parameters_buffer argument has to be of the size given by OrtTrainingApi::GetParametersSize api call, * with matching setting for trainable_only argument. All the target parameters must be of the same - * datatype. This is a complementary function to OrtTrainingApi::CopyBufferToParameters + * datatype. This is a complementary function to OrtTrainingApi::CopyParametersToBuffer * and can be used to load updated buffer values onto the training state. * Parameter ordering is preserved. * User is responsible for allocating and freeing the resources used by the parameters_buffer. + * In case the training session was created with a nominal checkpoint, invoking this function is required + * to load the updated parameters onto the checkpoint to complete it. * * \param[in] sess The `this` pointer to the training session. * \param[in] trainable_only Whether to skip non-trainable parameters diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 218bef524200c..e78c16136ab3f 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -58,6 +58,8 @@ using Property = std::variant; * training state (including model parameters, its gradients, the optimizer states and the properties). * The Ort::TrainingSession does not hold a copy of the Ort::CheckpointState and as a result, it is required * that the checkpoint state outlive the lifetime of the training session. + * \note Note that the checkpoint state can be either the complete checkpoint state or the nominal checkpoint + * state depending on the version provided while loading the checkpoint. * */ class CheckpointState : public detail::Base { @@ -386,6 +388,9 @@ class TrainingSession : public detail::Base { Value ToBuffer(const bool only_trainable); /** \brief Loads the training session model parameters from a contiguous buffer + * + * In case the training session was created with a nominal checkpoint, invoking this function is required + * to load the updated parameters onto the checkpoint to complete it. * * \param[in] buffer Contiguous buffer to load the parameters from. */ diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index 7d1326a10f8f8..397cba0b0f9de 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -168,22 +168,23 @@ inline void TrainingSession::FromBuffer(Value& buffer) { auto buffer_size = buffer_shape.front(); + size_t session_buffer_size = 0U; + ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false)); + + if (buffer_size == static_cast(session_buffer_size)) { + ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false)); + return; + } + size_t session_buffer_size_trainable_only = 0U; ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size_trainable_only, true)); if (buffer_size == static_cast(session_buffer_size_trainable_only)) { ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, true)); return; - } - - size_t session_buffer_size = 0U; - ThrowOnError(GetTrainingApi().GetParametersSize(p_, &session_buffer_size, false)); - - if (buffer_size != static_cast(session_buffer_size)) { + } else { ThrowStatus(Status("Incorrect buffer size received.", OrtErrorCode::ORT_INVALID_ARGUMENT)); } - - ThrowOnError(GetTrainingApi().CopyBufferToParameters(p_, buffer, false)); } inline CheckpointState CheckpointState::LoadCheckpoint(const std::basic_string& path_to_checkpoint) { diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index cf49a01517d6b..41ed79d285533 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -6,6 +6,8 @@ #include "core/common/safeint.h" #include "core/common/string_utils.h" #include "core/framework/execution_provider.h" +#include "core/framework/mldata_type_utils.h" +#include "core/framework/tensorprotoutils.h" #include "core/session/inference_session.h" #include "core/session/environment.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -117,6 +119,75 @@ Status TransformModelInputsForInference(Graph& inference_graph, return Status::OK(); } #endif + +InlinedHashMap BuildParameterToInputNodeArgMap(const ModuleCheckpointState& state, + const InputDefList* model_inputs) { + ORT_ENFORCE(model_inputs != nullptr, "Model inputs are not defined."); + InlinedHashMap parameter_to_input_node_arg_map; + parameter_to_input_node_arg_map.reserve(state.named_parameters.size()); + for (const auto& input_def : *model_inputs) { + const std::string& input_name = input_def->Name(); + const auto param_it = state.named_parameters.find(input_name); + if (param_it == state.named_parameters.end()) { + continue; + } + parameter_to_input_node_arg_map[input_name] = input_def; + } + return parameter_to_input_node_arg_map; +} + +InlinedHashMap BuildParameterToGradInputIndexMap(gsl::span grad_names) { + InlinedHashMap param_name_to_grad_input_index_map; + param_name_to_grad_input_index_map.reserve(grad_names.size()); + for (size_t i = 0; i < grad_names.size(); ++i) { + std::string param_name; + utils::GetParamNameFromGradient(grad_names[i], param_name); + param_name_to_grad_input_index_map.insert({param_name, i}); + } + return param_name_to_grad_input_index_map; +} + +Status LoadParameter(const std::string& param_name, const Tensor& src_weight_tensor, + const SessionState& session_state, const bool force_load, + const InlinedHashMap& param_to_grad_index, + gsl::span grad_names, Parameter& param) { + InlinedVector node_info_vec; + ORT_THROW_IF_ERROR(session_state.GetInputNodeInfo(param_name, node_info_vec)); + const auto& node_info = node_info_vec.front(); + const auto target_device = *node_info.device; + for (auto it = node_info_vec.begin(); it != node_info_vec.end(); ++it) { + ORT_ENFORCE(target_device == *(it->device), "Inconsistent device requirements found for input: ", param_name); + } + + if (force_load || src_weight_tensor.Location().device.Type() != target_device.Type()) { + auto weight_allocator = session_state.GetAllocator(target_device); + ORT_ENFORCE(weight_allocator != nullptr); + + // Create a new tensor on the target_device and switch the source_ortvalue to point to this new tensor + auto dst_weight_tensor = std::make_unique(src_weight_tensor.DataType(), src_weight_tensor.Shape(), + weight_allocator); + ORT_THROW_IF_ERROR(session_state.GetDataTransferMgr().CopyTensor(src_weight_tensor, *dst_weight_tensor.get())); + auto ml_tensor_type = DataTypeImpl::GetType(); + param.Data().Init(dst_weight_tensor.release(), ml_tensor_type, ml_tensor_type->GetDeleteFunc()); + } + + if (param.RequiresGrad()) { + // Create gradient accumulation buffer. + auto grad_it = param_to_grad_index.find(param_name); + ORT_ENFORCE(grad_it != param_to_grad_index.end(), "Gradient buffer input not provided for param: ", + param_name); + + const size_t grad_input_index = grad_it->second; + auto& param_grad_name = grad_names[grad_input_index]; + + OrtValue param_grad; + ORT_THROW_IF_ERROR(utils::CreateZeroValuedOrtValueLike(session_state, param.Data(), param_grad)); + ORT_THROW_IF_ERROR(param.SetGrad(param_grad_name, param_grad)); + } + + return Status::OK(); +} + } // namespace Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const { @@ -251,7 +322,6 @@ Module::Module(const ModelIdentifiers& model_identifiers, // user inputs, weights, gradients, reset_grad InlinedVector user_input_names, param_input_names, grad_input_names, reset_grad_name; - std::unordered_map param_name_to_grad_input_index_map; for (const auto& input_name : train_input_names) { auto it = state_->module_checkpoint_state.named_parameters.find(input_name); if (it != state_->module_checkpoint_state.named_parameters.end()) { @@ -259,7 +329,6 @@ Module::Module(const ModelIdentifiers& model_identifiers, } else if (input_name == ACCUMULATE_GRAD_CONTROL_INPUT_NAME) { reset_grad_name.emplace_back(input_name); } else if (std::string param_name; utils::GetParamNameFromGradient(input_name, param_name)) { - param_name_to_grad_input_index_map.insert({param_name, grad_input_names.size()}); grad_input_names.emplace_back(input_name); } else { user_input_names.emplace_back(input_name); @@ -268,11 +337,7 @@ Module::Module(const ModelIdentifiers& model_identifiers, gradients_.resize(grad_input_names.size()); - train_input_names_ = user_input_names; - train_user_input_count_ = user_input_names.size(); - train_input_names_.insert(train_input_names_.end(), param_input_names.begin(), param_input_names.end()); - train_input_names_.insert(train_input_names_.end(), grad_input_names.begin(), grad_input_names.end()); - train_input_names_.insert(train_input_names_.end(), reset_grad_name.begin(), reset_grad_name.end()); + train_input_names_ = TrainInputNames(user_input_names, param_input_names, grad_input_names); for (const auto& output_name : train_output_names) { if (std::string param_name; !utils::GetParamNameFromGradient(output_name, param_name)) { @@ -280,58 +345,24 @@ Module::Module(const ModelIdentifiers& model_identifiers, } } - // Loop each parameter, and allocate its memory based on the user-specified device. - auto& train_sess_state = train_sess_->GetSessionState(); - for (auto& param_name : param_input_names) { - auto params_iter = state_->module_checkpoint_state.named_parameters.find(param_name); - ORT_ENFORCE(params_iter != state_->module_checkpoint_state.named_parameters.end()); - - // Retrieve the target device for "param_name". - InlinedVector node_info_vec; - ORT_THROW_IF_ERROR(train_sess_state.GetInputNodeInfo(param_name, node_info_vec)); - const auto& node_info = node_info_vec.front(); - const auto target_device = *node_info.device; - for (auto it = node_info_vec.begin(); it != node_info_vec.end(); ++it) { - ORT_ENFORCE(target_device == *(it->device), "Inconsistent device requirements found for input: ", param_name); - } - - // Copy ortvalue buffer from CPU to target_device for this "param_name" (based on graph partitioning) - // Only copies data if the target device is not the same as the current device the buffer is placed on - OrtValue& param_data = params_iter->second->Data(); - ORT_ENFORCE(param_data.IsTensor()); - const Tensor& param_data_tensor = param_data.Get(); - // If the source device type is already the same as target device skip copy - if (param_data_tensor.Location().device.Type() != target_device.Type()) { - // TODO: move this outside of the for loop? - auto target_allocator = train_sess_state.GetAllocator(target_device); - ORT_ENFORCE(target_allocator != nullptr); - - // Create a new tensor on the target_device and switch the source_ortvalue to point to this new tensor - auto target_tensor = std::make_unique(param_data_tensor.DataType(), param_data_tensor.Shape(), - target_allocator); - ORT_THROW_IF_ERROR(train_sess_state.GetDataTransferMgr().CopyTensor(param_data_tensor, *target_tensor.get())); - auto ml_tensor_type = DataTypeImpl::GetType(); - param_data.Init(target_tensor.release(), ml_tensor_type, ml_tensor_type->GetDeleteFunc()); - } - - weights_.push_back(param_data); - weight_names_.push_back(param_name); - - // Create gradient buffer when parameter requires gradient. - if (params_iter->second->RequiresGrad()) { - // Create gradient accumulation buffer. - auto it = param_name_to_grad_input_index_map.find(param_name); - ORT_ENFORCE(it != param_name_to_grad_input_index_map.end(), "Gradient buffer input not provided for param: ", - param_name); - - const size_t grad_input_index = it->second; - auto& param_grad_name = grad_input_names[grad_input_index]; - // TODO: don't pre-allocate the gradient buffer. - // Gradient usually stays on the same device of its parameter. - OrtValue param_grad; - ORT_THROW_IF_ERROR(utils::CreateZeroValuedOrtValueLike(train_sess_state, param_data, param_grad)); - ORT_THROW_IF_ERROR(params_iter->second->SetGrad(param_grad_name, param_grad)); - gradients_[grad_input_index] = params_iter->second->Gradient(); + if (!state_->module_checkpoint_state.is_nominal_state) { + // ORT_THROW_IF_ERROR(AllocateMemoryForWeights()); + // Loop each parameter, and allocate its memory based on the user-specified device. + const auto param_to_grad_index = BuildParameterToGradInputIndexMap(train_input_names_.GradientInputNames()); + for (auto& param_name : train_input_names_.WeightsInputNames()) { + auto params_iter = state_->module_checkpoint_state.named_parameters.find(param_name); + ORT_ENFORCE(params_iter != state_->module_checkpoint_state.named_parameters.end()); + + OrtValue& param_data = params_iter->second->Data(); + ORT_ENFORCE(param_data.IsTensor(), "Expected: Parameter data should be of tensor type. Actual: ", + params_iter->second->Name(), " is not a tensor."); + ORT_THROW_IF_ERROR(LoadParameter(param_name, param_data.Get(), train_sess_->GetSessionState(), + false /* force_load */, param_to_grad_index, + train_input_names_.GradientInputNames(), *params_iter->second)); + weights_.push_back(param_data); + if (params_iter->second->RequiresGrad()) { + gradients_[param_to_grad_index.at(param_name)] = params_iter->second->Gradient(); + } } } @@ -414,16 +445,24 @@ std::string Module::GetEvalModelOutputName(size_t index) const { size_t Module::GetParametersSize(const bool trainable_only) const { SafeInt parameters_size = 0; - for (const auto& it : state_->module_checkpoint_state.named_parameters) { - if (trainable_only && !it.second->RequiresGrad()) { + const auto model_inputs_with_error = GetTrainingModelInputs(); + ORT_THROW_IF_ERROR(model_inputs_with_error.first); + ORT_ENFORCE(model_inputs_with_error.second, "Training model graph inputs are not defined."); + for (const auto& input_def : *model_inputs_with_error.second) { + const std::string& input_name = input_def->Name(); + const auto param_it = state_->module_checkpoint_state.named_parameters.find(input_name); + if (param_it == state_->module_checkpoint_state.named_parameters.end() || + (trainable_only && !param_it->second->RequiresGrad())) { continue; } - parameters_size += it.second->Data().Get().Shape().Size(); + parameters_size += onnxruntime::utils::GetTensorShapeFromTensorShapeProto(*input_def->Shape()).Size(); } return parameters_size; } std::vector> Module::Parameters() const { + ORT_ENFORCE(!state_->module_checkpoint_state.is_nominal_state, + "Cannot fetch parameters from a nominal checkpoint state. Please load the model parameters first."); std::vector> params; for (auto& it : state_->module_checkpoint_state.named_parameters) { params.push_back(it.second); @@ -432,23 +471,27 @@ std::vector> Module::Parameters() const { } std::unordered_map> Module::NamedParameters() const { + ORT_ENFORCE(!state_->module_checkpoint_state.is_nominal_state, + "Cannot fetch named parameters from a nominal checkpoint state. Please load the model parameters first."); return state_->module_checkpoint_state.named_parameters; } Status Module::CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only) { - ORT_ENFORCE(parameters_buffer.IsAllocated(), "Parameters buffer should be pre-allocated."); - ORT_ENFORCE(parameters_buffer.IsTensor(), "Parameters buffer should be of tensor type."); + ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, + "Cannot copy parameters from a nominal checkpoint state. Please load the model parameters first."); + ORT_RETURN_IF_NOT(parameters_buffer.IsAllocated(), "Parameters buffer should be pre-allocated."); + ORT_RETURN_IF_NOT(parameters_buffer.IsTensor(), "Parameters buffer should be of tensor type."); auto* init_tensor = parameters_buffer.GetMutable(); ORT_ENFORCE(nullptr != init_tensor); auto expected_buffer_size = static_cast(GetParametersSize(trainable_only)); - ORT_ENFORCE(init_tensor->Shape().Size() == expected_buffer_size, - "Parameters buffer size incorrect. Expected:", expected_buffer_size, - ", Actual:", init_tensor->Shape().Size()); + ORT_RETURN_IF(init_tensor->Shape().Size() != expected_buffer_size, + "Parameters buffer size incorrect. Expected:", expected_buffer_size, + ", Actual:", init_tensor->Shape().Size()); const DataTransferManager& sess_data_transfer_manager = train_sess_->GetDataTransferManager(); size_t offset = 0; - for (const auto& param_name : weight_names_) { + for (const auto& param_name : train_input_names_.WeightsInputNames()) { auto& param = state_->module_checkpoint_state.named_parameters.at(param_name); if (trainable_only && !param->RequiresGrad()) { continue; @@ -458,7 +501,7 @@ Status Module::CopyParametersToBuffer(OrtValue& parameters_buffer, const bool tr const TensorShape& shape = weight_tensor->Shape(); auto element_type = init_tensor->DataType(); - ORT_ENFORCE(weight_tensor->DataType() == element_type, "Data types must match."); + ORT_RETURN_IF(weight_tensor->DataType() != element_type, "Data types must match."); const OrtMemoryInfo& info = init_tensor->Location(); std::unique_ptr p_tensor; @@ -470,54 +513,102 @@ Status Module::CopyParametersToBuffer(OrtValue& parameters_buffer, const bool tr data_buffer + offset, info); } else { - ORT_THROW("Unsupported type: ", element_type); + ORT_THROW("Unsupported type: ", element_type, " encountered while copying parameters to buffer. ", + "Only float is supported."); } - ORT_THROW_IF_ERROR(sess_data_transfer_manager.CopyTensor(*weight_tensor, *p_tensor.get())); + ORT_RETURN_IF_ERROR(sess_data_transfer_manager.CopyTensor(*weight_tensor, *p_tensor.get())); offset += shape.Size(); } return Status::OK(); } Status Module::CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only) { - ORT_ENFORCE(parameters_buffer.IsAllocated(), "Parameters buffer should be pre-allocated."); - ORT_ENFORCE(parameters_buffer.IsTensor(), "Parameters buffer should be of tensor type."); - auto* init_tensor = parameters_buffer.GetMutable(); - ORT_ENFORCE(nullptr != init_tensor); + // In case of a nominal checkpoint state, all parameters need to be loaded into the model. + // i.e. trainable_only must be false. + ORT_RETURN_IF(trainable_only && state_->module_checkpoint_state.is_nominal_state, + "For nominal checkpoint state, all parameters need to be loaded into the model " + "(trainable_only = false)."); + ORT_RETURN_IF_NOT(parameters_buffer.IsAllocated(), "Parameters buffer should be pre-allocated."); + ORT_RETURN_IF_NOT(parameters_buffer.IsTensor(), "Parameters buffer should be of tensor type."); + auto* buffer_tensor = parameters_buffer.GetMutable(); + ORT_RETURN_IF(nullptr == buffer_tensor, "Expected valid parameter buffer. Actual: nullptr."); auto expected_buffer_size = static_cast(GetParametersSize(trainable_only)); - ORT_ENFORCE(init_tensor->Shape().Size() == expected_buffer_size, - "Parameters buffer size incorrect. Expected:", expected_buffer_size, - ", Actual:", init_tensor->Shape().Size()); + ORT_RETURN_IF(buffer_tensor->Shape().Size() != expected_buffer_size, + "Parameters buffer size incorrect. Expected:", expected_buffer_size, + ", Actual:", buffer_tensor->Shape().Size()); + auto& train_sess_state = train_sess_->GetSessionState(); const DataTransferManager& sess_data_transfer_manager = train_sess_->GetDataTransferManager(); + const auto model_inputs_with_error = GetTrainingModelInputs(); + ORT_RETURN_IF_ERROR(model_inputs_with_error.first); + ORT_RETURN_IF_NOT(model_inputs_with_error.second, "Training model graph inputs are not defined."); + const auto param_to_node_arg = BuildParameterToInputNodeArgMap(state_->module_checkpoint_state, + model_inputs_with_error.second); + const auto param_to_grad_index = BuildParameterToGradInputIndexMap(train_input_names_.GradientInputNames()); + + if (state_->module_checkpoint_state.is_nominal_state) { + // weights_ vector is not initialized for a nominal state. This function is expected to + // initialize the weights_. + ORT_ENFORCE(weights_.empty(), "Weights vector should be empty for a nominal state."); + } size_t offset = 0; - for (const auto& param_name : weight_names_) { + for (const auto& param_name : train_input_names_.WeightsInputNames()) { auto& param = state_->module_checkpoint_state.named_parameters.at(param_name); if (trainable_only && !param->RequiresGrad()) { continue; } OrtValue& weight = param->Data(); - auto* weight_tensor = weight.GetMutable(); - const TensorShape& shape = weight_tensor->Shape(); - auto element_type = init_tensor->DataType(); - ORT_ENFORCE(weight_tensor->DataType() == element_type, "Data types must match."); + auto param_it = param_to_node_arg.find(param_name); + const TensorShape shape = onnxruntime::utils::GetTensorShapeFromTensorShapeProto( + *(param_it->second->Shape())); + const auto element_type = static_cast( + onnxruntime::utils::GetMLDataType(*param_it->second)) + ->GetElementType(); - const OrtMemoryInfo& info = init_tensor->Location(); - std::unique_ptr p_tensor; + const OrtMemoryInfo& info = buffer_tensor->Location(); + std::unique_ptr src_tensor; if (onnxruntime::utils::IsPrimitiveDataType(element_type)) { - float* data_buffer = init_tensor->MutableData(); - p_tensor = std::make_unique(element_type, - shape, - data_buffer + offset, - info); + float* data_buffer = buffer_tensor->MutableData(); + src_tensor = std::make_unique(element_type, + shape, + data_buffer + offset, + info); + } else { + ORT_THROW("Unsupported type: ", element_type, " encountered while copying buffer to parameters. ", + "Only float is supported."); + } + + if (state_->module_checkpoint_state.is_nominal_state) { + // If state is a nominal state, then we first need to allocate the memory for + // parameters and their gradients in the checkpoint state before copying the data. + ORT_RETURN_IF_ERROR(LoadParameter(param_name, *src_tensor, train_sess_state, true, + param_to_grad_index, train_input_names_.GradientInputNames(), + *param)); + weights_.push_back(param->Data()); + if (param->RequiresGrad()) { + // It is expected that the gradients_ vector is already initialized with the correct size + // in the Module constructor (even though the OrtValues contained in the vector are empty). + gradients_[param_to_grad_index.at(param_name)] = param->Gradient(); + } } else { - ORT_THROW("Unsupported type: ", element_type); + // If state is not a nominal state, then we can directly copy the data to the existing + // parameters in the checkpoint state. + auto* weight_tensor = weight.GetMutable(); + ORT_ENFORCE(weight_tensor->DataType() == element_type, "Data types must match."); + ORT_THROW_IF_ERROR(sess_data_transfer_manager.CopyTensor(*src_tensor.get(), *weight_tensor)); } - ORT_THROW_IF_ERROR(sess_data_transfer_manager.CopyTensor(*p_tensor.get(), *weight_tensor)); + offset += shape.Size(); } + + if (state_->module_checkpoint_state.is_nominal_state) { + // Once the parameters are loaded, the state is no longer a nominal state. + state_->module_checkpoint_state.is_nominal_state = false; + } + return Status::OK(); } @@ -527,6 +618,9 @@ Status Module::LazyResetGrad() { } Status Module::TrainStep(const std::vector& inputs, std::vector& outputs) { + ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, + "Cannot perform TrainStep with a nominal state. Please load the model parameters first."); + std::vector> params; std::vector feeds{inputs}; feeds.insert(feeds.end(), weights_.begin(), weights_.end()); feeds.insert(feeds.end(), gradients_.begin(), gradients_.end()); @@ -535,7 +629,7 @@ Status Module::TrainStep(const std::vector& inputs, std::vector(!accumulate_gradient_, &reset_grad_input); feeds.push_back(reset_grad_input); - ORT_THROW_IF_ERROR(train_sess_->Run(RunOptions(), train_input_names_, feeds, train_output_names_, &outputs)); + ORT_THROW_IF_ERROR(train_sess_->Run(RunOptions(), train_input_names_.AllInputNames(), feeds, train_output_names_, &outputs)); // Reset the flag after every step. In case the ResetGrad was called before running // the current step, it will have done the effective resetting during the @@ -546,6 +640,8 @@ Status Module::TrainStep(const std::vector& inputs, std::vector& inputs, std::vector& outputs) { + ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, + "Cannot perform EvalStep with a nominal state. Please load the model parameters first."); ORT_ENFORCE(nullptr != eval_sess_, "Evaluation session not initialized."); std::vector feeds{inputs}; feeds.insert(feeds.end(), weights_.begin(), weights_.end()); @@ -560,6 +656,8 @@ Status Module::EvalStep(const std::vector& inputs, std::vector graph_output_names) const { + ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, + "Cannot export the model with a nominal state. Please load the model parameters first."); ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(), "Eval model was not provided. Cannot export a model for inferencing."); @@ -586,7 +684,7 @@ Status Module::ExportModelForInferencing(const std::string& inference_model_path #endif size_t Module::GetTrainingModelInputCount() const noexcept { - return train_user_input_count_; + return train_input_names_.UserInputNames().size(); } size_t Module::GetEvalModelInputCount() const noexcept { @@ -594,10 +692,10 @@ size_t Module::GetEvalModelInputCount() const noexcept { } std::string Module::GetTrainingModelInputName(size_t index) const { - ORT_ENFORCE(index < train_user_input_count_, - "Train input name index out of range. Expected in range [0-", train_user_input_count_, "). Actual: ", + ORT_ENFORCE(index < train_input_names_.UserInputNames().size(), + "Train input name index out of range. Expected in range [0-", train_input_names_.UserInputNames().size(), "). Actual: ", index); - return train_input_names_.at(index); + return train_input_names_.UserInputNames()[index]; } std::string Module::GetEvalModelInputName(size_t index) const { @@ -615,6 +713,43 @@ std::pair Module::GetEvalModelInputs() cons return eval_sess_->GetModelInputs(); } +Module::TrainInputNames::TrainInputNames(gsl::span user_input_names, + gsl::span weights_input_names, + gsl::span gradient_input_names) { + train_input_names_.reserve(user_input_names.size() + + weights_input_names.size() + + gradient_input_names.size() + + 1U); // +1 for the reset gradient flag input + train_input_index_offsets_.reserve(3); + + train_input_names_.insert(train_input_names_.end(), + user_input_names.begin(), user_input_names.end()); + train_input_index_offsets_.push_back(train_input_names_.size()); + train_input_names_.insert(train_input_names_.end(), + weights_input_names.begin(), weights_input_names.end()); + train_input_index_offsets_.push_back(train_input_names_.size()); + train_input_names_.insert(train_input_names_.end(), + gradient_input_names.begin(), gradient_input_names.end()); + train_input_index_offsets_.push_back(train_input_names_.size()); + train_input_names_.push_back(ACCUMULATE_GRAD_CONTROL_INPUT_NAME); +} + +gsl::span Module::TrainInputNames::AllInputNames() const { return train_input_names_; } + +gsl::span Module::TrainInputNames::UserInputNames() const { + return gsl::span{train_input_names_.begin(), train_input_index_offsets_[0]}; +} + +gsl::span Module::TrainInputNames::WeightsInputNames() const { + return gsl::span{train_input_names_.begin() + train_input_index_offsets_[0], + train_input_index_offsets_[1] - train_input_index_offsets_[0]}; +} + +gsl::span Module::TrainInputNames::GradientInputNames() const { + return gsl::span{train_input_names_.begin() + train_input_index_offsets_[1], + train_input_index_offsets_[2] - train_input_index_offsets_[1]}; +} + } // namespace api } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index f323e6be72d49..917887404217f 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -53,6 +53,7 @@ struct ModuleCheckpointState { public: std::unordered_map> named_parameters; const DataTransferManager* train_session_data_transfer_mgr; + bool is_nominal_state = false; }; struct CheckpointState; @@ -87,19 +88,28 @@ struct Module { ~Module(); // Return the trainable/nontrainable parameters + // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, + // and the state has not been loaded yet, then this function will raise an exception. std::vector> Parameters() const; + // Return the trainable/nontrainable parameters as a map + // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, + // and the state has not been loaded yet, then this function will raise an exception. std::unordered_map> NamedParameters() const; // Reset and release the gradient buffer of all trainable params lazily. Status LazyResetGrad(); // Train Step – does forward and backward computation. The outputs will be the forward’s outputs. - // Gradients will be accumulated within the Parameter object + // Gradients will be accumulated within the Parameter object. + // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, + // and the state has not been loaded yet, then this function will return an error. Status TrainStep(const std::vector& inputs, std::vector& outputs); // Eval Step – does forward computation. This will use a separate inference session // and take in a separate inference graph, while sharing the parameters + // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, + // and the state has not been loaded yet, then this function will return an error. Status EvalStep(const std::vector& inputs, std::vector& outputs); // Returns the output count for training graph @@ -118,14 +128,20 @@ struct Module { size_t GetParametersSize(const bool trainable_only = true) const; // Copy parameters onto contiguous buffer held by parameters_buffer + // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, + // and the state has not been loaded yet, then this function will return an error. Status CopyParametersToBuffer(OrtValue& parameters_buffer, const bool trainable_only = true); // Copy parameter values from contiguous buffer held by parameters_buffer onto parameters + // This function is responsible for completing the nominal checkpoint state. The checkpoint + // state will no longer be nominal after the successful completion of this function. Status CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only = true); #if !defined(ORT_MINIMAL_BUILD) // Load the eval model from eval_model_path_or_bytes and transform it for the purpose of - // inferencing, and serialize to given path + // inferencing, and serialize to given path. + // If the parameter state is not available; i.e. the module was created using the nominal checkpoint, + // and the state has not been loaded yet, then this function will return an error. Status ExportModelForInferencing(const std::string& inference_model_path, gsl::span graph_output_names) const; #endif @@ -152,11 +168,28 @@ struct Module { std::unique_ptr train_sess_{nullptr}; std::unique_ptr eval_sess_{nullptr}; - InlinedVector train_input_names_; + struct TrainInputNames { + private: + InlinedVector train_input_names_; + InlinedVector train_input_index_offsets_; // offset range[[0], [1]) = user input names + // offset range[[1], [2]) = weights input names + // offset range[[2], [3]) = gradient input names + public: + TrainInputNames() = default; + TrainInputNames(gsl::span user_input_names, + gsl::span weights_input_names, + gsl::span gradient_input_names); + + gsl::span AllInputNames() const; + gsl::span UserInputNames() const; + gsl::span WeightsInputNames() const; + gsl::span GradientInputNames() const; + }; + + TrainInputNames train_input_names_; InlinedVector train_output_names_; InlinedVector eval_input_names_; InlinedVector eval_output_names_; - InlinedVector weight_names_; InlinedVector weights_; InlinedVector gradients_; @@ -165,7 +198,6 @@ struct Module { bool accumulate_gradient_ = false; std::optional eval_model_path_; - size_t train_user_input_count_{0U}; size_t eval_user_input_count_{0U}; }; diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 38a9aad9640ea..0ed41f670f9e3 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -568,9 +568,16 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtChe API_IMPL_BEGIN auto chkpt_state = reinterpret_cast(checkpoint_state); + if (chkpt_state->module_checkpoint_state.is_nominal_state) { + const std::string err_msg = + "Parameter type and shape cannot be retrieved from nominal checkpoint state. " + "Please load the parameter states first."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { - std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + const std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); } @@ -586,9 +593,15 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState } auto chkpt_state = reinterpret_cast(checkpoint_state); + if (chkpt_state->module_checkpoint_state.is_nominal_state) { + const std::string err_msg = + "Parameter cannot be updated for nominal checkpoint state. Please load all the parameter states first."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { - std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + const std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); } ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom( @@ -608,9 +621,15 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState } auto chkpt_state = reinterpret_cast(checkpoint_state); + if (chkpt_state->module_checkpoint_state.is_nominal_state) { + const std::string err_msg = + "Parameter cannot be retrieved from nominal checkpoint state. Please load the parameter states first."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { - std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + const std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); } diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index 7f583ce8f6e76..84c35e6100385 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -21,8 +21,8 @@ namespace { constexpr char GROUP_ZERO_NAME[] = "group0"; static constexpr std::array CommonOptimizerInputs{"learning_rate", "step", "params", "gradients"}; -Status GraphInputsAreExpected(gsl::span actual_graph_inputs, - gsl::span expected_graph_inputs) { +Status GraphInputsAreExpected(gsl::span actual_graph_inputs, + gsl::span expected_graph_inputs) { const auto stringify = [](const auto& container) { if (container.empty()) { return std::string("[]"); @@ -245,8 +245,17 @@ Optimizer::Optimizer(const ModelIdentifiers& model_identifiers, if (!find_group_zero) state_->optimizer_checkpoint_state.group_named_optimizer_states.insert( {GROUP_ZERO_NAME, std::make_shared()}); - ORT_THROW_IF_ERROR(GenerateMomentumNamedStates(state_->optimizer_checkpoint_state)); - ORT_THROW_IF_ERROR(ConstructInputs()); + if (!state_->module_checkpoint_state.is_nominal_state) { + // Construct the optimizer state and inputs only if the complete state + // is available. + // For a nominal state, delay the construction of the optimizer state + // and inputs until the complete state is available. Once the complete + // state is available, the optimizer state and inputs can be constructed + // by invoking ConstructOptimizerStateAndInputs(). + ORT_THROW_IF_ERROR(ConstructOptimizerStateAndInputs()); + } else { + delay_optimizer_state_contruction_ = true; + } } else { ORT_THROW_IF_ERROR(LoadStateDict(state_->optimizer_checkpoint_state)); } @@ -298,6 +307,10 @@ void Optimizer::Initialize(const ModelIdentifiers& model_identifiers, } Status Optimizer::Step() { + if (delay_optimizer_state_contruction_) { + ORT_RETURN_IF_ERROR(ConstructOptimizerStateAndInputs()); + } + OrtValue learning_rate_input, step_input; utils::WrapInOrtValue(optimizer_state_->learning_rate, &learning_rate_input); // Use step count + 1 before running optimizer step. @@ -375,6 +388,17 @@ Status Optimizer::LoadStateDict(OptimizerCheckpointState& optimizer_checkpoint_s return Status::OK(); } +Status Optimizer::ConstructOptimizerStateAndInputs() { + ORT_RETURN_IF(state_->module_checkpoint_state.is_nominal_state, + "The optimizer state cannot be constructed. Please load the model parameters first."); + ORT_RETURN_IF_ERROR(GenerateMomentumNamedStates(state_->optimizer_checkpoint_state)); + ORT_RETURN_IF_ERROR(ConstructInputs()); + + delay_optimizer_state_contruction_ = false; + + return Status::OK(); +} + } // namespace api } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/training_api/optimizer.h b/orttraining/orttraining/training_api/optimizer.h index d9bc4870bb7ed..031b11426539b 100644 --- a/orttraining/orttraining/training_api/optimizer.h +++ b/orttraining/orttraining/training_api/optimizer.h @@ -123,6 +123,15 @@ struct Optimizer { return Status::OK(); } + // Constructs the optimizer state and prepares the model inputs. + // This is called once during the construction of the Optimizer if the model state is available. + // In case the optimizer was instantiated with a nominal checkpoint, this function must be + // called when the model state is available. + // The optimizer checks if the optimizer state needs to be constructed in the train step function. + // However, this is exposed as a public function in case the user wants to construct the optimizer + // state before the train step function is called. + Status ConstructOptimizerStateAndInputs(); + private: void Initialize(const ModelIdentifiers& model_identifiers, const std::vector>& providers, @@ -134,8 +143,7 @@ struct Optimizer { // Generates optimizer momentum states for parameters that require grad. Status GenerateMomentumNamedStates(OptimizerCheckpointState& optimizer_checkpoint_states); - // Constructs the ortvalue inputs to be fed to the graph - // at each step. + // Constructs the ortvalue inputs to be fed to the graph at each step. Status ConstructInputs(); /** @@ -160,6 +168,8 @@ struct Optimizer { InlinedVector inputs_; int32_t group_count_{0}; + + bool delay_optimizer_state_contruction_{false}; }; } // namespace api diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc index 45f0f0ddcf7f4..78619947b8b18 100644 --- a/orttraining/orttraining/training_api/training_session.cc +++ b/orttraining/orttraining/training_api/training_session.cc @@ -112,7 +112,16 @@ Status TrainingSession::CopyParametersToBuffer(OrtValue& parameters_buffer, cons } Status TrainingSession::CopyBufferToParameters(OrtValue& parameters_buffer, const bool trainable_only) { - return module_->CopyBufferToParameters(parameters_buffer, trainable_only); + const bool was_nominal_state = state_->module_checkpoint_state.is_nominal_state; + ORT_RETURN_IF_ERROR(module_->CopyBufferToParameters(parameters_buffer, trainable_only)); + + // If the checkpoint state was nominal before loading the params, then we need to construct the + // optimizer state and inputs. + if (was_nominal_state) { + ORT_RETURN_IF_ERROR(optimizer_->ConstructOptimizerStateAndInputs()); + } + + return Status::OK(); } #if !defined(ORT_MINIMAL_BUILD) From d87f73ab44a2a038f396cc0a38a7062c4516c59c Mon Sep 17 00:00:00 2001 From: zesongw Date: Wed, 31 Jan 2024 16:20:07 +0800 Subject: [PATCH 013/207] [WebNN EP] Use GetVecUint32FromVecInt64 to simplify the code (#19324) - Use the function `GetVecUint32FromVecInt64` in helper.h to replace `transform`. - Change some `int32_t` to `uint32_t`. - Remove a useless `temp`. --- .../webnn/builders/impl/normalization_op_builder.cc | 5 +---- .../core/providers/webnn/builders/impl/pool_op_builder.cc | 7 +++---- .../core/providers/webnn/builders/impl/split_op_builder.cc | 5 +---- .../webnn/builders/impl/squeeze_unsqueeze_op_builder.cc | 5 +---- .../providers/webnn/builders/impl/transpose_op_builder.cc | 5 +---- .../core/providers/webnn/webnn_execution_provider.cc | 3 --- 6 files changed, 7 insertions(+), 23 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 4d2470dfe7deb..50e04df4fe0f2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -125,10 +125,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder output = model_builder.GetBuilder().call("instanceNormalization", input, options); // Reshape back to the original output shape for 3D input. if (input_shape.size() != 4) { - std::vector output_shape; - std::transform(input_shape.begin(), input_shape.end(), - std::back_inserter(output_shape), - [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + std::vector output_shape = GetVecUint32FromVecInt64(input_shape); output = model_builder.GetBuilder().call( "reshape", output, emscripten::val::array(output_shape)); } diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index 739c3b3f38def..8b3eecf35fcc8 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -81,7 +81,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto onnx_kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); const auto onnx_strides = helper.Get("strides", std::vector{1, 1}); const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); @@ -94,12 +94,11 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, auto_pad_type, pads_out, model_builder.GetPreferredLayout() == DataLayout::NCHW)); - std::transform(pads_out.begin(), pads_out.end(), pads.begin(), - [](int64_t pad) -> int32_t { return static_cast(pad); }); + pads = GetVecUint32FromVecInt64(pads_out); } // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. - const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; + const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; options.set("padding", emscripten::val::array(padding)); const auto ceil_mode = helper.Get("ceil_mode", 0); diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index e9a600a5933af..91f21b196be54 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -83,10 +83,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector mapping_split; mapping_split.insert(mapping_split.begin(), num_outputs - 1, input_shape[axis] / num_outputs); mapping_split.insert(mapping_split.end(), input_shape[axis] % num_outputs); - std::vector converted_splits; - std::transform(mapping_split.cbegin(), mapping_split.cend(), - std::back_inserter(converted_splits), - [](int64_t dim) -> int32_t { return SafeInt(dim); }); + std::vector converted_splits = GetVecUint32FromVecInt64(mapping_split); output_array = model_builder.GetBuilder().call("split", input, emscripten::val::array(converted_splits), diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 9f6ccb98f79dd..15149bd8fe821 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -87,10 +87,7 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil emscripten::val output = emscripten::val::undefined(); // Use WebNN's reshape to implement Squeeze/Unsqueeze. - std::vector new_shape; - std::transform( - input_shape.begin(), input_shape.end(), std::back_inserter(new_shape), - [](int64_t data) -> uint32_t { return SafeInt(data); }); + std::vector new_shape = GetVecUint32FromVecInt64(input_shape); // Sort axes_data in ascending order. std::sort(axes_data.begin(), axes_data.end()); if (op_type == "Squeeze") { diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc index eca1521384643..79f60c51ace1b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -40,10 +40,7 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val options = emscripten::val::object(); - std::vector permutation; - std::transform(perm.cbegin(), perm.cend(), - std::back_inserter(permutation), - [](int64_t dim) -> int32_t { return SafeInt(dim); }); + std::vector permutation = GetVecUint32FromVecInt64(perm); options.set("permutation", emscripten::val::array(permutation)); emscripten::val output = model_builder.GetBuilder().call("transpose", input, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index cfb96af557d35..29c8ca91fe72c 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -282,9 +282,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector temp(shape.size()); - transform(shape.begin(), shape.end(), temp.begin(), - [](int64_t dim) -> uint32_t { return SafeInt(dim); }); const void* inputBuffer = const_cast(input_tensor.GetTensorRawData()); inputs.emplace( input_name, From 2b361c04d68ec21bdd1ede33bcdf7a142bd292d1 Mon Sep 17 00:00:00 2001 From: Phoebe Chen Date: Thu, 1 Feb 2024 02:12:43 +0800 Subject: [PATCH 014/207] Fix Flatbuffer build issue. (#19296) ### Description Building on g++ 13.2.0 results in -Wstringop-overread errors on Linux. This commit addresses the flatbuffer build issue with the following changes: 1. Remove the Werror flag in the flarbuffer patch. 2. Add a compilation option to suppress the 'stringop-overflow' error in the Flatbuffers within the xnnpack provider. ### Motivation and Context https://github.com/google/flatbuffers/issues/8119 https://github.com/microsoft/onnxruntime/pull/19239 Signed-off-by: Phoebe Chen --- cmake/CMakeLists.txt | 1 + cmake/onnxruntime_providers_xnnpack.cmake | 6 ++++++ cmake/patches/flatbuffers/flatbuffers.patch | 2 +- 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0eb224623f678..c8a88442fa746 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -641,6 +641,7 @@ else() check_cxx_compiler_flag(-Wunused-but-set-variable HAS_UNUSED_BUT_SET_VARIABLE) check_cxx_compiler_flag(-Wunused-variable HAS_UNUSED_VARIABLE) check_cxx_compiler_flag(-Wuseless-cast HAS_USELESS_CAST) + check_cxx_compiler_flag(-Wstringop-overflow HAS_STRINGOP_OVERFLOW) check_function_exists(reallocarray HAS_REALLOCARRAY) if (NOT APPLE AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_target_platform STREQUAL "aarch64") check_cxx_compiler_flag(-march=armv8.2-a+bf16 HAS_ARM64_BFLOAT16) diff --git a/cmake/onnxruntime_providers_xnnpack.cmake b/cmake/onnxruntime_providers_xnnpack.cmake index 9c00703ca0846..6342c24b2917e 100644 --- a/cmake/onnxruntime_providers_xnnpack.cmake +++ b/cmake/onnxruntime_providers_xnnpack.cmake @@ -19,6 +19,12 @@ flatbuffers::flatbuffers Boost::mp11 safeint_interface ) + # TODO fix stringop-overflow warnings + # Add compile option to suppress stringop-overflow error in Flatbuffers. + if (HAS_STRINGOP_OVERFLOW) + target_compile_options(onnxruntime_providers_xnnpack PRIVATE -Wno-error=stringop-overflow) + endif() + add_dependencies(onnxruntime_providers_xnnpack onnx ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_xnnpack PROPERTIES FOLDER "ONNXRuntime") diff --git a/cmake/patches/flatbuffers/flatbuffers.patch b/cmake/patches/flatbuffers/flatbuffers.patch index fb2678ef1bdce..f141d358c54b6 100644 --- a/cmake/patches/flatbuffers/flatbuffers.patch +++ b/cmake/patches/flatbuffers/flatbuffers.patch @@ -7,7 +7,7 @@ index 3987eac9..5e5462f1 100644 endif(CYGWIN) set(CMAKE_CXX_FLAGS - "${CMAKE_CXX_FLAGS} -Wall -pedantic -Werror -Wextra -Werror=shadow") -+ "${CMAKE_CXX_FLAGS} -Wall -pedantic -Werror -Wextra -Werror=shadow -Wno-error=stringop-overflow") ++ "${CMAKE_CXX_FLAGS} -Wall -pedantic -Wextra -Werror=shadow -Wno-error=stringop-overflow") set(FLATBUFFERS_PRIVATE_CXX_FLAGS "-Wold-style-cast") if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.4) if (CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 7.0) From ca8d4459d44c55817b61178b4b573eb645b6c4ec Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 31 Jan 2024 10:38:01 -0800 Subject: [PATCH 015/207] Add contrib Q/DQ ops to symbolic shape inference tool (#19340) ### Description Adds type/shape inferencing support for MSFT domain QuantizeLinear and DequantizeLinear operators to symbolic_shape_infer.py ### Motivation and Context Need a way to infer the types and shapes of Q/DQ ops in models that use the MSFT domain versions (e.g., int16 quantization). --- .../python/tools/symbolic_shape_infer.py | 27 +++ ...untime_test_python_symbolic_shape_infer.py | 202 ++++++++++++++++++ 2 files changed, 229 insertions(+) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ef4c4ae906243..9823e8264e17b 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -197,6 +197,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "BiasGelu": self._infer_BiasGelu, "BiasSplitGelu": self._infer_BiasSplitGelu, "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention, + "DequantizeLinear": self._infer_DequantizeLinear, "EmbedLayerNormalization": self._infer_EmbedLayerNormalization, "FastGelu": self._infer_FastGelu, "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, @@ -212,6 +213,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "PackedAttention": self._infer_PackedAttention, "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, "PythonOp": self._infer_PythonOp, + "QuantizeLinear": self._infer_QuantizeLinear, "QuickGelu": self._infer_FastGelu, "RelativePositionBias": self._infer_RelativePositionBias, "RemovePadding": self._infer_RemovePadding, @@ -457,6 +459,8 @@ def _onnx_infer_single_node(self, node): "GemmFastGelu", "LayerNormalization", "LongformerAttention", + "DequantizeLinear", + "QuantizeLinear", "RelativePositionBias", "RemovePadding", "RestorePadding", @@ -979,6 +983,29 @@ def _infer_NhwcConv(self, node): # noqa: N802 ) ) + def _infer_DequantizeLinear(self, node): # noqa: N802 + # Get the output data type from the scale input (index 1, required). + output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type + + # Get the output shape from the first input. + output_shape = self._get_shape(node, 0) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + + def _infer_QuantizeLinear(self, node): # noqa: N802 + # Get the output data type from the zero-point input (index 2, optional). + # Otherwise, default to uint8 + output_dtype = onnx.TensorProto.UINT8 + if len(node.input) > 2 and node.input[2]: + output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type + + # Get the output shape from the first input. + output_shape = self._get_shape(node, 0) + + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape)) + def _infer_Einsum(self, node): # noqa: N802 # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 equation = get_attribute(node, "equation") diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index 67db411ddc246..eca1430448e8e 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -392,6 +392,208 @@ def test_div_precision(self): self.assertEqual(len(output_dims), 1) self.assertEqual(output_dims[0].dim_value, 512) + def test_quantize_linear(self): + """ + Test ONNX QuantizeLinear op. + Check that the output shape is propagated from the first input and that the output data + type comes from the zero-point input. + """ + initializers = [ + helper.make_tensor( + "scale", + TensorProto.FLOAT, + [], + [1.0], + ), + helper.make_tensor( + "zero_point", + TensorProto.INT8, + [], + [16], + ), + ] + + nodes = [ + helper.make_node( + "QuantizeLinear", + inputs=[ + "input_f32", + "scale", + "zero_point", + ], + outputs=["output_s8"], + ), + ] + + inputs = [ + helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["b", 2, 3, 4]), + ] + + outputs = [ + helper.make_tensor_value_info("output_s8", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "QuantizeLinear_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + expected_shapes = [ + helper.make_tensor_value_info("output_s8", TensorProto.INT8, ["b", 2, 3, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + + def test_quantize_linear_ms_domain(self): + """ + Test QuantizeLinear op ('com.microsoft' domain). + Check that the output shape is propagated from the first input and that the output data + type comes from the zero-point input. + """ + initializers = [ + helper.make_tensor( + "scale", + TensorProto.FLOAT, + [], + [1.0], + ), + helper.make_tensor( + "zero_point", + TensorProto.UINT16, + [], + [16], + ), + ] + + nodes = [ + helper.make_node( + "QuantizeLinear", + inputs=[ + "input_f32", + "scale", + "zero_point", + ], + outputs=["output_u16"], + domain="com.microsoft", + ), + ] + + inputs = [ + helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["b", 2, 3, 4]), + ] + + outputs = [ + helper.make_tensor_value_info("output_u16", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "QuantizeLinear_MSDomain_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + expected_shapes = [ + helper.make_tensor_value_info("output_u16", TensorProto.UINT16, ["b", 2, 3, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + + def test_quantize_linear_no_zp_input(self): + """ + Test QuantizeLinear op ('com.microsoft' domain). + Check that the output shape is propagated from the first input. + The zero-point input is missing, so the output data type should default to uint8. + """ + initializers = [ + helper.make_tensor( + "scale", + TensorProto.FLOAT, + [], + [1.0], + ), + ] + + nodes = [ + helper.make_node( + "QuantizeLinear", + inputs=[ + "input_f32", + "scale", + ], + outputs=["output_u8"], + domain="com.microsoft", + ), + ] + + inputs = [ + helper.make_tensor_value_info("input_f32", TensorProto.FLOAT, ["b", 2, 3, 4]), + ] + + outputs = [ + helper.make_tensor_value_info("output_u8", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "QuantizeLinear_NoZP_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + # Check that the output shape is propagated from the first input and that the + # output data type comes from the zero-point input. + expected_shapes = [ + helper.make_tensor_value_info("output_u8", TensorProto.UINT8, ["b", 2, 3, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + + def test_dequantize_linear_ms_domain(self): + """ + Test DequantizeLinear operator ('com.microsoft' domain). + Check that the output shape is propagated from the first input and that the output data + type comes from the scale input. + """ + initializers = [ + helper.make_tensor( + "scale", + TensorProto.FLOAT, + [], + [1.0], + ), + helper.make_tensor( + "zero_point", + TensorProto.UINT16, + [], + [16], + ), + ] + + nodes = [ + helper.make_node( + "DequantizeLinear", + inputs=[ + "input_u16", + "scale", + "zero_point", + ], + outputs=["output_f32"], + domain="com.microsoft", + ), + ] + + inputs = [ + helper.make_tensor_value_info("input_u16", TensorProto.UINT16, ["b", 2, 3, 4]), + ] + + outputs = [ + helper.make_tensor_value_info("output_f32", TensorProto.UNDEFINED, None), + ] + + graph = helper.make_graph(nodes, "DequantizeLinear_MSDomain_Test", inputs, outputs, initializers) + model = helper.make_model(graph) + + inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) + + expected_shapes = [ + helper.make_tensor_value_info("output_f32", TensorProto.FLOAT, ["b", 2, 3, 4]), + ] + self._check_shapes(graph, inferred.graph, expected_shapes) + class TestSymbolicShapeInferenceForSlice(unittest.TestCase): def check_slice_of_concat(self, input_dims, start, end, step, expected_output_dim): From 55b60d8fe04408f7326704f87b48c9ba318091e6 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Wed, 31 Jan 2024 13:40:25 -0800 Subject: [PATCH 016/207] Turn off Neural Speed to avoid slowdowns (#19265) Disable Neural Speed to prevent the operation following MatMulNBits from significantly slowing down. --- cmake/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index c8a88442fa746..34e7687e91876 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -88,7 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) -option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON) +option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) From 68b6064be66cc16bae7e624a554e641cf21d2b06 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 1 Feb 2024 08:02:12 +1000 Subject: [PATCH 017/207] Fix reporting of unused initializers in subgraphs (#19341) ### Description Increment num_resolves_ inside the graph resolve finalization function so the subgraphs have the same value. This prevents incorrect output regarding removing unused initializers. ### Motivation and Context #19141 --- onnxruntime/core/graph/graph.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index f71b7ecebcf1a..902839bee04ba 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -2798,12 +2798,13 @@ Status Graph::Resolve(const ResolveOptions& options) { graph.GraphProtoSyncNeeded(false); } + // set num_resolves_ here so the graph and any subgraphs all have the same value + ++graph.num_resolves_; + return Status::OK(); }; ORT_RETURN_IF_ERROR(ForThisAndAllSubgraphs(all_subgraphs, finalize_func)); - ++num_resolves_; - return Status::OK(); } From 1d6f13fb92204df344301ccc12a5f292a0cc44ed Mon Sep 17 00:00:00 2001 From: Yueqing Zhang Date: Thu, 1 Feb 2024 13:08:26 +0800 Subject: [PATCH 018/207] [VitisAI] Refactor the VAIEP to use MSFT's standalone API (#19058) ### Description Refactor the VAIEP to use MSFT's standalone API ### Motivation and Context Vitis ONNX RT VAI should switch to using the standalone API for ONNX EPs in order to decouple the EP from onnxruntime.dll and the providers.dll. This will help to simplify customer deployment of applications and use cases that need to share their onnxruntime.dll with other applications. --------- Co-authored-by: Zhenze Wang Co-authored-by: zz002 --- cmake/onnxruntime.cmake | 1 - cmake/onnxruntime_providers_vitisai.cmake | 32 +- cmake/onnxruntime_python.cmake | 11 +- cmake/onnxruntime_unittests.cmake | 1 - .../core/session/onnxruntime_c_api.h | 17 + .../core/session/onnxruntime_cxx_api.h | 3 + .../core/session/onnxruntime_cxx_inline.h | 19 + .../providers/provider_factory_creators.h | 4 - .../providers/shared_library/provider_api.h | 9 +- .../provider_bridge_provider.cc | 4 + .../shared_library/provider_interfaces.h | 87 +++++ .../shared_library/provider_wrappedtypes.h | 117 +++++- .../core/providers/vitisai/imp/attr_proto.cc | 120 +++--- .../core/providers/vitisai/imp/attr_proto.h | 46 +-- .../core/providers/vitisai/imp/capability.cc | 73 ++-- .../core/providers/vitisai/imp/global_api.cc | 367 +++++++----------- .../core/providers/vitisai/imp/graph.cc | 127 +++--- .../core/providers/vitisai/imp/node.cc | 11 +- .../core/providers/vitisai/imp/node_arg.cc | 155 ++------ .../core/providers/vitisai/imp/node_attrs.cc | 114 ------ .../providers/vitisai/imp/register_xir_ops.cc | 117 +----- .../providers/vitisai/imp/tensor_proto.cc | 100 ++--- .../core/providers/vitisai/imp/tensor_proto.h | 41 +- .../vitisai/include/vaip/capability.h | 3 +- .../vitisai/include/vaip/global_api.h | 13 +- .../providers/vitisai/include/vaip/graph.h | 22 +- .../providers/vitisai/include/vaip/my_ort.h | 26 +- .../providers/vitisai/include/vaip/node.h | 8 - .../providers/vitisai/include/vaip/node_arg.h | 7 +- .../vitisai/include/vaip/node_attrs.h | 46 --- .../vitisai/include/vaip/vaip_ort_api.h | 8 +- .../core/providers/vitisai/symbols.def | 2 + .../core/providers/vitisai/version_script.lds | 9 + .../vitisai/vitisai_execution_provider.cc | 71 +--- .../vitisai/vitisai_execution_provider.h | 4 +- .../vitisai/vitisai_provider_factory.cc | 38 +- onnxruntime/core/session/onnxruntime_c_api.cc | 1 + onnxruntime/core/session/ort_apis.h | 4 + .../core/session/provider_bridge_ort.cc | 242 +++++++++++- .../core/session/provider_registration.cc | 16 +- .../python/onnxruntime_pybind_state.cc | 2 +- setup.py | 3 + 42 files changed, 1000 insertions(+), 1101 deletions(-) delete mode 100644 onnxruntime/core/providers/vitisai/imp/node_attrs.cc delete mode 100644 onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h create mode 100644 onnxruntime/core/providers/vitisai/symbols.def create mode 100644 onnxruntime/core/providers/vitisai/version_script.lds diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index c900f4d4b09a5..2ead13e554197 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -189,7 +189,6 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_SNPE} ${PROVIDERS_TVM} ${PROVIDERS_RKNPU} - ${PROVIDERS_VITISAI} ${PROVIDERS_XNNPACK} ${PROVIDERS_WEBNN} ${PROVIDERS_AZURE} diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake index 0951c2d02664d..183a3e196af42 100644 --- a/cmake/onnxruntime_providers_vitisai.cmake +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -14,14 +14,19 @@ "${ONNXRUNTIME_ROOT}/core/providers/vitisai/*.h" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) - onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto) - target_link_libraries(onnxruntime_providers_vitisai PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json) - if(NOT MSVC) - target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) - endif(NOT MSVC) + onnxruntime_add_shared_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_vitisai ${ONNXRUNTIME_PROVIDERS_SHARED} nlohmann_json::nlohmann_json safeint_interface flatbuffers::flatbuffers) + target_link_libraries(onnxruntime_providers_vitisai PRIVATE ${ONNXRUNTIME_PROVIDERS_SHARED}) + if(MSVC) + onnxruntime_add_include_to_target(onnxruntime_providers_vitisai dbghelp) + set_property(TARGET onnxruntime_providers_vitisai APPEND_STRING PROPERTY LINK_FLAGS "-DEF:${ONNXRUNTIME_ROOT}/core/providers/vitisai/symbols.def") + else(MSVC) + set_property(TARGET onnxruntime_providers_vitisai APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/vitisai/version_script.lds -Xlinker --gc-sections") + endif(MSVC) target_include_directories(onnxruntime_providers_vitisai PRIVATE "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include" ${XRT_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/VitisAI) if(MSVC) @@ -30,17 +35,18 @@ target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4251") # for unused formal parameter target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4100") + # for type name first seen using 'class' now seen using 'struct' + target_compile_options(onnxruntime_providers_vitisai PRIVATE "/wd4099") else(MSVC) + target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) target_compile_options(onnxruntime_providers_vitisai PRIVATE -Wno-unused-parameter) endif(MSVC) set_target_properties(onnxruntime_providers_vitisai PROPERTIES FOLDER "ONNXRuntime") set_target_properties(onnxruntime_providers_vitisai PROPERTIES LINKER_LANGUAGE CXX) - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_vitisai - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() + install(TARGETS onnxruntime_providers_vitisai + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 2e3594f256f65..456344aa34d95 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -170,7 +170,6 @@ target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_session ${onnxruntime_libs} ${PROVIDERS_TVM} - ${PROVIDERS_VITISAI} ${PROVIDERS_NNAPI} ${PROVIDERS_XNNPACK} ${PROVIDERS_COREML} @@ -852,6 +851,16 @@ if (onnxruntime_USE_DNNL) ) endif() +if (onnxruntime_USE_VITISAI) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${DNNL_DLL_PATH} $ + $ + $/onnxruntime/capi/ + ) +endif() + if (onnxruntime_USE_TENSORRT) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 714f35380ca02..6a4551ad94d9e 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -591,7 +591,6 @@ set(ONNXRUNTIME_TEST_LIBS # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime ${PROVIDERS_NNAPI} ${PROVIDERS_JS} - ${PROVIDERS_VITISAI} ${PROVIDERS_QNN} ${PROVIDERS_SNPE} ${PROVIDERS_RKNPU} diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2ce9d361e8e56..5577c840c5379 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4569,6 +4569,23 @@ struct OrtApi { _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + + /** \brief Append VitisAI provider to session options + * + * If VitisAI is not available (due to a non VitisAI enabled build, or if VitisAI is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] provider_options_keys + * \param[in] provider_options_values + * \param[in] num_keys + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 7a553f9f94006..ae4c4bef90c64 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -901,6 +901,9 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs = {}); SessionOptionsImpl& RegisterCustomOpsUsingFunction(const char* function_name); ///< Wraps OrtApi::RegisterCustomOpsUsingFunction + + ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_VitisAI + SessionOptionsImpl& AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options = {}); }; } // namespace detail diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 957e849cf5d4d..23246adff254a 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -885,6 +885,25 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_Ope return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_VitisAI(const std::unordered_map& provider_options) { + auto num_entries = provider_options.size(); + std::vector keys, values; + if (num_entries > 0) { + keys.reserve(num_entries); + values.reserve(num_entries); + + for (const auto& entry : provider_options) { + keys.push_back(entry.first.c_str()); + values.push_back(entry.second.c_str()); + } + } + + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_VitisAI(this->p_, keys.data(), values.data(), num_entries)); + + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::RegisterCustomOpsLibrary(const ORTCHAR_T* library_name, const CustomOpConfigs& custom_op_configs) { diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index 42a58097e1635..6a4ab6a3d2113 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -78,10 +78,6 @@ #include "core/providers/tvm/tvm_provider_factory_creator.h" #endif -#if defined(USE_VITISAI) -#include "core/providers/vitisai/vitisai_provider_factory_creator.h" -#endif - #if defined(USE_XNNPACK) #include "core/providers/xnnpack/xnnpack_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 1e3a528d87721..b78279040acb6 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -95,12 +95,15 @@ enum OperatorStatus : int { }; // onnx Protobuf types (All of these are direct mappings to the onnx types except for the Repeated*Field ones which map to a Repeated*Field type) -struct int64s; // RepeatedField +struct int64s; // RepeatedField +struct float32s; // RepeatedField struct AttributeProto; struct GraphProto; struct ModelProto; struct NodeProto; struct SparseTensorProto; +struct StringStringEntryProto; +struct StringStringEntryProtos; // RepeatedPtrField struct TensorProto; struct TensorProtos; // RepeatedPtrField struct TensorShapeProto_Dimension; @@ -113,6 +116,9 @@ struct TypeProto_Sequence; struct TypeProto; struct ValueInfoProto; struct ValueInfoProtos; // RepeatedPtrField +struct InferenceContext; +class GraphInferencer; +using InferenceFunction = std::function; } // namespace ONNX_NAMESPACE namespace onnxruntime { @@ -249,6 +255,7 @@ constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; constexpr const char* kCannExecutionProvider = "CANNExecutionProvider"; constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; +constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider"; constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider"; constexpr const char* kTensorrtExecutionProvider = "TensorrtExecutionProvider"; constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider"; diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 6dbe103791e43..da17135878fe5 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -492,6 +492,10 @@ template <> Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } template <> Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } +Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, + /*out*/ std::vector& unpacked_tensor) { + return g_host->UnpackInitializerData(tensor, model_path, unpacked_tensor); +} } // namespace utils diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index a216b2bfc6d04..f5a8327443864 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -91,6 +91,7 @@ using HashValue = uint64_t; using NodeIndex = size_t; // We can't just reinterpret_cast this one, since it's an unordered_map of object BY VALUE (can't do anything by value on the real types) // using NodeAttributes = std::unordered_map; +using ModelMetaData = std::unordered_map; using InitializedTensorSet = std::unordered_map; @@ -201,6 +202,8 @@ struct ProviderHost { virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint32_t* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) = 0; + virtual Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, + /*out*/ std::vector& unpacked_tensor) = 0; virtual uint16_t math__floatToHalf(float f) = 0; virtual float math__halfToFloat(uint16_t h) = 0; @@ -261,12 +264,32 @@ struct ProviderHost { virtual void logging__Capture__operator_delete(logging::Capture* p) noexcept = 0; virtual std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept = 0; + // Env + virtual Env& Env__Default() = 0; + // Utils::DataTypeUtils virtual const std::string* Utils__DataTypeUtils__ToType(const ONNX_NAMESPACE::TypeProto& type_proto) = 0; // int64s virtual int int64s__size(const ONNX_NAMESPACE::int64s* p) = 0; virtual const int64_t& int64s__Get(const ONNX_NAMESPACE::int64s* p, int index) = 0; + virtual void int64s__Reserve(ONNX_NAMESPACE::int64s* p, int size) = 0; + virtual const int64_t* int64s__data(const ONNX_NAMESPACE::int64s* p) = 0; + + // float32s + virtual void float32s__Reserve(ONNX_NAMESPACE::float32s* p, int size) = 0; + virtual const float* float32s__data(const ONNX_NAMESPACE::float32s* p) = 0; + virtual int float32s__size(const ONNX_NAMESPACE::float32s* p) = 0; + + // StringStringEntryProto + virtual std::string* StringStringEntryProto__mutable_key(ONNX_NAMESPACE::StringStringEntryProto* p) = 0; + virtual std::string* StringStringEntryProto__mutable_value(ONNX_NAMESPACE::StringStringEntryProto* p) = 0; + + // StringStringEntryProtos + virtual void StringStringEntryProtos__Clear(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProto* StringStringEntryProtos__Add(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; + virtual int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) = 0; #if !defined(DISABLE_OPTIONAL_TYPE) // TypeProto_Optional @@ -283,6 +306,7 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::TensorShapeProto& TypeProto_Tensor__shape(const ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; virtual ONNX_NAMESPACE::TensorShapeProto* TypeProto_Tensor__mutable_shape(ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; virtual int32_t TypeProto_Tensor__elem_type(const ONNX_NAMESPACE::TypeProto_Tensor* p) = 0; + virtual void TypeProto_Tensor__set_elem_type(ONNX_NAMESPACE::TypeProto_Tensor* p, int32_t value) = 0; #if !defined(DISABLE_SPARSE_TENSORS) // TypeProto_SparseTensor @@ -327,9 +351,17 @@ struct ProviderHost { virtual float AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p, int i) = 0; virtual const ::std::string& AttributeProto__strings(const ONNX_NAMESPACE::AttributeProto* p, int i) = 0; virtual const ONNX_NAMESPACE::int64s& AttributeProto__ints(const ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual const ONNX_NAMESPACE::float32s& AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual ONNX_NAMESPACE::int64s* AttributeProto__mutable_ints(ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual ONNX_NAMESPACE::float32s* AttributeProto__mutable_floats(ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual void AttributeProto__add_ints(ONNX_NAMESPACE::AttributeProto* p, int64_t size) = 0; + virtual void AttributeProto__add_floats(ONNX_NAMESPACE::AttributeProto* p, float size) = 0; + virtual void AttributeProto__add_strings(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& size) = 0; virtual int64_t AttributeProto__i(const ONNX_NAMESPACE::AttributeProto* p) = 0; virtual float AttributeProto__f(const ONNX_NAMESPACE::AttributeProto* p) = 0; + virtual const ONNX_NAMESPACE::TensorProto& AttributeProto__t(const ONNX_NAMESPACE::AttributeProto* p) = 0; virtual void AttributeProto__set_s(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0; + virtual void AttributeProto__set_f(ONNX_NAMESPACE::AttributeProto* p, const float& value) = 0; virtual void AttributeProto__set_i(ONNX_NAMESPACE::AttributeProto* p, int64_t value) = 0; virtual const ::std::string& AttributeProto__s(const ONNX_NAMESPACE::AttributeProto* p) = 0; virtual void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0; @@ -352,6 +384,7 @@ struct ProviderHost { virtual ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_value_info(ONNX_NAMESPACE::GraphProto* p) = 0; virtual ONNX_NAMESPACE::TensorProtos* GraphProto__mutable_initializer(ONNX_NAMESPACE::GraphProto* p) = 0; virtual ONNX_NAMESPACE::NodeProto* GraphProto__add_node(ONNX_NAMESPACE::GraphProto* p) = 0; + virtual std::string* GraphProto__mutable_name(ONNX_NAMESPACE::GraphProto* p) = 0; virtual ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) = 0; // ModelProto @@ -367,6 +400,7 @@ struct ProviderHost { virtual ONNX_NAMESPACE::GraphProto* ModelProto__mutable_graph(ONNX_NAMESPACE::ModelProto* p) = 0; virtual void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) = 0; // NodeProto virtual std::unique_ptr NodeProto__construct() = 0; @@ -381,19 +415,33 @@ struct ProviderHost { virtual void TensorProto__operator_delete(ONNX_NAMESPACE::TensorProto* p) = 0; virtual void TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) = 0; virtual bool TensorProto__has_name(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__set_name(ONNX_NAMESPACE::TensorProto* p, const ::std::string& name) = 0; + virtual const ::std::string& TensorProto__name(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual int TensorProto__dims_size(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual const ONNX_NAMESPACE::int64s& TensorProto__dims(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__add_dims(ONNX_NAMESPACE::TensorProto* p, int64_t value) = 0; virtual bool TensorProto__has_data_location(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual int TensorProto__data_location(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0; virtual const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual std::string* TensorProto__mutable_raw_data(ONNX_NAMESPACE::TensorProto* p) = 0; virtual int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__set_data_type(ONNX_NAMESPACE::TensorProto* p, int32_t type) = 0; virtual void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) = 0; + virtual ONNX_NAMESPACE::StringStringEntryProtos* TensorProto__mutable_external_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_float_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_int32_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_string_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_int64_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_double_data(ONNX_NAMESPACE::TensorProto* p) = 0; + virtual void TensorProto__clear_uint64_data(ONNX_NAMESPACE::TensorProto* p) = 0; virtual bool TensorProto_DataType_IsValid(int value) = 0; // TensorProtos virtual ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) = 0; + virtual int TensorProtos__size(ONNX_NAMESPACE::TensorProtos* p) = 0; + virtual ONNX_NAMESPACE::TensorProto& TensorProtos__at(ONNX_NAMESPACE::TensorProtos* p, int index) = 0; // TensorShapeProto_Dimension virtual int TensorShapeProto_Dimension__value_case(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; @@ -403,6 +451,8 @@ struct ProviderHost { virtual bool TensorShapeProto_Dimension__has_dim_value(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; virtual bool TensorShapeProto_Dimension__has_dim_param(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; virtual void TensorShapeProto_Dimension__clear_dim_value(ONNX_NAMESPACE::TensorShapeProto_Dimension* p) = 0; + virtual const std::string& TensorShapeProto_Dimension__denotation(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) const = 0; + virtual void TensorShapeProto_Dimension__set_denotation(ONNX_NAMESPACE::TensorShapeProto_Dimension* p, const std::string& value) = 0; // TensorShapeProto_Dimensions virtual std::unique_ptr TensorShapeProto_Dimensions__begin(const ONNX_NAMESPACE::TensorShapeProto_Dimensions* p) = 0; @@ -426,6 +476,8 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) = 0; + virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0; + // ConfigOptions virtual std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) = 0; @@ -651,6 +703,7 @@ struct ProviderHost { virtual void Node__ToProto(const Node* p, ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) = 0; virtual const NodeAttributes& Node__GetAttributes(const Node* p) noexcept = 0; + virtual void Node__AddAttribute(Node* p, const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) = 0; virtual size_t Node__GetInputEdgesCount(const Node* p) noexcept = 0; virtual size_t Node__GetOutputEdgesCount(const Node* p) noexcept = 0; @@ -660,10 +713,13 @@ struct ProviderHost { virtual std::unique_ptr Node__OutputNodesBegin(const Node* p) noexcept = 0; virtual std::unique_ptr Node__OutputNodesEnd(const Node* p) noexcept = 0; + virtual std::unique_ptr Node__InputEdgesBegin(const Node* p) noexcept = 0; + virtual std::unique_ptr Node__InputEdgesEnd(const Node* p) noexcept = 0; virtual std::unique_ptr Node__OutputEdgesBegin(const Node* p) noexcept = 0; virtual std::unique_ptr Node__OutputEdgesEnd(const Node* p) noexcept = 0; virtual void Node__ForEachDef(const Node* p, std::function func, bool include_missing_optional_defs) = 0; + virtual int Node__NodeType(const Node* p) const noexcept = 0; virtual const std::unordered_map>& Node__GetAttributeNameToMutableSubgraphMap(Node* p) = 0; virtual std::unordered_map> Node__GetAttributeNameToSubgraphMap(const Node* p) const = 0; @@ -674,6 +730,7 @@ struct ProviderHost { virtual const ONNX_NAMESPACE::NodeArgInfo& NodeArg__ToProto(const NodeArg* p) noexcept = 0; virtual bool NodeArg__Exists(const NodeArg* p) const noexcept = 0; virtual const ONNX_NAMESPACE::TypeProto* NodeArg__TypeAsProto(const NodeArg* p) noexcept = 0; + virtual Status NodeArg__OverrideTypesHelper(NodeArg* p, const ONNX_NAMESPACE::TypeProto& input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types) = 0; // NodeAttributes virtual std::unique_ptr NodeAttributes__construct() = 0; @@ -691,12 +748,18 @@ struct ProviderHost { virtual std::unique_ptr NodeAttributes__find(const NodeAttributes* p, const std::string& key) = 0; virtual void NodeAttributes__insert(NodeAttributes* p, const NodeAttributes& v) = 0; virtual void NodeAttributes__emplace(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) = 0; + virtual void NodeAttributes__insert_or_assign(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) = 0; virtual void NodeAttributes__reserve(NodeAttributes* p, size_t size) = 0; // Model + virtual std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, + const PathString& model_path, const logging::Logger& logger) = 0; virtual void Model__operator_delete(Model* p) = 0; virtual Graph& Model__MainGraph(Model* p) = 0; virtual std::unique_ptr Model__ToProto(Model* p) = 0; + virtual std::unique_ptr Model__ToGraphProtoWithExternalInitializers(Model* p, const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) = 0; + virtual const ModelMetaData& Model__MetaData(const Model* p) const noexcept = 0; + virtual Status Model__Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) = 0; // Graph virtual std::unique_ptr Graph__CreateGraphViewer(const Graph* p) = 0; @@ -714,6 +777,7 @@ struct ProviderHost { virtual void Graph__SetOutputs(Graph* p, gsl::span outputs) = 0; virtual const std::vector& Graph__GetInputs(const Graph* p) noexcept = 0; + virtual std::vector Graph__Nodes(const Graph* p) = 0; virtual bool Graph__GetInitializedTensor(const Graph* p, const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) = 0; virtual const Node* Graph__ParentNode(const Graph* p) const = 0; @@ -723,6 +787,26 @@ struct ProviderHost { virtual const Path& Graph__ModelPath(const Graph* p) const = 0; virtual const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept = 0; virtual bool Graph__IsSubgraph(const Graph* p) = 0; + virtual const Node* Graph__GetProducerNode(const Graph* p, const std::string& node_arg_name) const = 0; + virtual const Model& Graph__GetModel(const Graph* p) = 0; + virtual void Graph__ReverseDFSFrom(const Graph* p, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& comp, + const std::function& stop) const = 0; + virtual Graph& Graph__SetGraphResolveNeeded(Graph* p) = 0; + virtual void Graph__RemoveInitializedTensor(Graph* p, const std::string& tensor_name) = 0; + + virtual std::vector Graph__GetConsumerNodes(const Graph* p, const std::string& node_arg_name) const = 0; + virtual void Graph__AddEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, + int dst_arg_index) = 0; + virtual void Graph__RemoveEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, + int dst_arg_index) = 0; + virtual void Graph__RemoveNode(Graph* p, NodeIndex index) = 0; + virtual Node& Graph__FuseSubGraph(Graph* p, const IndexedSubGraph& sub_graph, const std::string& fused_node_name) = 0; + virtual void Graph__UpdateProducerNode(Graph* p, const std::string& node_arg_name, NodeIndex node_index) = 0; + virtual const ONNX_NAMESPACE::TensorProto* Graph__GetConstantInitializer(const Graph* p, const std::string& name, bool check_outer_scope) const = 0; + virtual const InitializedTensorSet& Graph__GetAllInitializedTensors(const Graph* p) = 0; virtual int Graph__MaxNodeIndex(const Graph* p) const noexcept = 0; virtual Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept = 0; virtual const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const = 0; @@ -757,11 +841,14 @@ struct ProviderHost { virtual const std::vector& GraphViewer__GetInputsIncludingInitializers(const GraphViewer* p) noexcept = 0; virtual void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept = 0; + virtual const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const = 0; // Path virtual PathString Path__ToPathString(const Path* p) noexcept = 0; virtual const std::vector& Path__GetComponents(const Path* p) noexcept = 0; virtual bool Path__IsEmpty(const Path* p) noexcept = 0; + virtual std::unique_ptr Path__construct() = 0; + virtual void Path__operator_delete(ONNX_NAMESPACE::Path* p) = 0; // OpKernel virtual const Node& OpKernel__Node(const OpKernel* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index f46c76fd3421b..dde4005c80b9d 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -52,11 +52,34 @@ namespace ONNX_NAMESPACE { struct int64s final { int size() const { return g_host->int64s__size(this); } const int64_t& Get(int index) const { return g_host->int64s__Get(this, index); } + const int64_t* data() const { return g_host->int64s__data(this); } const int64_t& operator[](int index) const { return Get(index); } - + void Reserve(int size) { g_host->int64s__Reserve(this, size); } PROVIDER_DISALLOW_ALL(int64s) }; +struct float32s final { + void Reserve(int size) { g_host->float32s__Reserve(this, size); } + const float* data() const { return g_host->float32s__data(this); } + int size() const { return g_host->float32s__size(this); } + PROVIDER_DISALLOW_ALL(float32s) +}; + +struct StringStringEntryProto final { + std::string* mutable_key() { return g_host->StringStringEntryProto__mutable_key(this); } + std::string* mutable_value() { return g_host->StringStringEntryProto__mutable_value(this); } + + PROVIDER_DISALLOW_ALL(StringStringEntryProto) +}; + +struct StringStringEntryProtos final { + void Clear() { g_host->StringStringEntryProtos__Clear(this); } + StringStringEntryProto* Add() { return g_host->StringStringEntryProtos__Add(this); } + int size() { return g_host->StringStringEntryProtos__size(this); } + StringStringEntryProto& at(int index) { return g_host->StringStringEntryProtos__at(this, index); } + + PROVIDER_DISALLOW_ALL(StringStringEntryProtos) +}; struct AttributeProto final { static std::unique_ptr Create() { return g_host->AttributeProto__construct(); } void operator=(const AttributeProto& v) { g_host->AttributeProto__operator_assign(this, v); } @@ -71,9 +94,18 @@ struct AttributeProto final { float floats(int i) const { return g_host->AttributeProto__floats(this, i); } const std::string& strings(int i) const { return g_host->AttributeProto__strings(this, i); } const int64s& ints() const { return g_host->AttributeProto__ints(this); } + const float32s& floats() const { return g_host->AttributeProto__floats(this); } + int64s* mutable_ints() { return g_host->AttributeProto__mutable_ints(this); } + float32s* mutable_floats() { return g_host->AttributeProto__mutable_floats(this); } + void add_ints(int64_t value) { g_host->AttributeProto__add_ints(this, value); } + void add_floats(float value) { g_host->AttributeProto__add_floats(this, value); } + void add_strings(const ::std::string& value) { g_host->AttributeProto__add_strings(this, value); } + int64_t i() const { return g_host->AttributeProto__i(this); } float f() const { return g_host->AttributeProto__f(this); } + const ONNX_NAMESPACE::TensorProto& t() const { return g_host->AttributeProto__t(this); } void set_s(const ::std::string& value) { return g_host->AttributeProto__set_s(this, value); } + void set_f(const float& value) { return g_host->AttributeProto__set_f(this, value); } void set_i(int64_t value) { return g_host->AttributeProto__set_i(this, value); } const ::std::string& s() const { return g_host->AttributeProto__s(this); } void set_name(const ::std::string& value) { return g_host->AttributeProto__set_name(this, value); } @@ -121,6 +153,8 @@ struct GraphProto final { NodeProto* add_node() { return g_host->GraphProto__add_node(this); } NodeProto* mutable_node(int index) { return g_host->GraphProto__mutable_node(this, index); } + std::string* mutable_name() { return g_host->GraphProto__mutable_name(this); } + GraphProto() = delete; GraphProto(const GraphProto&) = delete; }; @@ -133,7 +167,7 @@ struct ModelProto final { bool SerializeToOstream(std::ostream& output) const { return g_host->ModelProto__SerializeToOstream(this, output); } bool ParseFromString(const std::string& data) { return g_host->ModelProto__ParseFromString(this, data); } std::string SerializeAsString() const { return g_host->ModelProto__SerializeAsString(this); } - + StringStringEntryProtos* mutable_metadata_props() { return g_host->ModelProto__mutable_metadata_props(this); }; const GraphProto& graph() const { return g_host->ModelProto__graph(this); } GraphProto* mutable_graph() { return g_host->ModelProto__mutable_graph(this); } @@ -162,17 +196,22 @@ struct TensorProto final { void operator=(const TensorProto& v) { g_host->TensorProto__operator_assign(this, v); } bool has_name() const { return g_host->TensorProto__has_name(this); } + void set_name(const ::std::string& name) { return g_host->TensorProto__set_name(this, name); } + const ::std::string& name() const { return g_host->TensorProto__name(this); } int dims_size() const { return g_host->TensorProto__dims_size(this); } const int64s& dims() const { return g_host->TensorProto__dims(this); } + void add_dims(int64_t value) { g_host->TensorProto__add_dims(this, value); } bool has_data_location() const { return g_host->TensorProto__has_data_location(this); } TensorProto_DataLocation data_location() const { return TensorProto_DataLocation(g_host->TensorProto__data_location(this)); } bool has_raw_data() const { return g_host->TensorProto__has_raw_data(this); } const std::string& raw_data() const { return g_host->TensorProto__raw_data(this); } + std::string* mutable_raw_data() { return g_host->TensorProto__mutable_raw_data(this); } int32_t data_type() const { return g_host->TensorProto__data_type(this); } + void set_data_type(int32_t type) { return g_host->TensorProto__set_data_type(this, type); } typedef TensorProto_DataType DataType; static constexpr DataType UNDEFINED = TensorProto_DataType_UNDEFINED; @@ -180,6 +219,13 @@ struct TensorProto final { static bool DataType_IsValid(int value) { return g_host->TensorProto_DataType_IsValid(value); } void copy_from(const TensorProto* other) { return g_host->TensorProto__CopyFrom(this, other); } + StringStringEntryProtos* mutable_external_data() { return g_host->TensorProto__mutable_external_data(this); }; + void clear_float_data() { return g_host->TensorProto__clear_float_data(this); } + void clear_int32_data() { return g_host->TensorProto__clear_int32_data(this); } + void clear_string_data() { return g_host->TensorProto__clear_string_data(this); } + void clear_int64_data() { return g_host->TensorProto__clear_int64_data(this); } + void clear_double_data() { return g_host->TensorProto__clear_double_data(this); } + void clear_uint64_data() { return g_host->TensorProto__clear_uint64_data(this); } TensorProto() = delete; TensorProto(const TensorProto&) = delete; @@ -187,6 +233,8 @@ struct TensorProto final { struct TensorProtos final { TensorProto* Add() { return g_host->TensorProtos__Add(this); } + int size() { return g_host->TensorProtos__size(this); } + TensorProto& at(int index) { return g_host->TensorProtos__at(this, index); } PROVIDER_DISALLOW_ALL(TensorProtos) }; @@ -205,6 +253,8 @@ struct TensorShapeProto_Dimension final { bool has_dim_value() const { return g_host->TensorShapeProto_Dimension__has_dim_value(this); } bool has_dim_param() const { return g_host->TensorShapeProto_Dimension__has_dim_param(this); } void clear_dim_value() { return g_host->TensorShapeProto_Dimension__clear_dim_value(this); } + const std::string& denotation() const { return g_host->TensorShapeProto_Dimension__denotation(this); } + void set_denotation(const std::string& value) { g_host->TensorShapeProto_Dimension__set_denotation(this, value); } PROVIDER_DISALLOW_ALL(TensorShapeProto_Dimension) }; @@ -232,6 +282,7 @@ struct TypeProto_Tensor final { const TensorShapeProto& shape() const { return g_host->TypeProto_Tensor__shape(this); } TensorShapeProto* mutable_shape() { return g_host->TypeProto_Tensor__mutable_shape(this); } int32_t elem_type() const { return g_host->TypeProto_Tensor__elem_type(this); } + void set_elem_type(int32_t value) { g_host->TypeProto_Tensor__set_elem_type(this, value); } PROVIDER_DISALLOW_ALL(TypeProto_Tensor) }; @@ -315,7 +366,6 @@ struct ValueInfoProtos final { PROVIDER_DISALLOW_ALL(ValueInfoProtos) }; - } // namespace ONNX_NAMESPACE namespace onnxruntime { @@ -603,6 +653,10 @@ struct Function final { }; struct Node final { + enum class Type { + Primitive = 0, + Fused = 1, + }; const std::string& Name() const noexcept { return g_host->Node__Name(this); } const std::string& Description() const noexcept { return g_host->Node__Description(this); } const std::string& Domain() const noexcept { return g_host->Node__Domain(this); } @@ -626,6 +680,10 @@ struct Node final { void ToProto(ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) const { return g_host->Node__ToProto(this, proto, update_subgraphs); } const NodeAttributes& GetAttributes() const noexcept { return g_host->Node__GetAttributes(this); } + void AddAttribute(const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) { + g_host->Node__AddAttribute(this, attr_name, value); + } + size_t GetInputEdgesCount() const noexcept { return g_host->Node__GetInputEdgesCount(this); } size_t GetOutputEdgesCount() const noexcept { return g_host->Node__GetOutputEdgesCount(this); } @@ -661,12 +719,15 @@ struct Node final { std::unique_ptr impl_; }; + EdgeConstIterator InputEdgesBegin() const noexcept { return g_host->Node__InputEdgesBegin(this); } + EdgeConstIterator InputEdgesEnd() const noexcept { return g_host->Node__InputEdgesEnd(this); } EdgeConstIterator OutputEdgesBegin() const noexcept { return g_host->Node__OutputEdgesBegin(this); } EdgeConstIterator OutputEdgesEnd() const noexcept { return g_host->Node__OutputEdgesEnd(this); } void ForEachDef(std::function func, bool include_missing_optional_defs = false) const { g_host->Node__ForEachDef(this, func, std::move(include_missing_optional_defs)); } const std::unordered_map>& GetAttributeNameToMutableSubgraphMap() { return g_host->Node__GetAttributeNameToMutableSubgraphMap(this); } std::unordered_map> GetAttributeNameToSubgraphMap() const { return g_host->Node__GetAttributeNameToSubgraphMap(this); } + Type NodeType() const noexcept { return Type(g_host->Node__NodeType(this)); } PROVIDER_DISALLOW_ALL(Node) }; @@ -678,6 +739,7 @@ struct NodeArg final { const NodeArgInfo& ToProto() const noexcept { return g_host->NodeArg__ToProto(this); } bool Exists() const noexcept { return g_host->NodeArg__Exists(this); } const ONNX_NAMESPACE::TypeProto* TypeAsProto() const noexcept { return g_host->NodeArg__TypeAsProto(this); } + Status OverrideTypesHelper(const ONNX_NAMESPACE::TypeProto& input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types) { return g_host->NodeArg__OverrideTypesHelper(this, input_type, input_tensor_elem_type, current_tensor_elem_type, override_types); } PROVIDER_DISALLOW_ALL(NodeArg) }; @@ -698,6 +760,8 @@ struct NodeAttributes final { IteratorHolder> find(const std::string& key) const { return g_host->NodeAttributes__find(this, key); } void insert(const NodeAttributes& v) { return g_host->NodeAttributes__insert(this, v); } void emplace(const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) { g_host->NodeAttributes__emplace(this, k, v); } + void insert_or_assign(const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) { g_host->NodeAttributes__insert_or_assign(this, k, v); } + void reserve(size_t size) { g_host->NodeAttributes__reserve(this, size); } NodeAttributes() = delete; @@ -705,11 +769,18 @@ struct NodeAttributes final { }; struct Model final { + static std::unique_ptr Create(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, + const logging::Logger& logger) { + return g_host->Model__construct(std::move(model_proto), model_path, logger); + } static void operator delete(void* p) { g_host->Model__operator_delete(reinterpret_cast(p)); } + static Status Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) { return g_host->Model__Load(file_path, model_proto); } Graph& MainGraph() { return g_host->Model__MainGraph(this); } std::unique_ptr ToProto() { return g_host->Model__ToProto(this); } + std::unique_ptr ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) { return g_host->Model__ToGraphProtoWithExternalInitializers(this, external_file_name, file_path, initializer_size_threshold); } + const ModelMetaData& MetaData() const noexcept { return g_host->Model__MetaData(this); } Model() = delete; Model(const Model&) = delete; @@ -732,6 +803,7 @@ struct Graph final { void SetOutputs(gsl::span outputs) { return g_host->Graph__SetOutputs(this, outputs); } const std::vector& GetInputs() const noexcept { return g_host->Graph__GetInputs(this); } + std::vector Nodes() const noexcept { return g_host->Graph__Nodes(this); } bool GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const { return g_host->Graph__GetInitializedTensor(this, tensor_name, value); } @@ -742,6 +814,37 @@ struct Graph final { const Path& ModelPath() const { return g_host->Graph__ModelPath(this); } const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->Graph__GetInputsIncludingInitializers(this); } bool IsSubgraph() const { return g_host->Graph__IsSubgraph(this); } + const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->Graph__GetProducerNode(this, node_arg_name); } + const Model& GetModel() const { return g_host->Graph__GetModel(this); } + void ReverseDFSFrom(gsl::span from, const std::function& enter, + const std::function& leave, + const std::function& comp, + const std::function& stop) const { + g_host->Graph__ReverseDFSFrom(this, from, enter, leave, comp, stop); + } + Graph& SetGraphResolveNeeded() { return g_host->Graph__SetGraphResolveNeeded(this); } + void RemoveInitializedTensor(const std::string& tensor_name) { g_host->Graph__RemoveInitializedTensor(this, tensor_name); } + + std::vector GetConsumerNodes(const std::string& node_arg_name) const { + return g_host->Graph__GetConsumerNodes(this, node_arg_name); + } + void AddEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index) { + g_host->Graph__AddEdge(this, src_node_index, dst_node_index, src_arg_index, dst_arg_index); + } + void RemoveEdge(NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, int dst_arg_index) { + g_host->Graph__RemoveEdge(this, src_node_index, dst_node_index, src_arg_index, dst_arg_index); + } + void RemoveNode(NodeIndex index) { g_host->Graph__RemoveNode(this, index); } + Node& FuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name) { + return g_host->Graph__FuseSubGraph(this, sub_graph, fused_node_name); + } + void UpdateProducerNode(const std::string& node_arg_name, NodeIndex node_index) { + g_host->Graph__UpdateProducerNode(this, node_arg_name, node_index); + } + const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const { + return g_host->Graph__GetConstantInitializer(this, name, check_outer_scope); + } + const InitializedTensorSet& GetAllInitializedTensors() const noexcept { return g_host->Graph__GetAllInitializedTensors(this); } int MaxNodeIndex() const noexcept { return g_host->Graph__MaxNodeIndex(this); } const Node* GetNode(NodeIndex node_index) const noexcept { return g_host->Graph__GetNode(this, node_index); } Node* GetNode(NodeIndex node_index) noexcept { return g_host->Graph__GetNode(this, node_index); } @@ -783,6 +886,7 @@ class GraphViewer final { const std::vector& GetInputsIncludingInitializers() const noexcept { return g_host->GraphViewer__GetInputsIncludingInitializers(this); } void ToProto(ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) const { g_host->GraphViewer__ToProto(this, graph_proto, include_initializers, include_outer_scope_args); } + const Node* GetProducerNode(const std::string& node_arg_name) const { return g_host->GraphViewer__GetProducerNode(this, node_arg_name); } GraphViewer() = delete; GraphViewer(const GraphViewer&) = delete; @@ -790,11 +894,16 @@ class GraphViewer final { }; struct Path final { + static std::unique_ptr Create() { return g_host->Path__construct(); } + static void operator delete(void* p) { g_host->Path__operator_delete(reinterpret_cast(p)); } + PathString ToPathString() const noexcept { return g_host->Path__ToPathString(this); } const std::vector& GetComponents() const noexcept { return g_host->Path__GetComponents(this); } bool IsEmpty() const noexcept { return g_host->Path__IsEmpty(this); } - PROVIDER_DISALLOW_ALL(Path) + Path() = delete; + Path(const Path&) = delete; + void operator=(const Path&) = delete; }; struct OpKernelContext final { diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.cc b/onnxruntime/core/providers/vitisai/imp/attr_proto.cc index 29bc886fb5ed4..1392ecef1b72d 100644 --- a/onnxruntime/core/providers/vitisai/imp/attr_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.cc @@ -2,126 +2,106 @@ // Licensed under the MIT License. #include "./attr_proto.h" -#include "./vai_assert.h" - #include #include #include #include -namespace vaip { +#include "core/providers/shared_library/provider_api.h" -ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, - int64_t value) { - auto ret = new onnx::AttributeProto(); +#include "./vai_assert.h" + +namespace vaip { +ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, int64_t value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_INT); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); ret->set_i(value); - return ret; + return ret.release(); } -ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, - float value) { - auto ret = new onnx::AttributeProto(); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, float value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_FLOAT); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT); ret->set_f(value); - return ret; + return ret.release(); } -ONNX_NAMESPACE::AttributeProto* attr_proto_new_string( - const std::string& name, const std::string& value) { - auto ret = new onnx::AttributeProto(); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_string(const std::string& name, const std::string& value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_STRING); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRING); ret->set_s(value); - return ret; + return ret.release(); } ONNX_NAMESPACE::AttributeProto* attr_proto_new_tensor( const std::string& name, const ONNX_NAMESPACE::TensorProto& value) { - auto ret = new onnx::AttributeProto(); + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_TENSOR); - *ret->mutable_t() = value; - return ret; + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR); + *ret->add_tensors() = value; + return ret.release(); } -ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints( - const std::string& name, const std::vector& value) { - auto ret = new onnx::AttributeProto(); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints(const std::string& name, const std::vector& value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_INTS); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INTS); ret->mutable_ints()->Reserve((int)value.size()); for (auto v : value) { ret->add_ints(v); } - return ret; + return ret.release(); } - ONNX_NAMESPACE::AttributeProto* attr_proto_new_floats( const std::string& name, const std::vector& value) { - auto ret = new onnx::AttributeProto(); + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_FLOATS); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS); ret->mutable_floats()->Reserve((int)value.size()); for (auto v : value) { ret->add_floats(v); } - return ret; + return ret.release(); } - -ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings( - const std::string& name, const std::vector& value) { - auto ret = new onnx::AttributeProto(); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings(const std::string& name, const std::vector& value) { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); ret->set_name(name); - ret->set_type(onnx::AttributeProto_AttributeType_STRINGS); - ret->mutable_strings()->Reserve((int)value.size()); + ret->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS); for (auto& v : value) { ret->add_strings(v); } - return ret; + return ret.release(); } - -int64_t attr_proto_get_int(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_INT, attr.DebugString()); +int64_t attr_proto_get_int(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT, attr.name()); return attr.i(); } - -float attr_proto_get_float(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_FLOAT, attr.DebugString()); +float attr_proto_get_float(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT, attr.name()); return attr.f(); } - -const std::string& attr_proto_get_string(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_STRING, attr.DebugString()); +const std::string& attr_proto_get_string(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_STRING, attr.name()); return attr.s(); } - -const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor( - const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_TENSOR, attr.DebugString()); +const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_TENSOR, attr.name()); return attr.t(); } - -gsl::span attr_proto_get_ints(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_INTS, attr.DebugString()); +gsl::span attr_proto_get_ints(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INTS, attr.name()); return gsl::span(attr.ints()); } - -gsl::span attr_proto_get_floats(const onnx::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_FLOATS, attr.DebugString()); +gsl::span attr_proto_get_floats(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_FLOATS, attr.name()); return gsl::span(attr.floats()); } - -std::vector attr_proto_get_strings( - const ONNX_NAMESPACE::AttributeProto& attr) { - vai_assert(attr.type() == onnx::AttributeProto_AttributeType_STRINGS, attr.DebugString()); - return std::vector(attr.strings().begin(), attr.strings().end()); -} - -ONNX_NAMESPACE::AttributeProto attr_proto_from_i64(const std::string& name, - int64_t value) { - ONNX_NAMESPACE::AttributeProto ret; - ret.set_name(name); - ret.set_i(value); +std::vector attr_proto_get_strings(const ONNX_NAMESPACE::AttributeProto& attr) { + vai_assert(attr.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_STRINGS, attr.name()); + std::vector ret; + ret.reserve(attr.strings_size()); + for (int i = 0; i < attr.strings_size(); i++) { + ret.push_back(attr.strings(i)); + } return ret; } - } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/attr_proto.h b/onnxruntime/core/providers/vitisai/imp/attr_proto.h index 32ba8fa672d74..f4d56dd618a8c 100644 --- a/onnxruntime/core/providers/vitisai/imp/attr_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/attr_proto.h @@ -2,46 +2,26 @@ // Licensed under the MIT License. #pragma once #include - +#include "vaip/my_ort.h" #include "core/common/gsl.h" -#include "onnx/onnx_pb.h" namespace vaip { -ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, - int64_t value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, - float value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_string(const std::string& name, - const std::string& value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_tensor( - const std::string& name, const ONNX_NAMESPACE::TensorProto& value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints( - const std::string& name, const std::vector& value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_floats( - const std::string& name, const std::vector& value); -ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings( - const std::string& name, const std::vector& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_int(const std::string& name, int64_t value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_float(const std::string& name, float value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_string(const std::string& name, const std::string& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_tensor(const std::string& name, const ONNX_NAMESPACE::TensorProto& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_ints(const std::string& name, const std::vector& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_floats(const std::string& name, const std::vector& value); +ONNX_NAMESPACE::AttributeProto* attr_proto_new_strings(const std::string& name, const std::vector& value); /// attr_proto getters int64_t attr_proto_get_int(const ONNX_NAMESPACE::AttributeProto& attr); float attr_proto_get_float(const ONNX_NAMESPACE::AttributeProto& attr); -const std::string& attr_proto_get_string( - const ONNX_NAMESPACE::AttributeProto& attr); - -const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor( - const onnx::AttributeProto& attr); -gsl::span attr_proto_get_ints(const onnx::AttributeProto& attr); -gsl::span attr_proto_get_floats(const onnx::AttributeProto& attr); -std::vector attr_proto_get_strings( - const ONNX_NAMESPACE::AttributeProto& attr); - -/// attr_proto makers -ONNX_NAMESPACE::AttributeProto attr_proto_from_i64(const std::string& name, - int64_t); - -/// -using attr_proto_func_t = std::function; +const std::string& attr_proto_get_string(const ONNX_NAMESPACE::AttributeProto& attr); +const ONNX_NAMESPACE::TensorProto& attr_proto_get_tensor(const ONNX_NAMESPACE::AttributeProto& attr); +gsl::span attr_proto_get_ints(const ONNX_NAMESPACE::AttributeProto& attr); +gsl::span attr_proto_get_floats(const ONNX_NAMESPACE::AttributeProto& attr); +std::vector attr_proto_get_strings(const ONNX_NAMESPACE::AttributeProto& attr); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/capability.cc b/onnxruntime/core/providers/vitisai/imp/capability.cc index a55180bd2ee5e..58522a45a151e 100644 --- a/onnxruntime/core/providers/vitisai/imp/capability.cc +++ b/onnxruntime/core/providers/vitisai/imp/capability.cc @@ -3,15 +3,10 @@ #include "vaip/capability.h" #include "./vai_assert.h" -#include "core/graph/basic_types.h" - -#include "./attr_proto.h" - namespace vaip { using namespace ::onnxruntime; -static std::vector node_names_to_nodes(const GraphViewer& graph, - const std::vector& node_names) { +static std::vector node_names_to_nodes(const GraphViewer& graph, const std::vector& node_names) { auto ret = std::vector(); ret.reserve(node_names.size()); for (auto& onnx_node_name : node_names) { @@ -24,53 +19,45 @@ static std::vector node_names_to_nodes(const GraphViewer& graph, } std::unique_ptr XirSubgraphToComputeCapability1(const onnxruntime::GraphViewer& graph, vaip_core::ExecutionProvider* ep, size_t index) { - auto meta_def = std::make_unique(); - meta_def->constant_initializers = *ep->get_meta_def_constant_initializer(); - meta_def->inputs = *ep->get_meta_def_inputs(); - meta_def->outputs = *ep->get_meta_def_outputs(); - auto indexed_subgraph = std::make_unique(); - auto indexed_subgraph_ptr = indexed_subgraph.get(); - indexed_subgraph_ptr->nodes = node_names_to_nodes(graph, *ep->get_meta_def_nodes()); + auto meta_def = IndexedSubGraph_MetaDef::Create(); + meta_def->constant_initializers() = *ep->get_meta_def_constant_initializer(); + meta_def->inputs() = *ep->get_meta_def_inputs(); + meta_def->outputs() = *ep->get_meta_def_outputs(); + auto indexed_subgraph = IndexedSubGraph::Create(); + indexed_subgraph->Nodes() = node_names_to_nodes(graph, *ep->get_meta_def_nodes()); static auto g_counter = 1; - meta_def->name = std::string("vitis_ai_ep_") + std::to_string(g_counter++); - meta_def->domain = "com.xilinx"; - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - auto index_proto = std::unique_ptr(vaip::attr_proto_new_int("index", (int64_t)index)); - meta_def->attributes["index"] = *index_proto; + meta_def->name() = std::string("vitis_ai_ep_") + std::to_string(g_counter++); + meta_def->domain() = "com.xilinx"; + meta_def->since_version() = 1; + meta_def->status() = ONNX_NAMESPACE::EXPERIMENTAL; + auto index_proto = ONNX_NAMESPACE::AttributeProto::Create(); + index_proto->set_name("index"); + index_proto->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType_INT); + index_proto->set_i(index); + meta_def->attributes()["index"] = *index_proto; indexed_subgraph->SetMetaDef(std::move(meta_def)); - return std::make_unique(std::move(indexed_subgraph)); + return ComputeCapability::Create(std::move(indexed_subgraph)); } std::vector> GetComputeCapabilityOps(const onnxruntime::GraphViewer& graph, vaip_core::DllSafe>>* eps, - const std::set& all_not_support_optypes) { - std::set all_compute_capability_nodes; + const std::set& all_support_optypes_by_eps) { + std::set all_nodes_included_eps; for (auto& ep : **eps) { - auto nodes = *ep->get_meta_def_nodes(); - for (auto n : nodes) - all_compute_capability_nodes.insert(n); + auto nodes = node_names_to_nodes(graph, *ep->get_meta_def_nodes()); + all_nodes_included_eps.insert(nodes.begin(), nodes.end()); } + + std::vector node_indexs = graph.GetNodesInTopologicalOrder(); + node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_nodes_included_eps.count(index) > 0; }), node_indexs.end()); + node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_support_optypes_by_eps.count(graph.GetNode(index)->OpType()) == 0; }), node_indexs.end()); + std::vector> result; - for (auto& n : graph.Nodes()) { - if ((!all_compute_capability_nodes.count(n.Name())) && all_not_support_optypes.count(n.OpType())) { - auto meta_def = std::make_unique(); - meta_def->name = n.OpType(); - meta_def->domain = n.Domain(); - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - auto indexed_subgraph = std::make_unique(); - indexed_subgraph->nodes.push_back(n.Index()); - for (auto i : n.InputDefs()) { - meta_def->inputs.push_back(i->Name()); - } - for (auto i : n.OutputDefs()) { - meta_def->outputs.push_back(i->Name()); - } - indexed_subgraph->SetMetaDef(std::move(meta_def)); - result.emplace_back(std::make_unique(std::move(indexed_subgraph))); - } + for (auto& n : node_indexs) { + auto indexed_subgraph = IndexedSubGraph::Create(); + indexed_subgraph->Nodes() = {n}; + result.emplace_back(ComputeCapability::Create(std::move(indexed_subgraph))); } return result; } diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index b629c8eff9097..f609d40f459b7 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -1,20 +1,18 @@ - // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. + #include "vaip/global_api.h" #include +#include +#include #include #include "./vai_assert.h" -#include "core/common/exceptions.h" -#include "core/common/logging/logging.h" +#include "core/common/exceptions.h" #include "core/framework/error_code_helper.h" - -#include "core/graph/model.h" -#include "core/session/ort_env.h" -#include "core/session/onnxruntime_cxx_api.h" +#include "core/providers/shared/common.h" #include @@ -55,16 +53,14 @@ struct OrtVitisAIEpAPI { std::vector>* (*compile_onnx_model_with_options)( const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); void Ensure() { - if (handle_) return; - auto full_path = Env::Default().GetRuntimePath() + - PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); - ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, true, &handle_)); - ORT_THROW_IF_ERROR(Env::Default().GetSymbolFromLibrary( - handle_, "initialize_onnxruntime_vitisai_ep", reinterpret_cast(&initialize_onnxruntime_vitisai_ep))); - auto status1 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", - reinterpret_cast(&compile_onnx_model_with_options)); - auto status2 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", - reinterpret_cast(&compile_onnx_model_3)); + if (handle_) + return; + auto& env = Provider_GetHost()->Env__Default(); + auto full_path = env.GetRuntimePath() + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_)); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "initialize_onnxruntime_vitisai_ep", (void**)&initialize_onnxruntime_vitisai_ep)); + auto status1 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options); + auto status2 = env.GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", (void**)&compile_onnx_model_3); if (!status1.IsOK() && !status2.IsOK()) { ::onnxruntime::LogRuntimeError(0, status1, __FILE__, static_cast(__FUNCTION__), __LINE__); ORT_THROW(status1); @@ -76,6 +72,12 @@ struct OrtVitisAIEpAPI { }; static OrtVitisAIEpAPI s_library_vitisaiep; +static std::shared_ptr s_kernel_registry_vitisaiep; +static std::vector s_domains_vitisaiep; +static vaip_core::OrtApiForVaip the_global_api; +std::shared_ptr get_kernel_registry_vitisaiep() { return s_kernel_registry_vitisaiep; } +const std::vector& get_domains_vitisaiep() { return s_domains_vitisaiep; } + static std::string config_to_json_str(const onnxruntime::ProviderOptions& config) { auto iter = config.find("config_file"); if (iter == config.end()) { @@ -105,121 +107,142 @@ static std::string config_to_json_str(const onnxruntime::ProviderOptions& config return ""; } } -vaip_core::DllSafe>> compile_onnx_model_with_options( - const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options) { + +vaip_core::DllSafe>> compile_onnx_model( + const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { +#ifndef _WIN32 + auto model_path = graph_viewer.ModelPath().ToPathString(); +#else + using convert_t = std::codecvt_utf8; + std::wstring_convert strconverter; + auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); +#endif if (s_library_vitisaiep.compile_onnx_model_with_options) { - return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph, options)); + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options)); } else { auto json_str = config_to_json_str(options); - return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_3(model_path, graph, json_str.c_str())); + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_3(model_path, graph_viewer.GetGraph(), json_str.c_str())); } } -std::vector initialize_vitisai_ep() { - s_library_vitisaiep.Ensure(); - Status status = Status::OK(); - try { - OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, - "onnxruntime-vitisai-ep"}; - std::ignore = OrtEnv::GetInstance(lm_info, status); - } catch (onnxruntime::OnnxRuntimeException& /*e*/) { +struct MyCustomOpKernel : OpKernel { + MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { + op_kernel_ = + op_.CreateKernel(&op_, Ort::Global::api_, reinterpret_cast(&info)); } - auto domains = std::vector(); - domains.reserve(100); - s_library_vitisaiep.initialize_onnxruntime_vitisai_ep(create_org_api_hook(), domains); - auto& domainToVersionRangeInstance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); - if (domainToVersionRangeInstance.Map().find("com.xilinx") == domainToVersionRangeInstance.Map().end()) { - vaip::register_xir_ops(domains); + + ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } + + Status Compute(OpKernelContext* ctx) const override { + op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); + return Status::OK(); } - return domains; + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MyCustomOpKernel); + + const OrtCustomOp& op_; + void* op_kernel_; +}; + +void create_kernel_registry(std::vector domains) { + s_kernel_registry_vitisaiep = KernelRegistry::Create(); + for (const auto& domain : domains) { + for (const auto* op : domain->custom_ops_) { + auto def_builder = KernelDefBuilder::Create(); + def_builder->SetName(op->GetName(op)); + def_builder->SetDomain(domain->domain_.c_str()); + def_builder->SinceVersion(1); + if (op->version > 12) { + auto input_count = op->GetInputTypeCount(op); + for (auto i = 0u; i < input_count; i++) { + def_builder->InputMemoryType(op->GetInputMemoryType(op, i), i); + } + } + def_builder->Provider(onnxruntime::kVitisAIExecutionProvider); + KernelCreateFn kernel_create_fn = + [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + // out = std::make_unique(info, *op); + return Status::OK(); + }; + std::ignore = s_kernel_registry_vitisaiep->Register(KernelCreateInfo(def_builder->Build(), kernel_create_fn)); + } + } +} +void initialize_vitisai_ep() { + s_library_vitisaiep.Ensure(); + s_domains_vitisaiep.reserve(100); + s_library_vitisaiep.initialize_onnxruntime_vitisai_ep(create_org_api_hook(), s_domains_vitisaiep); + vaip::register_xir_ops(s_domains_vitisaiep); + create_kernel_registry(s_domains_vitisaiep); } -static vaip_core::OrtApiForVaip the_global_api; vaip_core::OrtApiForVaip* create_org_api_hook() { + InitProviderOrtApi(); + the_global_api.host_ = Provider_GetHost(); assert(Ort::Global::api_ != nullptr); the_global_api.ort_api_ = Ort::Global::api_; the_global_api.model_load = [](const std::string& filename) -> Model* { - ONNX_NAMESPACE::ModelProto model_proto; + auto model_proto = ONNX_NAMESPACE::ModelProto::Create(); auto& logger = logging::LoggingManager::DefaultLogger(); auto file_path = ToPathString(filename); - auto status = Model::Load(file_path, model_proto); + auto status = Model::Load(file_path, *model_proto); vai_assert(status.IsOK(), "load model proto error"); - auto model = std::make_unique(std::move(model_proto), file_path, nullptr, logger); + auto model = Model::Create(std::move(*model_proto), file_path, logger); return model.release(); }; the_global_api.model_delete = [](Model* model) { delete model; }; - the_global_api.model_clone = [](const Model& model) -> Model* { + + the_global_api.model_clone = [](const Model& const_model) -> Model* { auto& logger = logging::LoggingManager::DefaultLogger(); - auto model_proto = const_cast(model).ToProto(); - auto file_path = model.ModelPath().ToPathString(); - auto ret = std::make_unique(std::move(model_proto), file_path, nullptr, logger); + auto& model = const_cast(const_model); + auto model_proto = model.ToProto(); + auto file_path = model.MainGraph().ModelPath().ToPathString(); + auto ret = Model::Create(std::move(*model_proto), file_path, logger); auto status = ret->MainGraph().Resolve(); vai_assert(status.IsOK(), status.ErrorMessage()); return ret.release(); }; - the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) -> void { + the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) { const_cast(model.MetaData())[key] = value; }; - the_global_api.model_get_meta_data = [](const Model& model, - const std::string& key) -> vaip_core::DllSafe { - auto& m = model.MetaData(); - auto it = m.find(key); - auto ret = std::string(); - if (it != m.end()) { - ret = it->second; + the_global_api.model_get_meta_data = + [](const Model& model, const std::string& key) -> vaip_core::DllSafe { + if (model.MetaData().count(key)) { + return vaip_core::DllSafe(model.MetaData().at(key)); } - return vaip_core::DllSafe(ret); + return vaip_core::DllSafe(std::string()); }; - the_global_api.model_has_meta_data = [](const Model& model, const std::string& key) -> int { - auto& m = model.MetaData(); - return m.find(key) != m.end() ? 1 : 0; + return int(model.MetaData().count(key)); }; - the_global_api.model_main_graph = [](Model& model) -> Graph& { return model.MainGraph(); }; the_global_api.graph_get_model = [](const Graph& graph) -> const Model& { return graph.GetModel(); }; - the_global_api.graph_get_inputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { - auto ret = std::vector(); - auto inputs = graph.GetInputs(); - for (auto input : inputs) { - vai_assert(input->Exists(), input->Name()); - ret.push_back(input); - } - return vaip_core::DllSafe(std::move(ret)); + the_global_api.graph_get_inputs_unsafe = [](const Graph& graph) -> auto { + return vaip_core::DllSafe(graph.GetInputs()); }; - the_global_api.graph_get_outputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.graph_get_outputs_unsafe = [](const Graph& graph) -> auto { return vaip_core::DllSafe(graph.GetOutputs()); }; - - the_global_api.graph_set_outputs = [](Graph& graph, gsl::span outputs) -> void { - return graph.SetOutputs(outputs); + the_global_api.graph_set_outputs = [](Graph& graph, gsl::span outputs) { + graph.SetOutputs(outputs); }; - the_global_api.graph_get_node_arg = [](const Graph& graph, const std::string& name) -> const NodeArg* { return graph.GetNodeArg(name); }; the_global_api.graph_producer_node = [](const Graph& graph, const std::string& name) -> const Node* { return graph.GetProducerNode(name); }; - - the_global_api.graph_get_node = [](const Graph& graph, size_t index) -> const Node* { return graph.GetNode(index); }; - + the_global_api.graph_get_node = [](const Graph& graph, size_t index) -> const Node* { + return graph.GetNode(index); + }; the_global_api.graph_save = vaip::graph_save; the_global_api.graph_fuse = vaip::graph_fuse; the_global_api.graph_remove_node = vaip::graph_remove_node; - the_global_api.graph_add_node = [](Graph& graph, const std::string& name, const std::string& op_type, - const std::string& description, const std::vector& input_args, - const std::vector& output_args, - vaip_core::NodeAttributes& attributes, const std::string& domain) -> Node& { - return vaip::graph_add_node(graph, name, op_type, description, input_args, output_args, - std::move(reinterpret_cast(attributes)), domain); - }; - + the_global_api.graph_add_node = vaip::graph_add_node; the_global_api.graph_get_all_initialized_tensors = [](const Graph& graph) -> const InitializedTensorSet& { return graph.GetAllInitializedTensors(); }; - the_global_api.graph_resolve = [](Graph& graph, bool force) { if (force) { graph.SetGraphResolveNeeded(); @@ -227,129 +250,57 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { auto status = graph.Resolve(); return status.Code(); }; - - the_global_api.graph_get_consumer_nodes_unsafe = - [](const Graph& graph, const std::string& node_arg_name) -> vaip_core::DllSafe> { + the_global_api.graph_get_consumer_nodes_unsafe = [](const Graph& graph, const std::string& node_arg_name) -> auto { return vaip_core::DllSafe(graph.GetConsumerNodes(node_arg_name)); }; - the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { - auto& node_refererence = graph.Nodes(); - std::vector nodes(static_cast(graph.NumberOfNodes()), nullptr); - std::transform(node_refererence.begin(), node_refererence.end(), nodes.begin(), [](const Node& n) { return &n; }); - return vaip_core::DllSafe(std::move(nodes)); - }; + the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> auto { return vaip_core::DllSafe(graph.Nodes()); }; the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { return graph.Name(); }; the_global_api.graph_reverse_dfs_from = [](const Graph& graph, gsl::span from, - const std::function& enter, - const std::function& leave, - const std::function& stop) { + const auto& enter, const auto& leave, const auto& stop) { graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); }; // node the_global_api.node_get_inputs_unsafe = vaip::node_get_inputs; the_global_api.node_get_output_node_args_unsafe = vaip::node_get_output_node_args; - the_global_api.node_op_type = [](const Node& node) -> const std::string& { return node.OpType(); }; the_global_api.node_op_domain = [](const Node& node) -> const std::string& { return node.Domain(); }; - the_global_api.node_get_index = [](const Node& node) -> size_t { return static_cast(node.Index()); }; + the_global_api.node_get_index = [](const Node& node) -> size_t { return node.Index(); }; the_global_api.node_get_name = [](const Node& node) -> const std::string& { return node.Name(); }; the_global_api.node_description = [](const Node& node) -> const std::string& { return node.Description(); }; - - the_global_api.node_get_attributes = [](Node& node) -> vaip_core::NodeAttributes& { - return reinterpret_cast(node.GetMutableAttributes()); - }; - - the_global_api.node_type_is_fused = [](const Node& node) { - return node.NodeType() == onnxruntime::Node::Type::Fused; + the_global_api.node_get_attributes = [](Node& node) -> NodeAttributes& { + return const_cast(node.GetAttributes()); }; - the_global_api.node_get_function_body = [](const Node& node) -> const onnxruntime::Graph& { + the_global_api.node_type_is_fused = [](const Node& node) { return node.NodeType() == Node::Type::Fused; }; + the_global_api.node_get_function_body = [](const Node& node) -> const auto& { assert(node.GetFunctionBody() != nullptr); return node.GetFunctionBody()->Body(); }; // node_arg - the_global_api.node_arg_get_name_unsafe = [](const NodeArg& node_arg) -> const std::string& { - return node_arg.Name(); - }; + the_global_api.node_arg_get_name_unsafe = + [](const NodeArg& node_arg) -> const std::string& { return node_arg.Name(); }; the_global_api.node_arg_clone = vaip::node_arg_clone; the_global_api.node_arg_new = vaip::node_arg_new; - the_global_api.node_arg_is_exists = vaip::node_arg_is_exists; + the_global_api.node_arg_is_exists = [](const NodeArg& node_arg) { return node_arg.Exists(); }; the_global_api.node_arg_is_constant = vaip::node_arg_is_constant; the_global_api.node_arg_get_shape_i64_unsafe = vaip::node_arg_get_shape_i64; the_global_api.node_arg_set_shape_i64 = vaip::node_arg_set_shape_i64; the_global_api.node_arg_get_denotation_unsafe = vaip::node_arg_get_denotation; + the_global_api.node_arg_set_denotation = vaip::node_arg_set_denotation; the_global_api.node_arg_get_const_data_as_tensor = vaip::node_arg_get_const_data_as_tensor; the_global_api.node_arg_get_element_type = vaip::node_arg_get_element_type; - the_global_api.node_arg_set_element_type = [](NodeArg& node_arg, int type) { - auto data_type = ONNX_NAMESPACE::TensorProto::UNDEFINED; - switch (type) { - case 1: - data_type = ONNX_NAMESPACE::TensorProto::FLOAT; - break; - case 2: - data_type = ONNX_NAMESPACE::TensorProto::UINT8; - break; - case 3: - data_type = ONNX_NAMESPACE::TensorProto::INT8; - break; - - case 4: - data_type = ONNX_NAMESPACE::TensorProto::UINT16; - break; - case 5: - data_type = ONNX_NAMESPACE::TensorProto::INT16; - break; - case 6: - data_type = ONNX_NAMESPACE::TensorProto::INT32; - break; - case 7: - data_type = ONNX_NAMESPACE::TensorProto::INT64; - break; - case 8: - data_type = ONNX_NAMESPACE::TensorProto::STRING; - break; - case 9: - data_type = ONNX_NAMESPACE::TensorProto::BOOL; - break; - case 10: - data_type = ONNX_NAMESPACE::TensorProto::FLOAT16; - break; - case 11: - data_type = ONNX_NAMESPACE::TensorProto::DOUBLE; - break; - case 12: - data_type = ONNX_NAMESPACE::TensorProto::UINT32; - break; - case 13: - data_type = ONNX_NAMESPACE::TensorProto::UINT64; - break; - case 14: - data_type = ONNX_NAMESPACE::TensorProto::COMPLEX64; - break; - case 15: - data_type = ONNX_NAMESPACE::TensorProto::COMPLEX128; - break; - case 16: - data_type = ONNX_NAMESPACE::TensorProto::BFLOAT16; - break; - default: - vai_assert(false, "TensorProto::DataType not supoort"); - } - return vaip::node_arg_set_element_type(node_arg, data_type); - }; + the_global_api.node_arg_set_element_type = vaip::node_arg_set_element_type; /// attr proto - the_global_api.attr_proto_delete = [](onnx::AttributeProto* v) { delete v; }; - the_global_api.attr_proto_clone = [](const onnx::AttributeProto& v) -> onnx::AttributeProto* { - return new onnx::AttributeProto(v); - }; - the_global_api.attr_proto_get_name = [](const onnx::AttributeProto& attr_proto) -> const std::string& { - return attr_proto.name(); - }; - the_global_api.attr_proto_set_name = [](onnx::AttributeProto* attr_proto, const std::string& name) { - attr_proto->set_name(name); + the_global_api.attr_proto_delete = [](ONNX_NAMESPACE::AttributeProto* v) { delete v; }; + the_global_api.attr_proto_clone = [](const ONNX_NAMESPACE::AttributeProto& v) -> ONNX_NAMESPACE::AttributeProto* { + auto ret = ONNX_NAMESPACE::AttributeProto::Create(); + *ret = v; + return ret.release(); }; + the_global_api.attr_proto_get_name = [](const auto& attr_proto) -> const std::string& { return attr_proto.name(); }; + the_global_api.attr_proto_set_name = [](auto* attr_proto, const auto& name) { attr_proto->set_name(name); }; the_global_api.attr_proto_new_int = vaip::attr_proto_new_int; the_global_api.attr_proto_new_float = vaip::attr_proto_new_float; the_global_api.attr_proto_new_string = vaip::attr_proto_new_string; @@ -364,31 +315,24 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.attr_proto_get_ints = vaip::attr_proto_get_ints; the_global_api.attr_proto_get_floats = vaip::attr_proto_get_floats; the_global_api.attr_proto_get_strings = vaip::attr_proto_get_strings; - the_global_api.attr_proto_get_type = [](const onnx::AttributeProto& attr) -> int { return attr.type(); }; + the_global_api.attr_proto_get_type = [](const ONNX_NAMESPACE::AttributeProto& attr) -> int { return attr.type(); }; /// node attributes - the_global_api.node_attributes_new = []() { - return reinterpret_cast(new NodeAttributes()); - }; - the_global_api.node_attributes_add = [](vaip_core::NodeAttributes& p, onnx::AttributeProto&& attr) { - reinterpret_cast(p).insert_or_assign(attr.name(), std::move(attr)); + the_global_api.node_attributes_new = []() { return NodeAttributes::Create().release(); }; + the_global_api.node_attributes_add = [](NodeAttributes& p, ONNX_NAMESPACE::AttributeProto&& attr) { + p.insert_or_assign(attr.name(), std::move(attr)); }; - the_global_api.node_attributes_delete = [](vaip_core::NodeAttributes* p) { - delete reinterpret_cast(p); - }; - the_global_api.node_attributes_get = [](vaip_core::NodeAttributes& p, - const std::string& name) -> ONNX_NAMESPACE::AttributeProto* { - auto& attr = reinterpret_cast(p); - auto it = attr.find(name); - if (it == attr.end()) { - return nullptr; + + the_global_api.node_attributes_delete = [](NodeAttributes* p) { delete p; }; + the_global_api.node_attributes_get = + [](const NodeAttributes& attr, const std::string& name) -> const ONNX_NAMESPACE::AttributeProto* { + if (attr.count(name)) { + return &attr.at(name); } - return &it->second; + return nullptr; }; - the_global_api.node_attributes_get_keys = - [](vaip_core::NodeAttributes& p) -> vaip_core::DllSafe> { + the_global_api.node_attributes_get_keys = [](NodeAttributes& attr) -> vaip_core::DllSafe> { auto ret = std::vector(); - auto& attr = reinterpret_cast(p); ret.reserve(attr.size()); for (auto& it : attr) { ret.push_back(it.first); @@ -396,35 +340,16 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return vaip_core::DllSafe(std::move(ret)); }; /// tensor proto - the_global_api.tensor_proto_get_shape_unsafe = - [](const onnx::TensorProto& t) -> vaip_core::DllSafe> { - return vaip_core::DllSafe>(vaip::tensor_proto_get_shape(t)); - }; - - the_global_api.tensor_proto_data_type = [](const onnx::TensorProto& t) -> int { return t.data_type(); }; - - the_global_api.tensor_proto_delete = [](onnx::TensorProto* tp) { delete tp; }; - - the_global_api.tensor_proto_new_floats = [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{vaip::tensor_proto_new_floats(name, shape, data)}; - }; - the_global_api.tensor_proto_new_i32 = [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{vaip::tensor_proto_new_i32(name, shape, data)}; - }; - the_global_api.tensor_proto_new_i64 = [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{vaip::tensor_proto_new_i64(name, shape, data)}; - }; - the_global_api.tensor_proto_new_i8 = [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{vaip::tensor_proto_new_i8(name, shape, data)}; - }; - the_global_api.tensor_proto_raw_data_size = vaip::tensor_proto_raw_data_size; - + the_global_api.tensor_proto_get_shape_unsafe = vaip::tensor_proto_get_shape; + the_global_api.tensor_proto_data_type = [](const ONNX_NAMESPACE::TensorProto& t) -> int { return t.data_type(); }; + the_global_api.tensor_proto_delete = [](ONNX_NAMESPACE::TensorProto* tp) { delete tp; }; + the_global_api.tensor_proto_new_floats = vaip::tensor_proto_new_floats; + the_global_api.tensor_proto_new_i32 = vaip::tensor_proto_new_i32; + the_global_api.tensor_proto_new_i64 = vaip::tensor_proto_new_i64; + the_global_api.tensor_proto_new_i8 = vaip::tensor_proto_new_i8; + the_global_api.tensor_proto_raw_data_size = [](const auto& tensor) { return tensor.raw_data().size(); }; the_global_api.tensor_proto_as_raw = vaip::tensor_proto_as_raw; - the_global_api.tensor_proto_get_name = vaip::tensor_proto_get_name; + the_global_api.tensor_proto_get_name = [](const auto& tensor) -> const std::string& { return tensor.name(); }; the_global_api.get_lib_name = []() -> vaip_core::DllSafe { return vaip_core::DllSafe(std::string("onnxruntime.") + std::string(ORT_VERSION)); diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index cca680baf7dc0..061bc414fcec7 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -2,27 +2,15 @@ // Licensed under the MIT License. #include "vaip/graph.h" -#include - -#include "./vai_assert.h" #include #include #include #include #include #include -#include "onnx/onnx-ml.pb.h" -#ifdef _MSC_VER -#pragma warning(push) -// 'type' : forcing value to bool 'true' or 'false' (performance warning) -#pragma warning(disable : 4800) -#endif -#include -#ifdef _MSC_VER -#pragma warning(pop) -#endif -using convert_t = std::codecvt_utf8; -std::wstring_convert strconverter; + +#include "core/providers/shared_library/provider_api.h" +#include "./vai_assert.h" #include "vaip/node.h" #include "vaip/node_arg.h" @@ -38,23 +26,14 @@ struct NodeEdgeT { static void graph_remove_node(Graph& graph, const Node& node) { auto remove_edges = std::vector(); - auto begin = node.InputEdgesBegin(); - auto end = node.InputEdgesEnd(); - for (auto it = begin; it != end; ++it) { - remove_edges.push_back(NodeEdgeT{it->GetNode().Index(), node.Index(), - it->GetSrcArgIndex(), - it->GetDstArgIndex()}); + for (auto it = node.InputEdgesBegin(); it != node.InputEdgesEnd(); ++it) { + remove_edges.push_back(NodeEdgeT{it->GetNode().Index(), node.Index(), it->GetSrcArgIndex(), it->GetDstArgIndex()}); } - begin = node.OutputEdgesBegin(); - end = node.OutputEdgesEnd(); - for (auto it = begin; it != end; ++it) { - remove_edges.push_back(NodeEdgeT{node.Index(), it->GetNode().Index(), - it->GetSrcArgIndex(), - it->GetDstArgIndex()}); + for (auto it = node.OutputEdgesBegin(); it != node.OutputEdgesEnd(); ++it) { + remove_edges.push_back(NodeEdgeT{node.Index(), it->GetNode().Index(), it->GetSrcArgIndex(), it->GetDstArgIndex()}); } for (auto it : remove_edges) { - graph.RemoveEdge(it.src_node_index, it.dst_node_index, it.src_arg_index, - it.dst_arg_index); + graph.RemoveEdge(it.src_node_index, it.dst_node_index, it.src_arg_index, it.dst_arg_index); } graph.RemoveNode(node.Index()); } @@ -68,13 +47,9 @@ static std::vector node_get_implicit_input_node_args(const Node& } return ret; } - -Node& graph_add_node(Graph& graph, const std::string& name, - const std::string& op_type, const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - const NodeAttributes& attributes, - const std::string& domain) { +Node& graph_add_node(Graph& graph, const std::string& name, const std::string& op_type, const std::string& description, + const std::vector& input_args, const std::vector& output_args, + const NodeAttributes& attributes, const std::string& domain) { std::vector inputs; inputs.reserve(input_args.size()); for (auto i : input_args) { @@ -85,8 +60,7 @@ Node& graph_add_node(Graph& graph, const std::string& name, for (auto i : output_args) { outputs.push_back(const_cast(i)); } - auto& ret = graph.AddNode(name, op_type, description, inputs, outputs, - &attributes, domain); + auto& ret = graph.AddNode(name, op_type, description, inputs, outputs, &attributes, domain); auto src_arg_index = 0; for (auto& o : outputs) { auto consumers = graph.GetConsumerNodes(o->Name()); @@ -96,8 +70,7 @@ Node& graph_add_node(Graph& graph, const std::string& name, for (auto ni : *tmp_inputs) { auto name1 = ni.node_arg->Name(); if (name1 == o->Name()) { - graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, - dst_arg_index); + graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, dst_arg_index); } dst_arg_index = dst_arg_index + 1; } @@ -105,8 +78,7 @@ Node& graph_add_node(Graph& graph, const std::string& name, for (auto implicit_node_arg : node_get_implicit_input_node_args(*consumer)) { auto name1 = implicit_node_arg->Name(); if (name1 == o->Name()) { - graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, - dst_arg_index); + graph.AddEdge(ret.Index(), consumer->Index(), src_arg_index, dst_arg_index); } dst_arg_index = dst_arg_index + 1; } @@ -132,44 +104,39 @@ void graph_remove_node(Graph& graph, const NodeInput& node_input) { void graph_save(const Graph& graph, const std::string& filename, const std::string& filename_dat, size_t initializer_size_threshold) { auto& model = const_cast(graph.GetModel()); - auto model_proto = ONNX_NAMESPACE::ModelProto(); + std::unique_ptr model_proto; if (initializer_size_threshold == std::numeric_limits::max()) { model_proto = model.ToProto(); } else { - model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, - ToPathString(filename), - initializer_size_threshold); + model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, graph.ModelPath().ToPathString(), initializer_size_threshold); } auto& metadata = model.MetaData(); if (!metadata.empty()) { - model_proto.mutable_metadata_props()->Clear(); + auto metadata_props = model_proto->mutable_metadata_props(); + metadata_props->Clear(); for (auto& m : metadata) { - auto prop = model_proto.mutable_metadata_props()->Add(); + auto prop = metadata_props->Add(); *prop->mutable_key() = m.first; *prop->mutable_value() = m.second; } } // use relative path as data storage. - auto graph_proto = model_proto.mutable_graph(); - *graph_proto = graph.ToGraphProto(); - for (auto i = 0; i < graph_proto->initializer_size(); ++i) { - auto initializer = graph_proto->mutable_initializer(i); - for (auto j = 0; j < initializer->external_data_size(); ++j) { - auto external_data = initializer->mutable_external_data(j); - if (external_data->key() == "location") { - *external_data->mutable_value() = std::filesystem::path(external_data->value()).filename().u8string(); - } + auto graph_proto = model_proto->mutable_graph(); + *graph_proto = *graph.ToGraphProto(); + for (int i = 0; i < graph_proto->mutable_initializer()->size(); i++) { + auto mutable_external_data = graph_proto->mutable_initializer()->at(i).mutable_external_data(); + for (int j = 0; j < mutable_external_data->size(); j++) { + auto& external_data = mutable_external_data->at(j); + if (*external_data.mutable_key() == "location") + *external_data.mutable_value() = std::filesystem::path(*external_data.mutable_value()).filename().u8string(); } } - int fd = -1; - Status status = Env::Default().FileOpenWr(filename, fd); - vai_assert(status.IsOK(), status.ErrorMessage()); - google::protobuf::io::FileOutputStream output(fd); - const bool result = model_proto.SerializeToZeroCopyStream(&output) && output.Flush(); - vai_assert(result, "model serialize to zero cipy stream error"); - status = Env::Default().FileClose(fd); - vai_assert(status.IsOK(), status.ErrorMessage()); + + std::fstream output(filename, std::ios::out | std::ios::trunc | std::ios::binary); + bool result = model_proto->SerializeToOstream(output); + output << std::flush; + vai_assert(result, "model serialize to ostream error"); } Node& graph_fuse(Graph& graph, const std::string& name, @@ -178,25 +145,25 @@ Node& graph_fuse(Graph& graph, const std::string& name, const std::vector& inputs, const std::vector& outputs, const std::vector& constant_initializers) { - auto meta_def = std::make_unique(); - auto indexed_subgraph = std::make_unique(); - indexed_subgraph->nodes = nodes; - meta_def->inputs = inputs; - meta_def->outputs = outputs; - meta_def->constant_initializers = constant_initializers; - meta_def->name = "super_layer"; - meta_def->domain = "com.xilinx"; - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; + auto meta_def = IndexedSubGraph_MetaDef::Create(); + meta_def->inputs() = inputs; + meta_def->outputs() = outputs; + meta_def->constant_initializers() = constant_initializers; + meta_def->name() = "super_layer"; + meta_def->domain() = "com.xilinx"; + meta_def->since_version() = 1; + meta_def->status() = ONNX_NAMESPACE::EXPERIMENTAL; + + auto indexed_subgraph = IndexedSubGraph::Create(); + indexed_subgraph->Nodes() = nodes; indexed_subgraph->SetMetaDef(std::move(meta_def)); + auto& fused_node = graph.FuseSubGraph(*indexed_subgraph, name); auto function_body = fused_node.GetFunctionBody(); if (function_body) { - auto& mygraph = function_body->Body(); - // auto proto = graph.ToGraphProtoWithExternal("exteranl.dat", 128); - auto proto = mygraph.ToGraphProto(); - *proto.mutable_name() = name; - fused_node.AddAttribute("body", proto); + auto proto = function_body->Body().ToGraphProto(); + *proto->mutable_name() = name; + fused_node.AddAttribute("body", *proto); } for (auto&& o : fused_node.OutputDefs()) { graph.UpdateProducerNode(o->Name(), fused_node.Index()); diff --git a/onnxruntime/core/providers/vitisai/imp/node.cc b/onnxruntime/core/providers/vitisai/imp/node.cc index 6d65ad4e8c408..0565171fb7f40 100644 --- a/onnxruntime/core/providers/vitisai/imp/node.cc +++ b/onnxruntime/core/providers/vitisai/imp/node.cc @@ -4,9 +4,8 @@ #include "./vai_assert.h" #include "attr_proto.h" -#include "core/graph/graph_utils.h" -#include "core/graph/node_arg.h" #include "vaip/node_arg.h" +#include "core/providers/shared_library/provider_api.h" namespace vaip { @@ -29,7 +28,6 @@ vaip_core::DllSafe> node_get_inputs(const Node& node) { } return vaip_core::DllSafe(ret); } - vaip_core::DllSafe> node_get_output_node_args(const Node& node) { auto outputs = node.OutputDefs(); auto size = outputs.size(); @@ -42,11 +40,4 @@ vaip_core::DllSafe> node_get_output_node_args(const } return vaip_core::DllSafe(ret); } - -vaip_core::DllSafe> node_get_output_shape(const Node& node, int index) { - auto outputs = node.OutputDefs(); - assert((size_t)index < outputs.size()); - return node_arg_get_shape_i64(*outputs[index]); -} - } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/node_arg.cc b/onnxruntime/core/providers/vitisai/imp/node_arg.cc index 3bdeb09698d49..a54cbef91c398 100644 --- a/onnxruntime/core/providers/vitisai/imp/node_arg.cc +++ b/onnxruntime/core/providers/vitisai/imp/node_arg.cc @@ -2,25 +2,16 @@ // Licensed under the MIT License. #include "vaip/node_arg.h" #include "./vai_assert.h" - -#include +#include "core/providers/shared_library/provider_api.h" #include "./tensor_proto.h" -#include "core/graph/node_arg.h" namespace vaip { - -bool node_arg_is_exists(const NodeArg& node_arg) { - return node_arg.Exists(); -} bool node_arg_is_constant(const Graph& graph, const NodeArg& node_arg) { assert(node_arg.Exists()); assert(!node_arg.Name().empty()); - auto constant_tensor_proto = - graph.GetConstantInitializer(node_arg.Name(), true); - return constant_tensor_proto != nullptr; + return graph.GetConstantInitializer(node_arg.Name(), true) != nullptr; } - vaip_core::DllSafe> node_arg_get_shape_i64(const NodeArg& node_arg) { auto shape = node_arg.Shape(); if (nullptr == shape) return vaip_core::DllSafe>(); @@ -32,104 +23,42 @@ vaip_core::DllSafe> node_arg_get_shape_i64(const NodeArg& n } return vaip_core::DllSafe(shape_vector); } - -static void LayoutTransformRule_set_shape(onnx::TensorShapeProto& shape_proto, - const std::vector& shape) { - assert(shape.size() == static_cast(shape_proto.dim_size())); - auto rank = shape_proto.dim_size(); +void node_arg_set_shape_i64(const NodeArg& node_arg, const std::vector& shape) { + auto shape_proto = const_cast(node_arg.Shape()); + assert(shape_proto != nullptr); + assert(shape.size() == static_cast(shape_proto->dim_size())); + auto rank = shape_proto->dim_size(); for (auto i = 0; i < rank; ++i) { - shape_proto.mutable_dim(i)->set_dim_value(shape[i]); + shape_proto->mutable_dim(i)->set_dim_value(shape[i]); } } - -static void LayoutTransformRule_set_shape(onnx::TypeProto& type_proto, - const std::vector& shape) { - assert(type_proto.value_case() == onnx::TypeProto::kTensorType); - //<< type_proto.DebugString(); - auto& tensor_type = *type_proto.mutable_tensor_type(); - auto& shape_prot = *tensor_type.mutable_shape(); - return LayoutTransformRule_set_shape(shape_prot, shape); -} - -static void LayoutTransformRule_set_shape(NodeArg* node_arg, - const std::vector& shape) { - assert(node_arg != nullptr); - auto* type_proto = node_arg->TypeAsProto(); - assert(type_proto != nullptr); - return LayoutTransformRule_set_shape( - *const_cast(type_proto), shape); -} - -void node_arg_set_shape_i64(const NodeArg& node_arg, - const std::vector& shape) { - LayoutTransformRule_set_shape(const_cast(&node_arg), shape); -} - -static std::vector LayoutTransformRule_get_denotation( - const onnx::TensorShapeProto& shape) { +vaip_core::DllSafe> node_arg_get_denotation(const NodeArg& node_arg) { + auto shape = node_arg.Shape(); + if (shape == nullptr) { + return vaip_core::DllSafe>(); + } auto ret = std::vector(); - auto rank = shape.dim_size(); - ret.reserve(rank); + auto rank = shape->dim_size(); for (auto i = 0; i < rank; ++i) { - auto& d = shape.dim(i).denotation(); - ret.push_back(d); + ret.push_back(shape->dim(i).denotation()); } - return ret; + return vaip_core::DllSafe>(ret); } - -static vaip_core::DllSafe> LayoutTransformRule_get_denotation( - const onnx::TypeProto& type_proto) { - vai_assert(type_proto.value_case() == onnx::TypeProto::kTensorType, type_proto.DebugString()); - auto& tensor_type = type_proto.tensor_type(); - if (!tensor_type.has_shape()) { - return vaip_core::DllSafe>(); - } - auto& shape = tensor_type.shape(); - auto denotation = LayoutTransformRule_get_denotation(shape); - return vaip_core::DllSafe>(denotation); -} - -static vaip_core::DllSafe> LayoutTransformRule_get_denotation( - const NodeArg* node_arg) { - assert(node_arg != nullptr); - auto* type_proto = node_arg->TypeAsProto(); - assert(type_proto != nullptr); - return LayoutTransformRule_get_denotation(*type_proto); -} - -vaip_core::DllSafe> node_arg_get_denotation(const NodeArg& node_arg) { - return LayoutTransformRule_get_denotation(&node_arg); -} - -static onnx::TensorShapeProto* node_arg_get_tensor_mutable_shape( - NodeArg* node_arg) { - assert(node_arg != nullptr); - auto type_proto = const_cast(node_arg->TypeAsProto()); - assert(type_proto != nullptr); - vai_assert(type_proto->value_case() == onnx::TypeProto::kTensorType, - type_proto->DebugString()); - return type_proto->mutable_tensor_type()->mutable_shape(); -} - -static void LayoutTransformRule_set_denotation( - onnx::TensorShapeProto& shape, const std::vector& denotation) { - assert(denotation.size() == static_cast(shape.dim_size())); - auto rank = shape.dim_size(); +void node_arg_set_denotation(const NodeArg& node_arg, const std::vector& denotation) { + auto shape_proto = const_cast(node_arg.Shape()); + assert(shape_proto != nullptr); + assert(denotation.size() == static_cast(shape_proto->dim_size())); + auto rank = shape_proto->dim_size(); for (auto i = 0; i < rank; ++i) { - shape.mutable_dim(i)->set_denotation(denotation[i]); + shape_proto->mutable_dim(i)->set_denotation(denotation[i]); } } -void node_arg_set_denotation(const NodeArg& node_arg, - const std::vector& denotation) { - auto mutable_shape = - node_arg_get_tensor_mutable_shape(const_cast(&node_arg)); - - return LayoutTransformRule_set_denotation(*mutable_shape, denotation); -} - -void node_arg_set_element_type(NodeArg& node_arg, - onnx::TensorProto::DataType data_type) { - auto type_proto = const_cast(node_arg.TypeAsProto()); +void node_arg_set_element_type(NodeArg& node_arg, int type) { + if (type < 0 || type > 16) { + vai_assert(false, "TensorProto::DataType not supoort"); + } + auto data_type = static_cast(type); + auto type_proto = const_cast(node_arg.TypeAsProto()); assert(type_proto != nullptr); auto current_elem_type = type_proto->mutable_tensor_type()->elem_type(); auto input_elem_type = data_type; @@ -138,24 +67,12 @@ void node_arg_set_element_type(NodeArg& node_arg, current_elem_type, true); vai_assert(status.IsOK(), status.ErrorMessage()); } -void node_arg_set_shape(NodeArg& node_arg, std::vector shape) { - auto type_proto = const_cast(node_arg.TypeAsProto()); - assert(type_proto != nullptr); - for (auto i = 0u; i < shape.size(); i++) { - type_proto->mutable_tensor_type() - ->mutable_shape() - ->mutable_dim(i) - ->set_dim_value(shape[i]); - } -} - const ONNX_NAMESPACE::TensorProto& node_arg_get_const_data_as_tensor( const Graph& graph, const NodeArg& node_arg) { auto tensor_proto = graph.GetConstantInitializer(node_arg.Name(), true); assert(tensor_proto != nullptr); return *tensor_proto; } - int node_arg_get_element_type(const NodeArg& node_arg) { auto type_proto = node_arg.TypeAsProto(); assert(type_proto != nullptr); @@ -164,9 +81,7 @@ int node_arg_get_element_type(const NodeArg& node_arg) { } return type_proto->tensor_type().elem_type(); } - -NodeArg& node_arg_clone(Graph& graph, const NodeArg& node_arg, - const std::string& name) { +NodeArg& node_arg_clone(Graph& graph, const NodeArg& node_arg, const std::string& name) { vai_assert(name != node_arg.Name(), "node arg must have a new unique name"); vai_assert(graph.GetNodeArg(name) == nullptr, std::string("node arg " + name + " already exists. ")); auto type_proto = node_arg.TypeAsProto(); @@ -174,12 +89,10 @@ NodeArg& node_arg_clone(Graph& graph, const NodeArg& node_arg, auto& ret = graph.GetOrCreateNodeArg(name, type_proto); return ret; } - -NodeArg& node_arg_new(Graph& graph, - const std::string& name, const std::vector* shape, int element_type) { +NodeArg& node_arg_new(Graph& graph, const std::string& name, const std::vector* shape, int element_type) { vai_assert(graph.GetNodeArg(name) == nullptr, std::string("node arg " + name + " already exists. ")); - auto type_proto = onnx::TypeProto(); - auto tensor_type = type_proto.mutable_tensor_type(); + auto type_proto = ONNX_NAMESPACE::TypeProto::Create(); + auto tensor_type = type_proto->mutable_tensor_type(); tensor_type->set_elem_type(element_type); if (shape != nullptr) { auto shape_proto = tensor_type->mutable_shape(); @@ -189,8 +102,6 @@ NodeArg& node_arg_new(Graph& graph, } else { assert(tensor_type->has_shape() == false); } - auto& ret = graph.GetOrCreateNodeArg(name, &type_proto); - return ret; + return graph.GetOrCreateNodeArg(name, type_proto.release()); } - } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/node_attrs.cc b/onnxruntime/core/providers/vitisai/imp/node_attrs.cc deleted file mode 100644 index e438266e2a4c0..0000000000000 --- a/onnxruntime/core/providers/vitisai/imp/node_attrs.cc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. -#include "vaip/node_attrs.h" -#include "./vai_assert.h" - -namespace vaip { -static onnx::AttributeProto make_attribute(const std::string& name, - int64_t value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::INT); - ret.set_i(value); - return ret; -} - -static onnx::AttributeProto make_attribute(const std::string& name, - const std::vector value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::INTS); - for (auto v : value) { - ret.add_ints(v); - } - return ret; -} - -static onnx::AttributeProto make_attribute(const std::string& name, - const std::string& value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::STRING); - ret.set_s(value); - return ret; -} -static onnx::AttributeProto make_attribute( - const std::string& name, const std::vector& value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::STRINGS); - for (auto v : value) { - ret.add_strings(v); - } - return ret; -} - -static onnx::AttributeProto make_attribute(const std::string& name, - const std::vector& value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::FLOATS); - for (auto v : value) { - ret.add_floats(v); - } - return ret; -} - -static onnx::AttributeProto make_attribute(const std::string& name, - const onnx::TensorProto& value) { - auto ret = onnx::AttributeProto(); - ret.set_name(name); - ret.set_type(onnx::AttributeProto::TENSOR); - *(ret.mutable_t()) = std::move(value); - return ret; -} // namespace vaip - -NodeAttr::NodeAttr(const std::string& name, int64_t value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, const std::vector& value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, const std::string& value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, - const std::vector& value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, const std::vector& value) - : attribute_proto_{make_attribute(name, value)} {} - -NodeAttr::NodeAttr(const std::string& name, const onnx::TensorProto& value) - : attribute_proto_{make_attribute(name, value)} {} - -onnx::AttributeProto& NodeAttr::get() { return attribute_proto_; } - -NodeAttributesBuiler::NodeAttributesBuiler(size_t capacity) : attrs_{} { - attrs_.reserve(capacity); -} - -NodeAttributes NodeAttributesBuiler::build() { - auto ret = NodeAttributes(); - ret.reserve(attrs_.size()); - for (auto& node_attr : attrs_) { - onnx::AttributeProto& attr_proto = node_attr.get(); - auto name = attr_proto.name(); - ret.insert(std::make_pair(name, std::move(attr_proto))); - } - attrs_.clear(); - return ret; -} - -void NodeAttributesBuiler::merge_into(Node& node) { - merge_into(node.GetMutableAttributes()); -} - -void NodeAttributesBuiler::merge_into(NodeAttributes& attrs) { - for (auto& attr : attrs_) { - vai_assert(attr.get().has_name(), std::string("attr must has name " + attr.get().DebugString())); - auto name = attr.get().name(); - attrs.insert_or_assign(std::move(name), std::move(attr.get())); - } -} -} // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc index ee8dfc6d03d12..97ed2d3b4b8a1 100644 --- a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +++ b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc @@ -1,130 +1,25 @@ - - // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. + #include "./register_xir_ops.h" #include "./vai_assert.h" - -#include "core/common/logging/logging.h" -#include "core/common/status.h" - -#include "core/framework/customregistry.h" - +#include "core/providers/shared_library/provider_api.h" #include "core/session/onnxruntime_c_api.h" -#include "core/session/custom_ops.h" -#include "core/session/inference_session.h" -#include "onnx/defs/schema.h" -#include "onnx/defs/shape_inference.h" using namespace onnxruntime; -namespace vaip { - -static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { - auto* shape = ctx.getAttribute("shape"); - auto* data_type = ctx.getAttribute("data_type"); - if (data_type->s() == "float32") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT); - } else if (data_type->s() == "int8") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT8); - } else if (data_type->s() == "uint8") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::UINT8); - } else if (data_type->s() == "int32") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT32); - } else if (data_type->s() == "int64") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT64); - } else if (data_type->s() == "int1") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); - } else if (data_type->s() == "bfloat16") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BFLOAT16); - } else if (data_type->s() == "float16") { - updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT16); - } else { - vai_assert(false, ", not supported data_type: " + data_type->s()); - } - if (shape != nullptr) { - for (auto i = 0; i < shape->ints_size(); ++i) { - ONNX_NAMESPACE::appendDim(ONNX_NAMESPACE::getOutputShape(ctx, 0), shape->ints(i)); - } - } else { - // set scalar type. - auto* output_shape = ONNX_NAMESPACE::getOutputShape(ctx, 0); - output_shape->clear_dim(); - } - return; -} - -static void xir_fixneuron_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { - ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); -} - -static void xir_subgraph_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { - auto num_inputs = ctx.getNumInputs(); - - // Run inferencing on the subgraph - ONNX_NAMESPACE::GraphInferencer* graphInferencer = ctx.getGraphAttributeInferencer("body"); - if (!graphInferencer) { - fail_type_inference("body is missing."); - } - - std::vector input_data; - std::vector subgraph_input_types; - for (size_t i = 0; i < num_inputs; ++i) { - input_data.push_back(ctx.getInputData(i)); - subgraph_input_types.push_back(ctx.getInputType(i)); - } - std::vector output_types; - output_types = - graphInferencer->doInferencing(subgraph_input_types, input_data); - - auto num_outputs = ctx.getNumOutputs(); - auto num_of_the_subgraph_outputs = output_types.size(); - if (num_outputs != num_of_the_subgraph_outputs) { - fail_type_inference("super layer has ", num_outputs, - " but subgraphs produce ", num_of_the_subgraph_outputs); - } - for (size_t i = 0, end = output_types.size(); i < end; ++i) { - auto subgraph_output = output_types[i]; - auto* super_layer_output = ctx.getOutputType(i); - *super_layer_output = *subgraph_output; - } -} +namespace vaip { void register_xir_ops(const std::vector& domains) { - std::shared_ptr custom_registry; - auto status = CreateCustomRegistry(gsl::span(domains), custom_registry); - vai_assert(status.IsOK(), status.ErrorMessage()); for (auto domain : domains) { for (auto op : domain->custom_ops_) { auto name = op->GetName(op); - auto schema1 = custom_registry->GetOpschemaRegistry()->GetSchema(name, ORT_API_VERSION, domain->domain_); - auto schema2 = ::ONNX_NAMESPACE::OpSchema(); - schema2.SetName(schema1->Name()); - schema2.SetDomain(schema1->domain()); - auto n = 0; - for (auto input : schema1->inputs()) { - schema2.Input(n, input.GetName(), input.GetDescription(), std::string("T") + std::to_string(n), input.GetOption(), false, input.GetMinArity(), input.GetDifferentiationCategory()); - schema2.TypeConstraint(std::string("T") + std::to_string(n), DataTypeImpl::ToString(DataTypeImpl::AllTensorTypes()), "all types"); - n = n + 1; - } - auto m = n; - n = 0; - for (auto output : schema1->outputs()) { - auto type_str = std::string("T") + std::to_string(n + m); - schema2.Output(n, output.GetName(), output.GetDescription(), type_str, output.GetOption(), false, output.GetMinArity(), output.GetDifferentiationCategory()); - schema2.TypeConstraint(type_str, DataTypeImpl::ToString(DataTypeImpl::AllTensorTypes()), "all types"); - n = n + 1; - } - schema2.SinceVersion(1); - schema2.AllowUncheckedAttributes(); if ((std::string)name == "super_layer") { - schema2.TypeAndShapeInferenceFunction(xir_subgraph_shape_inference); + Provider_GetHost()->RegisterSchema(domain->domain_, op, 1); } else if ((std::string)name == "FixNeuron") { - schema2.TypeAndShapeInferenceFunction(xir_fixneuron_shape_inference); + Provider_GetHost()->RegisterSchema(domain->domain_, op, 2); } else { - schema2.TypeAndShapeInferenceFunction(xir_shape_infer); + Provider_GetHost()->RegisterSchema(domain->domain_, op, 3); } - ONNX_NAMESPACE::RegisterSchema(schema2, ORT_API_VERSION); } } } diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index db03354bf4c44..48dcd220a150c 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -1,20 +1,19 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #include "./tensor_proto.h" -#include "./vai_assert.h" -#include "core/framework/tensorprotoutils.h" #include #include +#include "./vai_assert.h" +#include "core/providers/shared_library/provider_api.h" namespace vaip { - -gsl::span tensor_proto_as_raw( - const ONNX_NAMESPACE::TensorProto& tensor) { +gsl::span tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& tensor) { auto& mut_tensor = const_cast(tensor); if (!tensor.has_raw_data()) { std::vector unpacked_tensor; - auto s = onnxruntime::utils::UnpackInitializerData(tensor, onnxruntime::Path(), unpacked_tensor); + auto path = onnxruntime::Path::Create(); + auto s = onnxruntime::utils::UnpackInitializerData(tensor, *path, unpacked_tensor); mut_tensor.mutable_raw_data()->resize(unpacked_tensor.size()); mut_tensor.clear_float_data(); mut_tensor.clear_int32_data(); @@ -27,78 +26,51 @@ gsl::span tensor_proto_as_raw( return gsl::span(tensor.raw_data().data(), tensor.raw_data().size()); } -size_t tensor_proto_raw_data_size(const ONNX_NAMESPACE::TensorProto& tensor) { - return tensor.raw_data().size(); -} - -std::vector tensor_proto_get_shape( - const onnx::TensorProto& tensor_proto) { +vaip_core::DllSafe> tensor_proto_get_shape(const ONNX_NAMESPACE::TensorProto& tensor_proto) { auto ret = std::vector(); int rank = tensor_proto.dims_size(); if (rank > 0) { - ret.reserve((size_t)rank); - for (auto i = 0; i < rank; ++i) { - ret.push_back(tensor_proto.dims(i)); + auto& dims = tensor_proto.dims(); + for (auto i = 0; i < dims.size(); ++i) { + ret.push_back(dims[i]); } } - return ret; + return vaip_core::DllSafe(ret); } - -const std::string& tensor_proto_get_name( - const ONNX_NAMESPACE::TensorProto& tensor) { - return tensor.name(); +static ONNX_NAMESPACE::TensorProto* tensor_proto_new(const std::string& name, const std::vector& shape, + int data_type, const char* data, size_t data_size) { + auto tensor_proto = ONNX_NAMESPACE::TensorProto::Create(); + tensor_proto->set_name(name); + for (auto s : shape) { + tensor_proto->add_dims(s); + } + tensor_proto->set_data_type(data_type); + tensor_proto->mutable_raw_data()->assign(data, data_size); + return tensor_proto.release(); } -ONNX_NAMESPACE::TensorProto tensor_proto_new_i32( - const std::string& name, const std::vector& shape, - const std::vector& data) { - auto tensor_proto = ONNX_NAMESPACE::TensorProto(); - tensor_proto.set_name(name); - tensor_proto.mutable_dims()->Clear(); - tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); - tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::INT32); - tensor_proto.mutable_raw_data()->assign( - reinterpret_cast(&data[0]), data.size() * sizeof(int32_t)); - return tensor_proto; +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i32(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT32, + reinterpret_cast(&data[0]), data.size() * sizeof(int32_t)); } -ONNX_NAMESPACE::TensorProto tensor_proto_new_i64( - const std::string& name, const std::vector& shape, - const std::vector& data) { - auto tensor_proto = ONNX_NAMESPACE::TensorProto(); - tensor_proto.set_name(name); - tensor_proto.mutable_dims()->Clear(); - tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); - tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::INT64); - tensor_proto.mutable_raw_data()->assign( - reinterpret_cast(&data[0]), data.size() * sizeof(int64_t)); - return tensor_proto; +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT64, + reinterpret_cast(&data[0]), data.size() * sizeof(int64_t)); } -ONNX_NAMESPACE::TensorProto tensor_proto_new_i8( - const std::string& name, const std::vector& shape, - const std::vector& data) { - auto tensor_proto = ONNX_NAMESPACE::TensorProto(); - tensor_proto.set_name(name); - tensor_proto.mutable_dims()->Clear(); - tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); - tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::INT8); - tensor_proto.mutable_raw_data()->assign( - reinterpret_cast(&data[0]), data.size() * sizeof(int8_t)); - return tensor_proto; +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_INT8, + reinterpret_cast(&data[0]), data.size() * sizeof(int8_t)); } -ONNX_NAMESPACE::TensorProto tensor_proto_new_floats( - const std::string& name, const std::vector& shape, - const std::vector& data) { - auto tensor_proto = ONNX_NAMESPACE::TensorProto(); - tensor_proto.set_name(name); - tensor_proto.mutable_dims()->Clear(); - tensor_proto.mutable_dims()->Add(shape.begin(), shape.end()); - tensor_proto.set_data_type(ONNX_NAMESPACE::TensorProto::FLOAT); - tensor_proto.mutable_raw_data()->assign( - reinterpret_cast(&data[0]), data.size() * sizeof(float)); - return tensor_proto; +ONNX_NAMESPACE::TensorProto* tensor_proto_new_floats(const std::string& name, const std::vector& shape, + const std::vector& data) { + return tensor_proto_new(name, shape, ONNX_NAMESPACE::TensorProto_DataType_FLOAT, + reinterpret_cast(&data[0]), data.size() * sizeof(float)); } } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h index 00aa388c809c1..292905ca734f1 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h @@ -1,31 +1,20 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #pragma once -// -#include "core/common/gsl.h" -#include "onnx/onnx_pb.h" -namespace vaip { - -gsl::span tensor_proto_as_raw( - const ONNX_NAMESPACE::TensorProto& tensor); -size_t tensor_proto_raw_data_size(const ONNX_NAMESPACE::TensorProto& tensor); - -std::vector tensor_proto_get_shape( - const ONNX_NAMESPACE::TensorProto& tensor); -const std::string& tensor_proto_get_name( - const ONNX_NAMESPACE::TensorProto& tensor); -ONNX_NAMESPACE::TensorProto tensor_proto_new_i8( - const std::string& name, const std::vector& shape, - const std::vector& data); -ONNX_NAMESPACE::TensorProto tensor_proto_new_i32( - const std::string& name, const std::vector& shape, - const std::vector& data); -ONNX_NAMESPACE::TensorProto tensor_proto_new_i64( - const std::string& name, const std::vector& shape, - const std::vector& data); - -ONNX_NAMESPACE::TensorProto tensor_proto_new_floats( - const std::string& name, const std::vector& shape, - const std::vector& data); +#include "vaip/my_ort.h" +#include "vaip/vaip_gsl.h" +#include "vaip/dll_safe.h" +namespace vaip { +gsl::span tensor_proto_as_raw(const ONNX_NAMESPACE::TensorProto& tensor); +vaip_core::DllSafe> tensor_proto_get_shape(const ONNX_NAMESPACE::TensorProto& tensor); +const std::string& tensor_proto_get_name(const ONNX_NAMESPACE::TensorProto& tensor); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i8(const std::string& name, const std::vector& shape, + const std::vector& data); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i32(const std::string& name, const std::vector& shape, + const std::vector& data); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_i64(const std::string& name, const std::vector& shape, + const std::vector& data); +ONNX_NAMESPACE::TensorProto* tensor_proto_new_floats(const std::string& name, const std::vector& shape, + const std::vector& data); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/capability.h b/onnxruntime/core/providers/vitisai/include/vaip/capability.h index d6b5ae34decc2..e7644dbe86354 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/capability.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/capability.h @@ -2,8 +2,7 @@ // Licensed under the MIT License. #pragma once -#include "core/framework/compute_capability.h" -#include "core/graph/graph_viewer.h" +#include "core/providers/shared_library/provider_api.h" #include "vaip/custom_op.h" namespace vaip { using namespace ::onnxruntime; diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index c446ab3aefcc5..1f8b8802e86b4 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -2,16 +2,15 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #pragma once -#include -#include -#include - +#include "core/providers/shared_library/provider_api.h" +#define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/provider_options.h" #include "vaip/my_ort.h" #include "vaip/dll_safe.h" #include "vaip/custom_op.h" -std::vector initialize_vitisai_ep(); -vaip_core::DllSafe>> compile_onnx_model_with_options( - const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); +void initialize_vitisai_ep(); +vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); +std::shared_ptr get_kernel_registry_vitisaiep(); +const std::vector& get_domains_vitisaiep(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/graph.h b/onnxruntime/core/providers/vitisai/include/vaip/graph.h index 9def8645709fb..292fb2bb38b2b 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/graph.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/graph.h @@ -1,25 +1,19 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #pragma once -#include #include "./node.h" +#include "vaip/my_ort.h" namespace vaip { using namespace onnxruntime; void graph_remove_node(Graph& graph, const NodeInput& node_input); -Node& graph_add_node(Graph& graph, const std::string& name, - const std::string& op_type, const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - const NodeAttributes& attributes, - const std::string& domain); - -void graph_save(const Graph& graph, const std::string& filename, const std::string& dat_filename, size_t initializer_size_threshold); -Node& graph_fuse(Graph& graph, const std::string& name, - const std::string& op_type, - const std::vector& nodes, - const std::vector& inputs, - const std::vector& outputs, +Node& graph_add_node(Graph& graph, const std::string& name, const std::string& op_type, const std::string& description, + const std::vector& input_args, const std::vector& output_args, + const NodeAttributes& attributes, const std::string& domain); +void graph_save(const Graph& graph, const std::string& filename, const std::string& dat_filename, + size_t initializer_size_threshold); +Node& graph_fuse(Graph& graph, const std::string& name, const std::string& op_type, const std::vector& nodes, + const std::vector& inputs, const std::vector& outputs, const std::vector& constant_initializers); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h index d43ef1253715c..46fc4ac9b2a5d 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/my_ort.h @@ -9,15 +9,17 @@ #include namespace onnxruntime { -class Model; -class Graph; -class GraphViewer; -class Node; -class NodeArg; +struct Model; +struct Graph; +struct GraphViewer; +struct Node; +struct NodeArg; +struct ProviderHost; +struct NodeAttributes; } // namespace onnxruntime namespace ONNX_NAMESPACE { -class AttributeProto; -class TensorProto; +struct AttributeProto; +struct TensorProto; #ifndef USE_VITISAI enum TensorProto_DataType : int { TensorProto_DataType_UNDEFINED = 0, @@ -68,6 +70,7 @@ using onnxruntime::GraphViewer; using onnxruntime::Model; using onnxruntime::Node; using onnxruntime::NodeArg; +using onnxruntime::NodeAttributes; struct ModelDeleter { VAIP_DLL_SPEC void operator()(Model* tp) const; }; @@ -75,22 +78,17 @@ using ModelPtr = std::unique_ptr; struct AttributeProtoDeleter { VAIP_DLL_SPEC void operator()(AttributeProto* p) const; }; -using AttributeProtoPtr = - std::unique_ptr; +using AttributeProtoPtr = std::unique_ptr; struct TensorProtoDeleter { VAIP_DLL_SPEC void operator()(TensorProto* tp) const; }; using TensorProtoPtr = std::unique_ptr; -/// I cannot forward declare a using directive, because -/// std::unorderd_map required AttributeProto must be defiend. -class NodeAttributes; struct NodeAttributesDeleter { VAIP_DLL_SPEC void operator()(NodeAttributes* p) const; }; -using NodeAttributesPtr = - std::unique_ptr; +using NodeAttributesPtr = std::unique_ptr; /// get node's input /// when Node* is nullptr, it is a tensor in the initializer. /// node_arg is always non-null. diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node.h b/onnxruntime/core/providers/vitisai/include/vaip/node.h index bad7660f66744..31d9d4bd73b8b 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/node.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/node.h @@ -2,10 +2,6 @@ // Licensed under the MIT License. #pragma once - -#include - -#include "core/graph/node_arg.h" #include "vaip/dll_safe.h" #include "vaip/my_ort.h" namespace vaip { @@ -17,8 +13,4 @@ vaip_core::DllSafe> node_get_inputs(const Node& node); /// to support multiple outputs vaip_core::DllSafe> node_get_output_node_args(const Node& node); -/// get output shape -/// index is usually zero, because most operators only have a single output. -vaip_core::DllSafe> node_get_output_shape(const Node& node, int index = 0); - } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h b/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h index 76432fc5b3a68..fca641c5e11c8 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h @@ -2,9 +2,8 @@ // Licensed under the MIT License. #pragma once -#include #include "vaip/dll_safe.h" -#include +#include "vaip/my_ort.h" namespace vaip { using namespace onnxruntime; @@ -26,9 +25,7 @@ void node_arg_set_shape_i64(const NodeArg& node_arg, void node_arg_set_denotation(const NodeArg& node_arg, const std::vector& denotation); void node_arg_set_element_type(NodeArg& node_arg, - ONNX_NAMESPACE::TensorProto::DataType data_type); -void node_arg_set_shape(NodeArg& node_arg, std::vector shape); - + int data_type); const ONNX_NAMESPACE::TensorProto& node_arg_get_const_data_as_tensor(const Graph& graph, const NodeArg& node_arg); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h b/onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h deleted file mode 100644 index 49cd1aad89f4f..0000000000000 --- a/onnxruntime/core/providers/vitisai/include/vaip/node_attrs.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. - -#pragma once -#include - -#include - -#include "core/graph/basic_types.h" -namespace vaip { -using namespace onnxruntime; -class NodeAttr { - public: - NodeAttr(const std::string& name, int64_t value); - NodeAttr(const std::string& name, const std::vector& value); - NodeAttr(const std::string& name, const std::string& value); - NodeAttr(const std::string& name, const std::vector& value); - NodeAttr(const std::string& name, const std::vector& value); - NodeAttr(const std::string& name, const onnx::TensorProto& value); - - onnx::AttributeProto& get(); - - private: - onnx::AttributeProto attribute_proto_; -}; - -class NodeAttributesBuiler { - public: - explicit NodeAttributesBuiler(size_t capacity = 10); - NodeAttributesBuiler(const NodeAttributesBuiler&) = delete; - NodeAttributesBuiler(NodeAttributesBuiler&&) = default; - /// after build, all attrs_ are cleared. - NodeAttributes build(); - /// for efficiency reason, after merge_into, all attrs_ are moved. - void merge_into(Node& node); - void merge_into(NodeAttributes& attrs); - template - NodeAttributesBuiler& add(const std::string& name, T&& value) { - attrs_.emplace_back(name, std::forward(value)); - return *this; - } - - private: - std::vector attrs_; -}; -} // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 0d7d5f6220d06..ae5f71d66269c 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,6 +13,7 @@ struct OrtApi; namespace vaip_core { struct OrtApiForVaip { + onnxruntime::ProviderHost* host_; const OrtApi* ort_api_; // model Model* (*model_load)(const std::string& file); // [0] @@ -49,7 +50,7 @@ struct OrtApiForVaip { const std::string& description, const std::vector& input_args, const std::vector& output_args, - NodeAttributes& attributes, + const NodeAttributes& attributes, const std::string& domain); // [18] void (*graph_save)(const Graph& graph, const std::string& filename, const std::string& dat_filename, @@ -119,8 +120,8 @@ struct OrtApiForVaip { NodeAttributes* (*node_attributes_new)(); // [46] void (*node_attributes_delete)(NodeAttributes* p); // [47] void (*node_attributes_add)(NodeAttributes& p, AttributeProto&& attr); // [48] - AttributeProto* (*node_attributes_get)(NodeAttributes& p, - const std::string& name); // [49] + const AttributeProto* (*node_attributes_get)(const NodeAttributes& p, + const std::string& name); // [49] DllSafe> (*node_attributes_get_keys)( NodeAttributes& p); // [50] /// attr proto @@ -194,5 +195,4 @@ VAIP_DLL_SPEC const OrtApiForVaip* api(); ? ::vaip_core::api()->name \ : (assert(false && #name " is not set"), nullptr)) #endif -VAIP_DLL_SPEC void initialize_ort(); } // namespace vaip_core diff --git a/onnxruntime/core/providers/vitisai/symbols.def b/onnxruntime/core/providers/vitisai/symbols.def new file mode 100644 index 0000000000000..4ec2f7914c208 --- /dev/null +++ b/onnxruntime/core/providers/vitisai/symbols.def @@ -0,0 +1,2 @@ +EXPORTS + GetProvider diff --git a/onnxruntime/core/providers/vitisai/version_script.lds b/onnxruntime/core/providers/vitisai/version_script.lds new file mode 100644 index 0000000000000..2c8e9c4b3ed64 --- /dev/null +++ b/onnxruntime/core/providers/vitisai/version_script.lds @@ -0,0 +1,9 @@ +#_init and _fini should be local +VERS_1.0 { + global: + GetProvider; + + # Hide everything else. + local: + *; +}; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 5f20b32cd6dc4..6fc09f3495aa1 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -1,91 +1,34 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. -#include "core/graph/graph_utils.h" #include "vitisai_execution_provider.h" #include -#include #include #include -#include "core/common/common.h" - #include "vaip/capability.h" #include "vaip/global_api.h" -#include "core/session/custom_ops.h" -#include "core/session/inference_session.h" using namespace ONNX_NAMESPACE; namespace onnxruntime { - constexpr const char* VITISAI = "VITISAI"; -static vaip_core::DllSafe>> compile_onnx_model( - const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { -#ifndef _WIN32 - auto model_path = graph_viewer.ModelPath().ToPathString(); -#else - using convert_t = std::codecvt_utf8; - std::wstring_convert strconverter; - auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); -#endif - return compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options); -} - -struct MyCustomOpKernel : OpKernel { - MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { - op_kernel_ = - op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), reinterpret_cast(&info)); - } - - ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } - - Status Compute(OpKernelContext* ctx) const override { - op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); - return Status::OK(); - } - - private: - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MyCustomOpKernel); - - const OrtCustomOp& op_; - void* op_kernel_; -}; - -VitisAIExecutionProvider::VitisAIExecutionProvider(const ProviderOptions& info) +VitisAIExecutionProvider::VitisAIExecutionProvider( + const ProviderOptions& info) : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) { - custom_op_domains_ = initialize_vitisai_ep(); - registry_ = std::make_shared(); CreateKernelRegistry(); } void VitisAIExecutionProvider::CreateKernelRegistry() { - for (const auto& domain : custom_op_domains_) { + for (const auto& domain : get_domains_vitisaiep()) { for (const auto* op : domain->custom_ops_) { - KernelDefBuilder def_builder; - def_builder.SetName(op->GetName(op)); - def_builder.SetDomain(domain->domain_); - def_builder.SinceVersion(1); - if (op->version > 12) { - auto input_count = op->GetInputTypeCount(op); - for (auto i = 0u; i < input_count; i++) { - def_builder.InputMemoryType(op->GetInputMemoryType(op, i), i); - } - } - def_builder.Provider(onnxruntime::kVitisAIExecutionProvider); - KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, - std::unique_ptr& out) -> Status { - out = std::make_unique(info, *op); - return Status::OK(); - }; - std::ignore = registry_->Register(def_builder, kernel_create_fn); vitisai_optypes_.insert(op->GetName(op)); } } } -std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return registry_; } +std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return get_kernel_registry_vitisaiep(); } std::vector> VitisAIExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { @@ -111,9 +54,9 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector& node_compute_funcs) { for (const auto& fused_node_graph : fused_nodes_and_graphs) { NodeComputeInfo compute_info; - const onnx::AttributeProto* attr = graph_utils::GetNodeAttribute(fused_node_graph.fused_node, "index"); - assert(attr != nullptr); - size_t index = (size_t)attr->i(); + auto& attrs = fused_node_graph.fused_node.get().GetAttributes(); + assert(attrs.count("index")); + size_t index = attrs.at("index").i(); compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index e86b53339d4d2..186427be4fab2 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -9,8 +9,7 @@ #include #include -#include "core/framework/execution_provider.h" -#include "core/framework/customregistry.h" +#include "core/providers/shared_library/provider_api.h" #include "core/session/onnxruntime_c_api.h" // we cannot include vaip/vaip.hpp here because header file referred by @@ -21,7 +20,6 @@ class DllSafe; class ExecutionProvider; } // namespace vaip_core namespace onnxruntime { - // Logical device representation. class VitisAIExecutionProvider : public IExecutionProvider { public: diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 4c416124ca8f2..5895e1973f231 100755 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -11,7 +11,6 @@ #include "core/framework/execution_provider.h" #include "core/session/abi_session_options_impl.h" -#include "core/providers/shared_library/provider_host_api.h" using namespace onnxruntime; namespace onnxruntime { @@ -30,10 +29,37 @@ std::unique_ptr VitisAIProviderFactory::CreateProvider() { return std::make_unique(info_); } -std::shared_ptr VitisAIProviderFactoryCreator::Create( - const ProviderOptions& provider_options) { - initialize_vitisai_ep(); - return std::make_shared(provider_options); -} +struct VitisAI_Provider : Provider { + // Takes a pointer to a provider specific structure to create the factory. For example, with OpenVINO it is a pointer to an OrtOpenVINOProviderOptions structure + std::shared_ptr + CreateExecutionProviderFactory(const void* options) override { + return std::make_shared(GetProviderOptions(options)); + } + // Convert provider options struct to ProviderOptions which is a map + ProviderOptions GetProviderOptions(const void* options) override { + auto vitisai_options = reinterpret_cast(options); + return *vitisai_options; + } + // Update provider options from key-value string configuration + void UpdateProviderOptions(void* options, const ProviderOptions& provider_options) override { + auto vitisai_options = reinterpret_cast(options); + for (const auto& entry : provider_options) { + vitisai_options->insert_or_assign(entry.first, entry.second); + } + }; + // Get provider specific custom op domain list. Provider has the resposibility to release OrtCustomOpDomain instances it creates. + void GetCustomOpDomainList(IExecutionProviderFactory*, std::vector&) override{}; + // Called right after loading the shared library, if this throws any errors Shutdown() will be called and the library unloaded + void Initialize() override { initialize_vitisai_ep(); } + // Called right before unloading the shared library + void Shutdown() override {} +} g_provider; } // namespace onnxruntime + +extern "C" { + +ORT_API(onnxruntime::Provider*, GetProvider) { + return &onnxruntime::g_provider; +} +} diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 91a7f0d930b51..dec8754ea244f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2724,6 +2724,7 @@ static constexpr OrtApi ort_api_1_to_18 = { &OrtApis::SetDeterministicCompute, &OrtApis::KernelContext_ParallelFor, &OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2, + &OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index c1caafa4dcad3..9ce94ba89a942 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -509,4 +509,8 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_OpenVINO_V2, _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + +ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); } // namespace OrtApis diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 2e445e4982d24..32ae15e71acc6 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -57,6 +57,8 @@ namespace ONNX_NAMESPACE { // We use these names in the provider API because we don't have the protobuf definitions of the RepeatedField* types using int64s = google::protobuf::RepeatedField; +using float32s = google::protobuf::RepeatedField; +using StringStringEntryProtos = google::protobuf::RepeatedPtrField; using TensorProtos = google::protobuf::RepeatedPtrField; using TensorShapeProto_Dimensions = google::protobuf::RepeatedPtrField; using ValueInfoProtos = google::protobuf::RepeatedPtrField; @@ -77,6 +79,7 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; #include "core/providers/migraphx/migraphx_provider_factory_creator.h" #include "core/providers/openvino/openvino_provider_factory_creator.h" #include "core/providers/tensorrt/tensorrt_provider_factory_creator.h" +#include "core/providers/vitisai/vitisai_provider_factory_creator.h" #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/cann/cann_provider_factory.h" @@ -123,6 +126,7 @@ ProviderInfo_Dnnl& GetProviderInfo_Dnnl(); ProviderInfo_ROCM* TryGetProviderInfo_ROCM(); ProviderInfo_ROCM& GetProviderInfo_ROCM(); ProviderHostCPU& GetProviderHostCPU(); +ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector& ops); struct TensorShapeProto_Dimension_Iterator_Impl : TensorShapeProto_Dimension_Iterator { TensorShapeProto_Dimension_Iterator_Impl(google::protobuf::internal::RepeatedPtrIterator&& v) : v_{std::move(v)} {} @@ -274,7 +278,10 @@ struct ProviderHostImpl : ProviderHost { Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint32_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int64_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint64_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } - + Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& tensor, const Path& model_path, + /*out*/ std::vector& unpacked_tensor) override { + return utils::UnpackInitializerData(tensor, model_path, unpacked_tensor); + } uint16_t math__floatToHalf(float f) override { return math::floatToHalf(f); } float math__halfToFloat(uint16_t h) override { return math::halfToFloat(h); } @@ -352,12 +359,32 @@ struct ProviderHostImpl : ProviderHost { void logging__Capture__operator_delete(logging::Capture* p) noexcept override { delete p; } std::ostream& logging__Capture__Stream(logging::Capture* p) noexcept override { return p->Stream(); } + // Env + Env& Env__Default() override { return Env::Default(); } + // Utils::DataTypeUtils (wrapped) const std::string* Utils__DataTypeUtils__ToType(const ONNX_NAMESPACE::TypeProto& type_proto) override { return ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(type_proto); } // int64s (wrapped) int int64s__size(const ONNX_NAMESPACE::int64s* p) override { return p->size(); } const int64_t& int64s__Get(const ONNX_NAMESPACE::int64s* p, int index) override { return p->Get(index); } + void int64s__Reserve(ONNX_NAMESPACE::int64s* p, int size) override { p->Reserve(size); }; + const int64_t* int64s__data(const ONNX_NAMESPACE::int64s* p) override { return p->data(); } + + // float32s + void float32s__Reserve(ONNX_NAMESPACE::float32s* p, int size) override { p->Reserve(size); }; + const float* float32s__data(const ONNX_NAMESPACE::float32s* p) override { return p->data(); } + int float32s__size(const ONNX_NAMESPACE::float32s* p) override { return p->size(); } + + // StringStringEntryProto + std::string* StringStringEntryProto__mutable_key(ONNX_NAMESPACE::StringStringEntryProto* p) override { return p->mutable_key(); } + std::string* StringStringEntryProto__mutable_value(ONNX_NAMESPACE::StringStringEntryProto* p) override { return p->mutable_value(); } + + // StringStringEntryProtos + void StringStringEntryProtos__Clear(ONNX_NAMESPACE::StringStringEntryProtos* p) override { p->Clear(); }; + ONNX_NAMESPACE::StringStringEntryProto* StringStringEntryProtos__Add(ONNX_NAMESPACE::StringStringEntryProtos* p) override { return p->Add(); } + int StringStringEntryProtos__size(ONNX_NAMESPACE::StringStringEntryProtos* p) override { return p->size(); } + ONNX_NAMESPACE::StringStringEntryProto& StringStringEntryProtos__at(ONNX_NAMESPACE::StringStringEntryProtos* p, int index) override { return p->at(index); }; #if !defined(DISABLE_OPTIONAL_TYPE) // TypeProto_Optional (wrapped) @@ -374,6 +401,7 @@ struct ProviderHostImpl : ProviderHost { const ONNX_NAMESPACE::TensorShapeProto& TypeProto_Tensor__shape(const ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->shape(); } ONNX_NAMESPACE::TensorShapeProto* TypeProto_Tensor__mutable_shape(ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->mutable_shape(); } int32_t TypeProto_Tensor__elem_type(const ONNX_NAMESPACE::TypeProto_Tensor* p) override { return p->elem_type(); } + void TypeProto_Tensor__set_elem_type(ONNX_NAMESPACE::TypeProto_Tensor* p, int32_t value) override { p->set_elem_type(value); }; // TypeProto_SparseTensor (wrapped) #if !defined(DISABLE_SPARSE_TENSORS) @@ -426,9 +454,18 @@ struct ProviderHostImpl : ProviderHost { float AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p, int i) override { return p->floats(i); } const std::string& AttributeProto__strings(const ONNX_NAMESPACE::AttributeProto* p, int i) override { return p->strings(i); } const ONNX_NAMESPACE::int64s& AttributeProto__ints(const ONNX_NAMESPACE::AttributeProto* p) override { return p->ints(); } + const ONNX_NAMESPACE::float32s& AttributeProto__floats(const ONNX_NAMESPACE::AttributeProto* p) override { return p->floats(); } + ONNX_NAMESPACE::int64s* AttributeProto__mutable_ints(ONNX_NAMESPACE::AttributeProto* p) override { return p->mutable_ints(); } + ONNX_NAMESPACE::float32s* AttributeProto__mutable_floats(ONNX_NAMESPACE::AttributeProto* p) override { return p->mutable_floats(); } + void AttributeProto__add_ints(ONNX_NAMESPACE::AttributeProto* p, int64_t value) override { p->add_ints(value); }; + void AttributeProto__add_floats(ONNX_NAMESPACE::AttributeProto* p, float value) override { p->add_floats(value); }; + void AttributeProto__add_strings(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { p->add_strings(value); }; + int64_t AttributeProto__i(const ONNX_NAMESPACE::AttributeProto* p) override { return p->i(); } float AttributeProto__f(const ONNX_NAMESPACE::AttributeProto* p) override { return p->f(); } + const ONNX_NAMESPACE::TensorProto& AttributeProto__t(const ONNX_NAMESPACE::AttributeProto* p) override { return p->t(); } void AttributeProto__set_s(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { return p->set_s(value); } + void AttributeProto__set_f(ONNX_NAMESPACE::AttributeProto* p, const float& value) override { return p->set_f(value); } void AttributeProto__set_i(ONNX_NAMESPACE::AttributeProto* p, int64_t value) override { return p->set_i(value); } const ::std::string& AttributeProto__s(const ONNX_NAMESPACE::AttributeProto* p) override { return p->s(); } void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { return p->set_name(value); } @@ -450,6 +487,7 @@ struct ProviderHostImpl : ProviderHost { ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_value_info(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_value_info(); } ONNX_NAMESPACE::TensorProtos* GraphProto__mutable_initializer(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_initializer(); } ONNX_NAMESPACE::NodeProto* GraphProto__add_node(ONNX_NAMESPACE::GraphProto* p) override { return p->add_node(); } + std::string* GraphProto__mutable_name(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_name(); } ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) override { return p->mutable_node(index); } void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) override { *p = v; } @@ -467,6 +505,7 @@ struct ProviderHostImpl : ProviderHost { ONNX_NAMESPACE::GraphProto* ModelProto__mutable_graph(ONNX_NAMESPACE::ModelProto* p) override { return p->mutable_graph(); } void ModelProto__set_ir_version(ONNX_NAMESPACE::ModelProto* p, int64_t value) override { p->set_ir_version(value); } + ONNX_NAMESPACE::StringStringEntryProtos* ModelProto__mutable_metadata_props(ONNX_NAMESPACE::ModelProto* p) override { return p->mutable_metadata_props(); }; // NodeProto (wrapped) std::unique_ptr NodeProto__construct() override { return std::make_unique(); } @@ -481,19 +520,34 @@ struct ProviderHostImpl : ProviderHost { void TensorProto__operator_delete(ONNX_NAMESPACE::TensorProto* p) override { delete p; } void TensorProto__operator_assign(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto& v) override { *p = v; } bool TensorProto__has_name(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_name(); } + void TensorProto__set_name(ONNX_NAMESPACE::TensorProto* p, const ::std::string& name) override { p->set_name(name); } + const ::std::string& TensorProto__name(const ONNX_NAMESPACE::TensorProto* p) override { return p->name(); } int TensorProto__dims_size(const ONNX_NAMESPACE::TensorProto* p) override { return p->dims_size(); } const ONNX_NAMESPACE::int64s& TensorProto__dims(const ONNX_NAMESPACE::TensorProto* p) override { return p->dims(); } + void TensorProto__add_dims(ONNX_NAMESPACE::TensorProto* p, int64_t value) override { p->add_dims(value); } bool TensorProto__has_data_location(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_data_location(); } int TensorProto__data_location(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_location(); } bool TensorProto__has_raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->has_raw_data(); } const std::string& TensorProto__raw_data(const ONNX_NAMESPACE::TensorProto* p) override { return p->raw_data(); } + std::string* TensorProto__mutable_raw_data(ONNX_NAMESPACE::TensorProto* p) override { return p->mutable_raw_data(); } + int32_t TensorProto__data_type(const ONNX_NAMESPACE::TensorProto* p) override { return p->data_type(); } + void TensorProto__set_data_type(ONNX_NAMESPACE::TensorProto* p, int32_t type) override { p->set_data_type(type); } bool TensorProto_DataType_IsValid(int value) override { return ONNX_NAMESPACE::TensorProto::DataType_IsValid(value); } void TensorProto__CopyFrom(ONNX_NAMESPACE::TensorProto* p, const ONNX_NAMESPACE::TensorProto* other) override { p->CopyFrom(*other); } + ONNX_NAMESPACE::StringStringEntryProtos* TensorProto__mutable_external_data(ONNX_NAMESPACE::TensorProto* p) override { return p->mutable_external_data(); }; + void TensorProto__clear_float_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_float_data(); } + void TensorProto__clear_int32_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_int32_data(); } + void TensorProto__clear_string_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_string_data(); } + void TensorProto__clear_int64_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_int64_data(); } + void TensorProto__clear_double_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_double_data(); } + void TensorProto__clear_uint64_data(ONNX_NAMESPACE::TensorProto* p) override { p->clear_uint64_data(); } // TensorProtos (wrapped) ONNX_NAMESPACE::TensorProto* TensorProtos__Add(ONNX_NAMESPACE::TensorProtos* p) override { return p->Add(); } + int TensorProtos__size(ONNX_NAMESPACE::TensorProtos* p) override { return p->size(); } + ONNX_NAMESPACE::TensorProto& TensorProtos__at(ONNX_NAMESPACE::TensorProtos* p, int index) override { return p->at(index); }; // TensorShapeProto_Dimension (wrapped) int TensorShapeProto_Dimension__value_case(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->value_case(); } @@ -503,6 +557,8 @@ struct ProviderHostImpl : ProviderHost { void TensorShapeProto_Dimension__set_dim_value(ONNX_NAMESPACE::TensorShapeProto_Dimension* p, int64_t value) override { return p->set_dim_value(value); } bool TensorShapeProto_Dimension__has_dim_value(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->has_dim_value(); } bool TensorShapeProto_Dimension__has_dim_param(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) override { return p->has_dim_param(); } + const std::string& TensorShapeProto_Dimension__denotation(const ONNX_NAMESPACE::TensorShapeProto_Dimension* p) const override { return p->denotation(); } + void TensorShapeProto_Dimension__set_denotation(ONNX_NAMESPACE::TensorShapeProto_Dimension* p, const std::string& value) override { return p->set_denotation(value); } // TensorShapeProto_Dimensions (wrapped) std::unique_ptr TensorShapeProto_Dimensions__begin(const ONNX_NAMESPACE::TensorShapeProto_Dimensions* p) override { @@ -531,6 +587,90 @@ struct ProviderHostImpl : ProviderHost { const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; } + static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { + auto* shape = ctx.getAttribute("shape"); + auto* data_type = ctx.getAttribute("data_type"); + int32_t elemType = 0; + if (data_type->s() == "float32") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + } else if (data_type->s() == "int8") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT8; + } else if (data_type->s() == "uint8") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT8; + } else if (data_type->s() == "int32") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT32; + } else if (data_type->s() == "int64") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT64; + } else if (data_type->s() == "int1") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_BOOL; + } else if (data_type->s() == "bfloat16") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16; + } else if (data_type->s() == "float16") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; + } else if (data_type->s() == "uint16") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT16; + } else if (data_type->s() == "int16") { + elemType = ONNX_NAMESPACE::TensorProto_DataType_INT16; + } else { + return; + } + ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType); + if (shape != nullptr) { + for (auto i = 0; i < shape->ints_size(); ++i) { + ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); + } + } else { + // set scalar type. + ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim(); + } + } + + static void xir_fixneuron_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); + } + + static void xir_subgraph_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { + auto num_inputs = ctx.getNumInputs(); + + // Run inferencing on the subgraph + auto* graphInferencer = ctx.getGraphAttributeInferencer("body"); + + std::vector input_data; + std::vector subgraph_input_types; + for (size_t i = 0; i < num_inputs; ++i) { + input_data.push_back(ctx.getInputData(i)); + subgraph_input_types.push_back(ctx.getInputType(i)); + } + + auto output_types = graphInferencer->doInferencing(subgraph_input_types, input_data); + for (size_t i = 0, end = output_types.size(); i < end; ++i) { + *ctx.getOutputType(i) = *output_types[i]; + } + } + void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) override { + auto& domain_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); + const auto& domain_to_version_map = domain_instance.Map(); + if (domain_to_version_map.find(domain) == domain_to_version_map.end()) { + domain_instance.AddDomainToVersion(domain, 1, 1000); + } + auto schema = CreateSchema(domain, {op}); + switch (type) { + case 1: + schema.TypeAndShapeInferenceFunction(xir_subgraph_shape_inference); + break; + case 2: + schema.TypeAndShapeInferenceFunction(xir_fixneuron_shape_inference); + break; + case 3: + schema.TypeAndShapeInferenceFunction(xir_shape_infer); + break; + default: + break; + } + ONNX_NAMESPACE::RegisterSchema(schema, ORT_API_VERSION); + } + // ConfigOptions (wrapped) std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) override { return p->GetConfigEntry(config_key); @@ -762,6 +902,9 @@ struct ProviderHostImpl : ProviderHost { void Node__ToProto(const Node* p, ONNX_NAMESPACE::NodeProto& proto, bool update_subgraphs = false) override { p->ToProto(proto, update_subgraphs); } const NodeAttributes& Node__GetAttributes(const Node* p) noexcept override { return p->GetAttributes(); } + void Node__AddAttribute(Node* p, const ::std::string& attr_name, const ONNX_NAMESPACE::GraphProto& value) override { + p->AddAttribute(attr_name, value); + } size_t Node__GetInputEdgesCount(const Node* p) noexcept override { return p->GetInputEdgesCount(); } size_t Node__GetOutputEdgesCount(const Node* p) noexcept override { return p->GetOutputEdgesCount(); } @@ -770,13 +913,19 @@ struct ProviderHostImpl : ProviderHost { std::unique_ptr Node__OutputNodesBegin(const Node* p) noexcept override { return std::make_unique(p->OutputNodesBegin()); } std::unique_ptr Node__OutputNodesEnd(const Node* p) noexcept override { return std::make_unique(p->OutputNodesEnd()); } - + std::unique_ptr Node__InputEdgesBegin(const Node* p) noexcept override { + return std::make_unique(p->InputEdgesBegin()); + } + std::unique_ptr Node__InputEdgesEnd(const Node* p) noexcept override { + return std::make_unique(p->InputEdgesEnd()); + } std::unique_ptr Node__OutputEdgesBegin(const Node* p) noexcept override { return std::make_unique(p->OutputEdgesBegin()); } std::unique_ptr Node__OutputEdgesEnd(const Node* p) noexcept override { return std::make_unique(p->OutputEdgesEnd()); } void Node__ForEachDef(const Node* p, std::function func, bool include_missing_optional_defs) override { p->ForEachDef(func, std::move(include_missing_optional_defs)); } const std::unordered_map>& Node__GetAttributeNameToMutableSubgraphMap(Node* p) noexcept override { return p->GetAttributeNameToMutableSubgraphMap(); } std::unordered_map> Node__GetAttributeNameToSubgraphMap(const Node* p) const override { return p->GetAttributeNameToSubgraphMap(); } + int Node__NodeType(const Node* p) const noexcept override { return int(p->NodeType()); } // NodeArg (wrapped) const std::string& NodeArg__Name(const NodeArg* p) noexcept override { return p->Name(); } @@ -785,6 +934,7 @@ struct ProviderHostImpl : ProviderHost { const NodeArgInfo& NodeArg__ToProto(const NodeArg* p) noexcept override { return p->ToProto(); } bool NodeArg__Exists(const NodeArg* p) const noexcept override { return p->Exists(); } const ONNX_NAMESPACE::TypeProto* NodeArg__TypeAsProto(const NodeArg* p) noexcept override { return p->TypeAsProto(); } + Status NodeArg__OverrideTypesHelper(NodeArg* p, const ONNX_NAMESPACE::TypeProto& input_type, int32_t input_tensor_elem_type, int32_t current_tensor_elem_type, bool override_types) override { return p->OverrideTypesHelper(input_type, input_tensor_elem_type, current_tensor_elem_type, override_types); }; // NodeAttributes (wrapped) std::unique_ptr NodeAttributes__construct() override { return std::make_unique(); } @@ -807,12 +957,20 @@ struct ProviderHostImpl : ProviderHost { } void NodeAttributes__insert(NodeAttributes* p, const NodeAttributes& v) override { return p->insert(v.begin(), v.end()); } void NodeAttributes__emplace(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) override { p->emplace(k, v); } + void NodeAttributes__insert_or_assign(NodeAttributes* p, const std::string& k, const ONNX_NAMESPACE::AttributeProto& v) override { p->insert_or_assign(k, v); } void NodeAttributes__reserve(NodeAttributes* p, size_t size) override { p->reserve(size); } // Model (wrapped) + std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, + const logging::Logger& logger) override { + return std::make_unique(model_proto, model_path, nullptr, logger); + } void Model__operator_delete(Model* p) override { delete p; } Graph& Model__MainGraph(Model* p) override { return p->MainGraph(); } std::unique_ptr Model__ToProto(Model* p) override { return std::make_unique(p->ToProto()); } + std::unique_ptr Model__ToGraphProtoWithExternalInitializers(Model* p, const std::string& external_file_name, const PathString& file_path, size_t initializer_size_threshold) override { return std::make_unique(p->ToGraphProtoWithExternalInitializers(external_file_name, file_path, initializer_size_threshold)); }; + const ModelMetaData& Model__MetaData(const Model* p) const noexcept override { return p->MetaData(); }; + Status Model__Load(const PathString& file_path, /*out*/ ONNX_NAMESPACE::ModelProto& model_proto) override { return Model::Load(file_path, model_proto); } // Graph (wrapped) std::unique_ptr Graph__CreateGraphViewer(const Graph* p) override { return std::make_unique(*p); } @@ -832,6 +990,12 @@ struct ProviderHostImpl : ProviderHost { void Graph__SetOutputs(Graph* p, gsl::span outputs) override { p->SetOutputs(outputs); } const std::vector& Graph__GetInputs(const Graph* p) noexcept override { return p->GetInputs(); } + std::vector Graph__Nodes(const Graph* p) override { + auto& node_refererence = p->Nodes(); + std::vector nodes(p->NumberOfNodes(), nullptr); + std::transform(node_refererence.begin(), node_refererence.end(), nodes.begin(), [](const Node& n) { return &n; }); + return nodes; + } bool Graph__GetInitializedTensor(const Graph* p, const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) override { return p->GetInitializedTensor(tensor_name, value); } const Node* Graph__ParentNode(const Graph* p) const override { return p->ParentNode(); } @@ -841,6 +1005,40 @@ struct ProviderHostImpl : ProviderHost { const Path& Graph__ModelPath(const Graph* p) const override { return p->ModelPath(); } const std::vector& Graph__GetInputsIncludingInitializers(const Graph* p) const noexcept override { return p->GetInputsIncludingInitializers(); } bool Graph__IsSubgraph(const Graph* p) override { return p->IsSubgraph(); } + const Node* Graph__GetProducerNode(const Graph* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } + const Model& Graph__GetModel(const Graph* p) override { return p->GetModel(); } + void Graph__ReverseDFSFrom(const Graph* p, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& comp, + const std::function& stop) const override { + p->ReverseDFSFrom(from, enter, leave, comp, stop); + } + Graph& Graph__SetGraphResolveNeeded(Graph* p) override { return p->SetGraphResolveNeeded(); } + void Graph__RemoveInitializedTensor(Graph* p, const std::string& tensor_name) override { p->RemoveInitializedTensor(tensor_name); } + + std::vector Graph__GetConsumerNodes(const Graph* p, const std::string& node_arg_name) const override { + return p->GetConsumerNodes(node_arg_name); + } + void Graph__AddEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, + int dst_arg_index) override { + p->AddEdge(src_node_index, dst_node_index, src_arg_index, dst_arg_index); + } + void Graph__RemoveEdge(Graph* p, NodeIndex src_node_index, NodeIndex dst_node_index, int src_arg_index, + int dst_arg_index) override { + p->RemoveEdge(src_node_index, dst_node_index, src_arg_index, dst_arg_index); + } + void Graph__RemoveNode(Graph* p, NodeIndex index) override { p->RemoveNode(index); } + Node& Graph__FuseSubGraph(Graph* p, const IndexedSubGraph& sub_graph, const std::string& fused_node_name) override { + return p->FuseSubGraph(sub_graph, fused_node_name); + } + void Graph__UpdateProducerNode(Graph* p, const std::string& node_arg_name, NodeIndex node_index) override { + p->UpdateProducerNode(node_arg_name, node_index); + } + const ONNX_NAMESPACE::TensorProto* Graph__GetConstantInitializer(const Graph* p, const std::string& name, bool check_outer_scope) const override { + return p->GetConstantInitializer(name, check_outer_scope); + } + const InitializedTensorSet& Graph__GetAllInitializedTensors(const Graph* p) override { return p->GetAllInitializedTensors(); } int Graph__MaxNodeIndex(const Graph* p) const noexcept override { return p->MaxNodeIndex(); } Node* Graph__GetNode(Graph* p, NodeIndex node_index) noexcept override { return p->GetNode(node_index); } const Node* Graph__GetNode(const Graph* p, NodeIndex node_index) const override { return p->GetNode(node_index); } @@ -885,11 +1083,14 @@ struct ProviderHostImpl : ProviderHost { void GraphViewer__ToProto(const GraphViewer* p, ONNX_NAMESPACE::GraphProto& graph_proto, bool include_initializers, bool include_outer_scope_args) noexcept override { GraphViewerToProto(*p, graph_proto, include_initializers, include_outer_scope_args); } + const Node* GraphViewer__GetProducerNode(const GraphViewer* p, const std::string& node_arg_name) const override { return p->GetProducerNode(node_arg_name); } // Path (wrapped) PathString Path__ToPathString(const Path* p) noexcept override { return p->ToPathString(); } const std::vector& Path__GetComponents(const Path* p) noexcept override { return p->GetComponents(); } bool Path__IsEmpty(const Path* p) noexcept override { return p->IsEmpty(); } + std::unique_ptr Path__construct() override { return std::make_unique(); } + void Path__operator_delete(ONNX_NAMESPACE::Path* p) override { delete p; }; // OpKernel (direct) const Node& OpKernel__Node(const OpKernel* p) override { return p->OpKernel::Node(); } @@ -1280,6 +1481,7 @@ static ProviderLibrary s_library_rocm(LIBRARY_PREFIX ORT_TSTR("onnxruntime_provi #endif ); static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_dnnl") LIBRARY_EXTENSION); +static ProviderLibrary s_library_vitisai(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_vitisai") LIBRARY_EXTENSION); static ProviderLibrary s_library_openvino(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_openvino") LIBRARY_EXTENSION); static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX ORT_TSTR("onnxruntime_providers_tensorrt") LIBRARY_EXTENSION #ifndef _WIN32 @@ -1308,6 +1510,7 @@ static ProviderLibrary s_library_migraphx(LIBRARY_PREFIX ORT_TSTR("onnxruntime_p void UnloadSharedProviders() { s_library_dnnl.Unload(); + s_library_vitisai.Unload(); s_library_openvino.Unload(); s_library_tensorrt.Unload(); s_library_cuda.Unload(); @@ -1524,6 +1727,10 @@ std::shared_ptr DnnlProviderFactoryCreator::Create(co return s_library_dnnl.Get().CreateExecutionProviderFactory(dnnl_options); } +std::shared_ptr VitisAIProviderFactoryCreator::Create(const ProviderOptions& provider_options) { + return s_library_vitisai.Get().CreateExecutionProviderFactory(&provider_options); +} + ProviderInfo_OpenVINO* GetProviderInfo_OpenVINO() { return reinterpret_cast(s_library_openvino.Get().GetInfo()); } @@ -2416,3 +2623,34 @@ ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProvid ORT_UNUSED_PARAMETER(ptr); #endif } + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { + API_IMPL_BEGIN + onnxruntime::ProviderOptions provider_options; + for (size_t i = 0; i != num_keys; ++i) { + if (provider_options_keys[i] == nullptr || provider_options_keys[i][0] == '\0' || + provider_options_values[i] == nullptr || provider_options_values[i][0] == '\0') { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Provider options key/value cannot be empty"); + } + + // arbitrary length to validate the key/value. adjust if/when needed. + // TODO: are any other input validation checks required here (and in the other functions that process + // provider options)? + if (strlen(provider_options_keys[i]) > 1024 || strlen(provider_options_values[i]) > 1024) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Maximum string length for a provider options key/value is 1024."); + } + + provider_options[provider_options_keys[i]] = provider_options_values[i]; + } + auto factory = onnxruntime::VitisAIProviderFactoryCreator::Create(provider_options); + if (!factory) { + return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_VitisAI: Failed to load shared library"); + } + + options->provider_factories.push_back(factory); + return nullptr; + API_IMPL_END +} diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 964355956b4ab..ade1d96d617fb 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -148,12 +148,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options, &(options->value))); #else status = create_not_supported_status(); -#endif - } else if (strcmp(provider_name, "VitisAI") == 0) { -#if defined(USE_VITISAI) - options->provider_factories.push_back(VitisAIProviderFactoryCreator::Create(provider_options)); -#else - status = create_not_supported_status(); #endif } else { ORT_UNUSED_PARAMETER(options); @@ -499,4 +493,14 @@ ORT_API_STATUS_IMPL(OrtApis::GetROCMProviderOptionsAsString, ORT_API(void, OrtApis::ReleaseROCMProviderOptions, _Frees_ptr_opt_ OrtROCMProviderOptions* ptr) { ORT_UNUSED_PARAMETER(ptr); } + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys) { + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(provider_options_keys); + ORT_UNUSED_PARAMETER(provider_options_values); + ORT_UNUSED_PARAMETER(num_keys); + return CreateNotEnabledStatus("VitisAI"); +} #endif diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 8e13982ca6861..9c36eb635ffcf 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -982,7 +982,7 @@ std::unique_ptr CreateExecutionProviderInstance( return onnxruntime::TVMProviderFactoryCreator::Create(info)->CreateProvider(); #endif } else if (type == kVitisAIExecutionProvider) { -#if USE_VITISAI +#ifdef USE_VITISAI const auto it = provider_options_map.find(type); if (it == provider_options_map.end()) { LOGS_DEFAULT(FATAL) << "cannot find provider options for VitisAIExecutionProvider"; diff --git a/setup.py b/setup.py index e94165fdf9b05..67d34b065ad03 100644 --- a/setup.py +++ b/setup.py @@ -298,6 +298,7 @@ def finalize_options(self): libs.extend(["libonnxruntime_providers_shared.so"]) libs.extend(["libonnxruntime_providers_dnnl.so"]) libs.extend(["libonnxruntime_providers_openvino.so"]) + libs.extend(["libonnxruntime_providers_vitisai.so"]) libs.append(providers_cuda_or_rocm) libs.append(providers_tensorrt_or_migraphx) libs.append(providers_cann) @@ -310,6 +311,7 @@ def finalize_options(self): libs.extend(["libonnxruntime_providers_dnnl.dylib"]) libs.extend(["libonnxruntime_providers_tensorrt.dylib"]) libs.extend(["libonnxruntime_providers_cuda.dylib"]) + libs.extend(["libonnxruntime_providers_vitisai.dylib"]) if nightly_build: libs.extend(["libonnxruntime_pywrapper.dylib"]) else: @@ -320,6 +322,7 @@ def finalize_options(self): libs.extend(["onnxruntime_providers_tensorrt.dll"]) libs.extend(["onnxruntime_providers_openvino.dll"]) libs.extend(["onnxruntime_providers_cuda.dll"]) + libs.extend(["onnxruntime_providers_vitisai.dll"]) # DirectML Libs libs.extend(["DirectML.dll"]) if nightly_build: From eb0ce86db8eaa5a207dbd433ba8d0630dc60db4c Mon Sep 17 00:00:00 2001 From: Linnea May Date: Thu, 1 Feb 2024 10:26:37 -0800 Subject: [PATCH 019/207] [DML] Resize 18 & 19 (#19071) ### Description Register resize-18 and -19, which will be lit up automatically when dml feature level bumps up to 6300. It's worth noting that DML has a different implementation for antialias than does ORT CPU. DML does iterative downsampling whenever the scale factor is less than 0.5. This is equivalent to performing resize with a variable-sized input window (also equivalent to mip mapping). ORT takes a different approach, using the same convolution approach as PIL. The two implementations approach each other in certain cases (with iota-generated data) but they usually aren't perfectly equivalent. ### Motivation and Context --------- Co-authored-by: Linnea May --- .../src/Operators/DmlOperatorResize.cpp | 54 ++++++++--- .../src/Operators/OperatorRegistration.cpp | 8 +- .../dml/OperatorAuthorHelper/Attributes.h | 1 + .../OperatorAuthorHelper/OperatorHelper.cpp | 92 +++++++++--------- .../dml/OperatorAuthorHelper/OperatorHelper.h | 50 +++++++++- .../OperatorAuthorHelper/OperatorVersions.h | 2 + .../providers/cpu/tensor/resize_op_test.cc | 97 ++++++++++++++++++- 7 files changed, 242 insertions(+), 62 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp index f332fac9d3a09..b7cceb1d1d998 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp @@ -9,11 +9,12 @@ namespace Dml constexpr NameAndIndex coordinateTransformationModes[] = { {"half_pixel", 0}, - {"pytorch_half_pixel", 1}, - {"align_corners", 2}, - {"asymmetric", 3}, - {"tf_half_pixel_for_nn", 4}, - {"tf_crop_and_resize", 5}, + {"half_pixel_symmetric", 1}, + {"pytorch_half_pixel", 2}, + {"align_corners", 3}, + {"asymmetric", 4}, + {"tf_half_pixel_for_nn", 5}, + {"tf_crop_and_resize", 6}, }; constexpr NameAndIndex nearestNeighborRoundingModes[] = @@ -50,7 +51,7 @@ void ComputePixelOffsetsAndScales( uint32_t coordinateTransformationModeValue = *optionalCoordinateTransformationModeValue; ML_CHECK_VALID_ARGUMENT( - !regionOfInterest.empty() || coordinateTransformationModeValue != 5 /*tf_crop_and_resize*/, + !regionOfInterest.empty() || coordinateTransformationModeValue != 6 /*tf_crop_and_resize*/, "Resize expects 'roi' tensor for 'tf_crop_and_resize' mode." ); @@ -88,6 +89,18 @@ void ComputePixelOffsetsAndScales( break; case 1: + // coordinate_transformation_mode is "half_pixel_symmetric", + // adjustment = output_width_int / output_width + // center = input_width / 2 + // offset = center * (1 - adjustment) + // x_original = (x + 0.5) / scale - (0.5 - offset) + // x_original = (x + 0.5) / scale - (0.5 - [(input_width / 2) * (1 - (output_width_int / output_width))]) + // output_width can be fractional when calculated with scale factor + inputPixelOffset = 0.5f - float((inputDimensions[i] / 2.0f) * (1.0f - outputDimensions[i] / (scales[i] * inputDimensions[i]))); + outputPixelOffset = -0.5; + break; + + case 2: // if coordinate_transformation_mode is "pytorch_half_pixel", // x_original = length_resized > 1 ? (x_resized + 0.5) / scale - 0.5 : 0 if (inputDimensions[i] <= 1) @@ -104,7 +117,7 @@ void ComputePixelOffsetsAndScales( } break; - case 2: + case 3: // if coordinate_transformation_mode is "align_corners", // x_original = x_resized * (length_original - 1) / (length_resized - 1) inputPixelOffset = 0.0; @@ -121,7 +134,7 @@ void ComputePixelOffsetsAndScales( } break; - case 3: + case 4: // if coordinate_transformation_mode is "asymmetric", // x_original = x_resized / scale inputPixelOffset = 0.0; @@ -129,7 +142,7 @@ void ComputePixelOffsetsAndScales( // Keep existing scales. break; - case 4: + case 5: // if coordinate_transformation_mode is "tf_half_pixel_for_nn", // x_original = (x_resized + 0.5) / scale inputPixelOffset = 0.0; @@ -137,7 +150,7 @@ void ComputePixelOffsetsAndScales( // Keep existing scales. break; - case 5: + case 6: // if coordinate_transformation_mode is "tf_crop_and_resize", // x_original = length_resized > 1 ? start_x * (length_original - 1) + x_resized * (end_x - start_x) * (length_original - 1) / (length_resized - 1) // : 0.5 * (start_x + end_x) * (length_original - 1) @@ -177,7 +190,7 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper public: // Resample a multidimensional image to a new size. DmlOperatorResize(const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t opsetVersion) - : DmlOperator(kernelCreationContext), + : DmlOperator(kernelCreationContext), ResizeHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription(), opsetVersion) { ML_CHECK_VALID_ARGUMENT(!m_scales.empty(), "Resize/Upsample expect scales, either a 2nd input tensors or 'scales' attribute."); @@ -250,6 +263,11 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper std::string mode = kernelCreationContext.GetOptionalAttribute(AttrName::Mode, "NEAREST"); DML_INTERPOLATION_MODE interpolationMode = Dml::MapStringToInteropolationMode(mode); + +#if DML_TARGET_VERSION >= 0x6300 + const int antialiased = kernelCreationContext.GetOptionalAttribute(AttrName::Antialiased, 0); +#endif + // Map ONNX to DML's mode using offsets and rounding direction. // These offsets are in addition to the coordinate transform offsets. DML_AXIS_DIRECTION roundingDirection = DML_AXIS_DIRECTION_DECREASING; @@ -289,7 +307,12 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); +#if DML_TARGET_VERSION >= 0x6300 + DML_RESAMPLE3_OPERATOR_DESC operatorDesc = {}; + operatorDesc.Antialiased = static_cast(antialiased); +#else DML_RESAMPLE2_OPERATOR_DESC operatorDesc = {}; +#endif operatorDesc.InputTensor = inputDescs.data(); operatorDesc.OutputTensor = outputDescs.data(); operatorDesc.InterpolationMode = interpolationMode; @@ -298,8 +321,11 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper operatorDesc.DimensionCount = gsl::narrow_cast(paddedScales.size()); operatorDesc.InputPixelOffsets = inputPixelOffsets.data(); operatorDesc.OutputPixelOffsets = outputPixelOffsets.data(); - +#if DML_TARGET_VERSION >= 0x6300 + DML_OPERATOR_DESC opDesc = { DML_OPERATOR_RESAMPLE3, &operatorDesc }; +#else DML_OPERATOR_DESC opDesc = { DML_OPERATOR_RESAMPLE2, &operatorDesc }; +#endif SetDmlOperatorDesc(opDesc, kernelCreationContext); } }; @@ -342,6 +368,10 @@ void CALLBACK QueryResize(IMLOperatorSupportQueryContextPrivate* context, bool* DML_OP_DEFINE_CREATION_FUNCTION(Resize10, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Resize11, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Resize13, VersionedKernel); +#if DML_TARGET_VERSION >= 0x6300 +DML_OP_DEFINE_CREATION_FUNCTION(Resize18, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Resize19, VersionedKernel); +#endif DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Upsample10, VersionedKernel); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 7b53a1102c5a7..9c136ed8c9484 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -508,6 +508,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(Trilu); #if DML_TARGET_VERSION >= 0x6300 DML_OP_EXTERN_CREATION_FUNCTION(Col2Im); +DML_OP_EXTERN_CREATION_FUNCTION(Resize18); +DML_OP_EXTERN_CREATION_FUNCTION(Resize19); #endif DML_OP_EXTERN_CREATION_FUNCTION(Shape); @@ -600,6 +602,7 @@ constexpr static std::array supportedTypeListSigned constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64|SupportedTensorDataTypes::Float32}; constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */}; constexpr static std::array supportedTypeListResize13 = supportedTypeListResize11; +constexpr static std::array supportedTypeListResize18 = supportedTypeListResize11; constexpr static std::array supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 }; constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; constexpr static std::array supportedTypeListRoiAlign = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64 }; @@ -973,7 +976,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_VER( 10, Resize, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported, requiredConstantCpuInputs(1) /*scales*/)}, {REG_INFO_VER( 11, Resize, typeNameListTwo, supportedTypeListResize11, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, {REG_INFO_VER( 13, Resize, typeNameListTwo, supportedTypeListResize13, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, - +#if DML_TARGET_VERSION >= 0x6300 + {REG_INFO_VER( 18, Resize, typeNameListTwo, supportedTypeListResize18, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, + {REG_INFO_VER( 19, Resize, typeNameListTwo, supportedTypeListResize18, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, +#endif // Activation Functions {REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 13, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index 9c5d021f52b36..287deaa513f64 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -12,6 +12,7 @@ namespace AttrName static constexpr const char* AllowZero = "allowzero"; static constexpr const char* Alpha = "alpha"; static constexpr const char* AlignCorners = "align_corners"; + static constexpr const char* Antialiased = "antialias"; static constexpr const char* AutoPad = "auto_pad"; static constexpr const char* Axes = "axes"; static constexpr const char* Axis = "axis"; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 83c6748fadd35..317f5ebcbc3e1 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -56,6 +56,18 @@ namespace OperatorHelper } } + template + void ExpandToAxes(/*inout*/ std::vector& originalValues, gsl::span axes, std::vector expanded) + { + assert(originalValues.size() == axes.size()); + // Fill in roi and scales/sizes + for (size_t i = 0; i < axes.size(); i++) + { + expanded[axes[i]] = originalValues[i]; + } + originalValues = std::move(expanded); + } + float CastFloat16ToFloat32(uint16_t input) { // Promote float16m10e5s1 to float32m23e8s1. @@ -144,50 +156,6 @@ namespace OperatorHelper } #pragma warning(pop) - void ReadCpuLocalTensorIntoInt32( - const MLOperatorTensor& tensor, - std::vector& result - ) - { - result.clear(); - ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor."); - - const std::vector& tensorDimensions = tensor.GetShape(); - const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions); - - switch (tensor.GetTensorDataType()) - { - case MLOperatorTensorDataType::Int32: - { - const int32_t* data = tensor.GetData(); - result.assign(data, data + elementCount); - } - break; - - case MLOperatorTensorDataType::Int64: - { - const int64_t* data = tensor.GetData(); - result.reserve(elementCount); - - // Use clamped cast rather than static_cast/narrow_cast, - // because it's not uncommon for a model to specify a - // 64-bit INTMAX constant as a sentinel value to mean - // the largest possible value (even though the actual - // dimension values come nowhere close to that, far - // less than 32-bit INTMAX). - for (auto d : gsl::make_span(data, data + elementCount)) - { - result.push_back(clamp_cast(d)); - } - } - break; - - default: - ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64."); - break; - } - } - void ReadCpuLocalTensorIntoFloat32( const MLOperatorTensor& tensor, std::vector& result @@ -2461,7 +2429,8 @@ namespace OperatorHelper { auto& attributes = kernelInformation.GetAttributes(); m_inputDimensions = shapeInformation.GetInputTensorShape(0); - std::vector outputSizes; + std::vector outputSizes; + std::vector axes; if (opsetVersion >= 11) { @@ -2478,7 +2447,38 @@ namespace OperatorHelper if (kernelInformation.IsInputValid(3)) { MLOperatorTensor outputSizesTensor = kernelInformation.GetConstantInputTensor(3); - ReadCpuLocalTensorIntoInt32(outputSizesTensor, /*out*/ outputSizes); + ReadCpuLocalTensorIntoInt32(outputSizesTensor, /*out*/ outputSizes); + } + + axes = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::Axes); + // Handle possible axes input + if (opsetVersion >= 18 && !axes.empty()) + { + uint32_t dimCount = gsl::narrow_cast(m_inputDimensions.size()); + HandleEmptyAxes(/*inout*/ axes, m_inputDimensions, false); + HandleNegativeAxes(/*inout*/ axes, dimCount); + + // Taken from https://github.com/onnx/onnx/blob/3d69db8fd16873d68e7033479467f9478562a12d/onnx/reference/ops/op_resize.py#L303 + if (!m_scales.empty()) + { + std::vector defaultScales(dimCount, 1.0f); + ExpandToAxes(/*inout*/ m_scales, axes, defaultScales); + } + if (!outputSizes.empty()) + { + ExpandToAxes(/*inout*/ outputSizes, axes, m_inputDimensions); + } + if (!m_regionOfInterest.empty()) + { + std::vector defaultRois(dimCount, 0.0f); + defaultRois.resize(dimCount * 2, 1.0f); + size_t numAxes = axes.size(); + for (size_t i = 0; i < axes.size(); i++) + { + defaultRois[axes[i]] = m_regionOfInterest[i]; + defaultRois[axes[i + dimCount]] = m_regionOfInterest[i + numAxes]; + } + } } } else if (opsetVersion >= 9) diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index d4b44f6fa8a9d..1b2521a86613f 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -120,10 +120,54 @@ double CastToFloat64(MLOperatorTensorDataType tensorDataType, const void* p); void ReadScalarTensorData(const MLOperatorTensor& tensor, /*out*/ void* data, size_t dataByteSize); int64_t ReadScalarTensorCastToInt64(const MLOperatorTensor& tensor); double ReadScalarTensorCastToFloat64(const MLOperatorTensor& tensor); - -void ReadCpuLocalTensorIntoInt32(const MLOperatorTensor& tensor, std::vector& result); void ReadCpuLocalTensorIntoFloat32(const MLOperatorTensor& tensor, std::vector& result); +template +void ReadCpuLocalTensorIntoInt32( + const MLOperatorTensor& tensor, + std::vector& result + ) +{ + result.clear(); + ML_CHECK_VALID_ARGUMENT(tensor.IsCpuData(), "Tensor must be CPU Tensor."); + + const std::vector& tensorDimensions = tensor.GetShape(); + const uint32_t elementCount = ComputeElementCountFromDimensions(tensorDimensions); + + switch (tensor.GetTensorDataType()) + { + case MLOperatorTensorDataType::Int32: + { + result.resize(elementCount); + const int32_t* data = tensor.GetData(); + std::transform(data, data + elementCount, result.begin(), [](auto v) {return static_cast(v); }); + } + break; + + case MLOperatorTensorDataType::Int64: + { + const int64_t* data = tensor.GetData(); + result.reserve(elementCount); + + // Use clamped cast rather than static_cast/narrow_cast, + // because it's not uncommon for a model to specify a + // 64-bit INTMAX constant as a sentinel value to mean + // the largest possible value (even though the actual + // dimension values come nowhere close to that, far + // less than 32-bit INTMAX). + for (auto d : gsl::make_span(data, data + elementCount)) + { + result.push_back(clamp_cast(d)); + } + } + break; + + default: + ML_INVALID_ARGUMENT("Expecting CPU local tensor of type int32 or int64."); + break; + } +} + class EdgeShapes { public: @@ -1613,6 +1657,8 @@ using ShapeInferenceHelper_Tile = TileHelper; using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper; using ShapeInferenceHelper_Resize11 = VersionedOpsetHelper; using ShapeInferenceHelper_Resize13 = VersionedOpsetHelper; +using ShapeInferenceHelper_Resize18 = VersionedOpsetHelper; +using ShapeInferenceHelper_Resize19 = VersionedOpsetHelper; using ShapeInferenceHelper_OneHot = OneHotHelper; using ShapeInferenceHelper_Sqrt = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 57cb009b72ebc..e725ba085113d 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -408,11 +408,13 @@ namespace OperatorHelper static const int sc_sinceVer_Split = 18; static const int sc_sinceVer_LpPool = 18; static const int sc_sinceVer_Col2Im = 18; + static const int sc_sinceVer_Resize = 18; } namespace OnnxOperatorSet19 { static const int sc_sinceVer_AveragePool = 19; + static const int sc_sinceVer_Resize = 19; static const int sc_sinceVer_Pad = 19; static const int sc_sinceVer_Cast = 19; static const int sc_sinceVer_CastLike = 19; diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index f473c98ca713e..10f02349a24d5 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -1870,6 +1870,8 @@ void TestAntialiasing(std::map attributes, test.AddAttribute("extrapolation_value", std::stof(v)); } else if (k == "roi") { roi = parse_attr(v, 0.0f); + } else if (k == "antialias") { + test.AddAttribute("antialias", std::stoll(v)); } else { throw std::invalid_argument("Unknown attribute"); } @@ -1894,6 +1896,9 @@ void TestAntialiasing(std::map attributes, } TEST(ResizeOpTest, Antialias_Bilinear_No_ExcludeOutside) { + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + } std::vector X(16); std::iota(X.begin(), X.end(), 1.f); @@ -1912,7 +1917,6 @@ TEST(ResizeOpTest, Antialias_Bilinear_ExcludeOutside) { 12.1f, 13.3f, 14.5f}; TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y); } - TEST(ResizeOpTest, Antialias_Bilinear_Scale_Is_All_1) { std::vector X(3 * 4 * 5 * 6); std::iota(X.begin(), X.end(), 1.f); @@ -2009,6 +2013,9 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { } TEST(ResizeOpTest, Antialias_Trilinear_No_ExcludeOutside) { + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + } std::vector X(16 * 4); std::iota(X.begin(), X.end(), 0.f); std::vector Y = {5.7272725f, 6.9545455f, 8.181818f, 10.636364f, 11.863636f, @@ -2030,6 +2037,9 @@ TEST(ResizeOpTest, Antialias_Trilinear_ExcludeOutside) { } TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) { + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + } std::vector X(16 * 4 * 4); std::iota(X.begin(), X.end(), 0.f); { @@ -2118,6 +2128,9 @@ TEST(ResizeOpTest, Antialias_NHWCBicubic_ExcludeOutside) { } TEST(ResizeOpTest, Antialias_Linear_AlignCorners) { + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + } std::vector X(256); std::iota(X.begin(), X.end(), 0.0f); @@ -2231,5 +2244,87 @@ TEST(ResizeOpTest, Antialias_Use_Extrapolation) { }, {4, 4, 4}, X, {3, 3, 3}, Y); } + +TEST(ResizeOpTest, Antialias_Large_half_pixel) { + std::vector X{0.f, 1.f, 2.f, 3.f, 4.f, 5.f}; + std::vector Y = {1.f, 4.f}; + std::vector roi{}; + std::vector scales{}; + std::vector output_shape{1, 1, 2, 1}; + + OpTester test("Resize", 18); + + test.AddAttribute("exclude_outside", 0LL); + test.AddAttribute("antialias", 1LL); + test.AddAttribute("mode", "linear"); + + test.AddInput("X", {1, 1, 6, 1}, X); + test.AddInput("roi", {int64_t(roi.size())}, roi); + test.AddInput("", {0}, scales); + test.AddInput("sizes", {4}, output_shape); + + // Have absolute tolerance because ort is slightly different results. + // DML implementation is equivalent to resize with variable input window size while ORT using a convolution approach. + // Absolute error is for ORT CPU. + test.AddOutput("Y", output_shape, Y, false, /*rel_error*/ 0.0f, /*abs_error*/ 0.12f); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); +} + +// Test without anti-aliasing for better comparison with DirectML +TEST(ResizeOpTest, Axes_and_Scale_18) { + std::vector X(16 * 4); + std::iota(X.begin(), X.end(), 0.f); + std::vector Y = {3.5f, 4.8333335f, 6.1666665f, 8.833333f, 10.166667f, 11.5f, 14.166667f, + 15.5f, 16.833334f, 24.833334f, 26.166666f, 27.5f, 30.166666f, 31.5f, + 32.833332f, 35.5f, 36.833332f, 38.166668f, 46.166668f, 47.5f, 48.833332f, + 51.5f, 52.833332f, 54.166668f, 56.833332f, 58.166668f, 59.5}; + std::vector roi{}; + std::vector scales{3 / 4.0f, 3 / 4.0f, 3 / 4.0f}; + std::vector output_shape{1, 1, 3, 3, 3}; + std::vector axes{2, 3, 4}; + + OpTester test("Resize", 18); + + test.AddAttribute("exclude_outside", 0LL); + test.AddAttribute>("axes", axes); + test.AddAttribute("antialias", 0LL); + test.AddAttribute("mode", "linear"); + + test.AddInput("X", {1, 1, 4, 4, 4}, X); + test.AddInput("roi", {int64_t(roi.size())}, roi); + test.AddInput("scales", {int64_t(scales.size())}, scales, true); + + test.AddOutput("Y", output_shape, Y); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); +} + +TEST(ResizeOpTest, Axes_and_Size_18) { + std::vector X(16 * 4); + std::iota(X.begin(), X.end(), 0.f); + std::vector Y = {3.5f, 4.8333335f, 6.1666665f, 8.833333f, 10.166667f, 11.5f, 14.166667f, + 15.5f, 16.833334f, 24.833334f, 26.166666f, 27.5f, 30.166666f, 31.5f, + 32.833332f, 35.5f, 36.833332f, 38.166668f, 46.166668f, 47.5f, 48.833332f, + 51.5f, 52.833332f, 54.166668f, 56.833332f, 58.166668f, 59.5}; + std::vector roi{}; + std::vector scales{}; + std::vector output_shape{1, 1, 3, 3, 3}; + std::vector axes{2, 3, 4}; + + OpTester test("Resize", 18); + + test.AddAttribute("exclude_outside", 0LL); + test.AddAttribute>("axes", axes); + test.AddAttribute("antialias", 0LL); + test.AddAttribute("mode", "linear"); + + test.AddInput("X", {1, 1, 4, 4, 4}, X); + test.AddInput("roi", {int64_t(roi.size())}, roi); + test.AddInput("", {0}, scales); + test.AddInput("sizes", {3}, {3, 3, 3}); + + test.AddOutput("Y", output_shape, Y); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); +} + } // namespace test } // namespace onnxruntime From 0fa88bc8106acae37aa1876065e16fbf320b5b01 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 1 Feb 2024 15:04:29 -0800 Subject: [PATCH 020/207] Multi-partition support for context binary cache feature (#18865) ### Description Multi-partition support for context binary cache feature 1. In QNNEP create the list of EPContext nodes if ep_context_enable is enabled, so that it can dump the model with multiple partitions 2. Extend context loading part to support multiple EPContext nodes ### Motivation and Context It only support single partition before this changes. There's graph partition limitation for context cache feature after this change. --- .../core/framework/graph_partitioner.cc | 112 ++- .../qnn/builder/onnx_ctx_model_helper.cc | 204 +++--- .../qnn/builder/onnx_ctx_model_helper.h | 71 +- .../qnn/builder/qnn_backend_manager.cc | 15 +- .../qnn/builder/qnn_backend_manager.h | 3 +- .../core/providers/qnn/builder/qnn_model.cc | 3 +- .../providers/qnn/qnn_execution_provider.cc | 153 ++-- .../test/providers/qnn/qnn_basic_test.cc | 88 --- .../test/providers/qnn/qnn_ep_context_test.cc | 657 ++++++++++++++++++ .../test/providers/qnn/qnn_test_utils.h | 2 +- .../test/providers/qnn/simple_op_htp_test.cc | 375 ---------- 11 files changed, 919 insertions(+), 764 deletions(-) create mode 100644 onnxruntime/test/providers/qnn/qnn_ep_context_test.cc diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 07b465c80745a..90ee8a46f66a9 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -645,6 +645,10 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end()); } + if (all_ep_context_nodes.size() < 1) { + return Status::OK(); + } + auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { for (auto& node : all_ep_context_nodes) { if (node_name == node->Name()) { @@ -656,76 +660,70 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers onnxruntime::PathString context_cache_path; PathString model_pathstring = graph.ModelPath().ToPathString(); - if (all_ep_context_nodes.size() > 0) { - if (!ep_context_path.empty()) { - context_cache_path = ToPathString(ep_context_path); - } else if (!model_pathstring.empty()) { - context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); - } - { + if (!ep_context_path.empty()) { + context_cache_path = ToPathString(ep_context_path); + } else if (!model_pathstring.empty()) { + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } + + { #ifdef _WIN32 - std::wifstream fs(context_cache_path); + std::wifstream fs(context_cache_path); #else - std::ifstream fs(context_cache_path); + std::ifstream fs(context_cache_path); #endif - ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); - } + ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); + } - Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), - graph.DomainToVersionMap(), {}, logger); - auto& ep_graph = ep_context_model.MainGraph(); - ep_graph.SetDescription(graph.Description()); - - // Set inputs outputs explicitly to make sure the order is same as the user model. - auto inputs = graph.GetInputs(); - auto outputs = graph.GetOutputs(); - - InlinedVector ep_graph_inputs; - ep_graph_inputs.reserve(inputs.size()); - for (auto& input : inputs) { - auto input_arg = graph.GetNodeArg(input->Name()); - auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); - ep_graph_inputs.push_back(&ep_graph_input_arg); - } + Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + graph.DomainToVersionMap(), {}, logger); + auto& ep_graph = ep_context_model.MainGraph(); + ep_graph.SetDescription(graph.Description()); - InlinedVector ep_graph_outputs; - ep_graph_outputs.reserve(outputs.size()); - for (auto& output : outputs) { - auto output_arg = graph.GetNodeArg(output->Name()); - auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); - ep_graph_outputs.push_back(&ep_graph_output_arg); - } + // Set inputs outputs explicitly to make sure the order is same as the user model. + auto inputs = graph.GetInputs(); + auto outputs = graph.GetOutputs(); - ep_graph.SetInputs(ep_graph_inputs); - ep_graph.SetOutputs(ep_graph_outputs); + InlinedVector ep_graph_inputs; + ep_graph_inputs.reserve(inputs.size()); + for (auto& input : inputs) { + auto input_arg = graph.GetNodeArg(input->Name()); + auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + ep_graph_inputs.push_back(&ep_graph_input_arg); + } - for (const auto& node : graph.Nodes()) { - // the fused node and EPContext node has same node name - auto ep_context_node = get_ep_context_node(node.Name()); - // Use EpContext node created by the EPs if name matched, otherwise use node from original model - if (ep_context_node.first) { - ep_graph.AddNode(*ep_context_node.second); - } else { - ep_graph.AddNode(node); - } - } + InlinedVector ep_graph_outputs; + ep_graph_outputs.reserve(outputs.size()); + for (auto& output : outputs) { + auto output_arg = graph.GetNodeArg(output->Name()); + auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + ep_graph_outputs.push_back(&ep_graph_output_arg); + } - // handle initializers - for (const auto& input : graph.GetInputsIncludingInitializers()) { - const ONNX_NAMESPACE::TensorProto* initializer = nullptr; - if (graph.GetInitializedTensor(input->Name(), initializer)) { - // There initializer could have duplicates so make sure we only add once - const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; - if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) { - ep_graph.AddInitializedTensor(*initializer); - } - } + ep_graph.SetInputs(ep_graph_inputs); + ep_graph.SetOutputs(ep_graph_outputs); + + for (const auto& node : graph.Nodes()) { + // the fused node and EPContext node has same node name + auto ep_context_node = get_ep_context_node(node.Name()); + // Use EpContext node created by the EPs if name matched, otherwise use node from original model + if (ep_context_node.first) { + ep_graph.AddNode(*ep_context_node.second); + } else { + ep_graph.AddNode(node); } + } - ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + // handle initializers + for (const auto& initialized_tensor : graph.GetAllInitializedTensors()) { + if (ep_graph.GetNodeArg(initialized_tensor.first) != nullptr) { + ep_graph.AddInitializedTensor(*initialized_tensor.second); + } } + ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 5d3f406f50612..c2e71081b898e 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -12,34 +12,60 @@ namespace onnxruntime { namespace qnn { -Status IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, - bool& is_qnn_ctx_model) { - is_qnn_ctx_model = false; - for (const auto& fused_node_graph : fused_nodes_and_graphs) { - const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph); - // It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type - int count = 0; - for (const auto& node : graph_viewer.Nodes()) { - if (EPCONTEXT_OP == node.OpType()) { - is_qnn_ctx_model = true; +bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer) { + // It's an Onnx model with Qnn context cache binary if it has a node with EPContext type and the source is QNN or QNNExecutionProvider. + for (const auto& node : graph_viewer.Nodes()) { + if (EPCONTEXT_OP == node.OpType()) { + NodeAttrHelper node_helper(node); + std::string cache_source = node_helper.Get(SOURCE, ""); + + std::transform(cache_source.begin(), + cache_source.end(), + cache_source.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + + if (cache_source == "qnnexecutionprovider" || cache_source == "qnn") { + return true; } - ++count; } - ORT_RETURN_IF(is_qnn_ctx_model && count > 1, "Fused graph should only has 1 single EPContext node."); } - return Status::OK(); + return false; } -bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer) { - // It's an Onnx model with Qnn context cache binary if it only has a node with EPContext type - for (const auto& node : graph_viewer.Nodes()) { - if (EPCONTEXT_OP == node.OpType()) { +bool IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs) { + for (const auto& fused_node_graph : fused_nodes_and_graphs) { + const onnxruntime::GraphViewer& graph_viewer(fused_node_graph.filtered_graph); + bool has_qnn_ep_context_node = GraphHasEpContextNode(graph_viewer); + if (has_qnn_ep_context_node) { return true; } } return false; } +Status GetMainContextNode(const std::vector& fused_nodes_and_graphs, + QnnBackendManager* qnn_backend_manager, + const logging::Logger& logger, + int& main_context_pos, + std::unordered_map>& qnn_models) { + main_context_pos = -1; + for (size_t i = 0; i < fused_nodes_and_graphs.size(); ++i) { + const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph); + const auto& ep_context_node = graph_viewer.Nodes().begin(); + ORT_RETURN_IF_NOT(EPCONTEXT_OP == ep_context_node->OpType(), "Should only filter in the EPContext node."); + qnn_models.emplace(ep_context_node->Name(), + std::make_unique(logger, qnn_backend_manager)); + NodeAttrHelper node_helper(*ep_context_node); + int64_t is_main_context = node_helper.Get(MAIN_CONTEXT, static_cast(0)); + if (1 == is_main_context) { + main_context_pos = static_cast(i); + } + } + + ORT_RETURN_IF(main_context_pos < 0, "Failed to find the EPContext node with main_context=1"); + return Status::OK(); +} + Status CreateNodeArgs(const std::vector& names, const std::unordered_map& tensor_info_table, std::vector& node_args, @@ -60,32 +86,18 @@ Status CreateNodeArgs(const std::vector& names, return Status::OK(); } -Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, - const logging::Logger& logger) { - using namespace onnxruntime; - std::shared_ptr model; - ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger)); - const auto& graph = model->MainGraph(); - return GetEpContextFromGraph(GraphViewer(graph), - ctx_onnx_model_path, - qnn_backend_manager, - qnn_model); -} - -Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, - const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model) { - const auto& node = graph_viewer.Nodes().begin(); - NodeAttrHelper node_helper(*node); +Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + std::unordered_map>& qnn_models) { + ORT_RETURN_IF_NOT(EPCONTEXT_OP == main_context_node.OpType(), "Should only filter in the EPContext node."); + NodeAttrHelper node_helper(main_context_node); bool is_embed_mode = node_helper.Get(EMBED_MODE, true); if (is_embed_mode) { const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, ""); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), - qnn_model); + qnn_models); } std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path(); @@ -133,23 +145,16 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, cache_file.close(); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), - qnn_model); + qnn_models); } -Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, +Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, - bool is_qnn_ctx_model, - bool is_ctx_cache_file_exist, QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, - const logging::Logger& logger) { - Status status; - if (is_qnn_ctx_model) { - status = GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model); - } else if (is_ctx_cache_file_exist) { - status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger); - } + std::unordered_map>& qnn_models) { + Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models); + // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model if (!status.IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage()); } @@ -157,88 +162,37 @@ Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, return Status::OK(); } -Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path, - std::string& model_name, - std::string& model_description, - std::string& graph_partition_name, - std::string& cache_source, - const logging::Logger& logger) { - using namespace onnxruntime; - std::shared_ptr model; - ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger)); - const auto& graph = GraphViewer(model->MainGraph()); - const auto& node = graph.Nodes().begin(); - NodeAttrHelper node_helper(*node); - model_name = graph.Name(); - model_description = graph.Description(); - graph_partition_name = node_helper.Get(PARTITION_NAME, ""); - cache_source = node_helper.Get(SOURCE, ""); - - return Status::OK(); -} - -bool IsContextCacheFileExists(const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path) { - // Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default +// Figure out the real context cache file path +// return true if context cache file exists +bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, + const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + onnxruntime::PathString& context_cache_path) { + // always try the path set by user first, it's the only way to set it if load model from memory if (!customer_context_cache_path.empty()) { context_cache_path = ToPathString(customer_context_cache_path); - } else if (!model_pathstring.empty()) { - context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } else if (!model_pathstring.empty()) { // model loaded from file + if (is_qnn_ctx_model) { + // it's a context cache model, just use the model path + context_cache_path = model_pathstring; + } else if (!model_pathstring.empty()) { + // this is not a normal Onnx model, no customer path, create a default path for generation: model_path + _ctx.onnx + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } } return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path); } -Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path, - const std::string& model_name, - const std::string& model_description, - const std::string& graph_partition_name, - const logging::Logger& logger) { - std::string model_name_from_ctx_cache; - std::string model_description_from_ctx_cache; - std::string graph_partition_name_from_ctx_cache; - std::string cache_source; - auto status = GetMetadataFromEpContextModel(context_cache_path, - model_name_from_ctx_cache, - model_description_from_ctx_cache, - graph_partition_name_from_ctx_cache, - cache_source, - logger); - if (!status.IsOK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContextModel."); - } - - // The source attribute from the skeleton onnx file indicate whether it's generated from QNN toolchain or ORT - if (cache_source != kQnnExecutionProvider) { - LOGS(logger, VERBOSE) << "Context binary cache is not generated by Ort."; - return Status::OK(); - } - - if (model_name != model_name_from_ctx_cache || - model_description != model_description_from_ctx_cache || - graph_partition_name != graph_partition_name_from_ctx_cache) { - std::string message = onnxruntime::MakeString("Metadata mismatch. onnx: ", - model_name, " ", model_description, " ", graph_partition_name, - " vs epcontext: ", - model_name_from_ctx_cache, " ", - model_description_from_ctx_cache, " ", - graph_partition_name_from_ctx_cache); - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, message); - } - - return Status::OK(); -} - -Status GenerateCtxCacheOnnxModel(Model* model, - unsigned char* buffer, - uint64_t buffer_size, - const std::string& sdk_build_version, - const std::vector& fused_nodes_and_graphs, - const std::unordered_map>& qnn_models, - const onnxruntime::PathString& context_cache_path, - bool qnn_context_embed_mode, - const logging::Logger& logger) { +Status CreateEPContextNodes(Model* model, + unsigned char* buffer, + uint64_t buffer_size, + const std::string& sdk_build_version, + const std::vector& fused_nodes_and_graphs, + const std::unordered_map>& qnn_models, + const onnxruntime::PathString& context_cache_path, + bool qnn_context_embed_mode, + const logging::Logger& logger) { auto& graph = model->MainGraph(); using namespace ONNX_NAMESPACE; diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index ba6fe23ecd56e..b1360b4e576fa 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -28,59 +28,44 @@ static const std::string EP_SDK_VER = "ep_sdk_version"; static const std::string PARTITION_NAME = "partition_name"; static const std::string SOURCE = "source"; -Status IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, - bool& is_qnn_ctx_model); +bool GraphHasEpContextNode(const onnxruntime::GraphViewer& graph_viewer); -bool IsQnnCtxModel(const onnxruntime::GraphViewer& graph_viewer); +bool IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs); + +Status GetMainContextNode(const std::vector& fused_nodes_and_graphs, + QnnBackendManager* qnn_backend_manager, + const logging::Logger& logger, + int& main_context_pos, + std::unordered_map>& qnn_models); Status CreateNodeArgs(const std::vector& names, const std::unordered_map& tensor_info_table, std::vector& node_args, onnxruntime::Graph& graph); -bool IsContextCacheFileExists(const std::string& customer_context_cache_path, - const onnxruntime::PathString& model_pathstring, - onnxruntime::PathString& context_cache_path); - -Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, - const logging::Logger& logger); +bool ValidateContextCacheFilePath(bool is_qnn_ctx_model, + const std::string& customer_context_cache_path, + const onnxruntime::PathString& model_pathstring, + onnxruntime::PathString& context_cache_path); -Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer, - const onnxruntime::PathString& ctx_onnx_model_path, - QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model); +Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, + const onnxruntime::PathString& ctx_onnx_model_path, + QnnBackendManager* qnn_backend_manager, + std::unordered_map>& qnn_models); -Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer, +Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, - bool is_qnn_ctx_model, - bool is_ctx_cache_file_exist, QnnBackendManager* qnn_backend_manager, - QnnModel& qnn_model, - const logging::Logger& logger); - -Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path, - const std::string& model_name, - const std::string& model_description, - const std::string& graph_partition_name, - const logging::Logger& logger); - -Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path, - std::string& model_name, - std::string& model_description, - std::string& graph_partition_name, - std::string& cache_source, - const logging::Logger& logger); - -Status GenerateCtxCacheOnnxModel(Model* model, - unsigned char* buffer, - uint64_t buffer_size, - const std::string& sdk_build_version, - const std::vector& fused_nodes_and_graphs, - const std::unordered_map>& qnn_models, - const onnxruntime::PathString& context_cache_path, - bool qnn_context_embed_mode, - const logging::Logger& logger); + std::unordered_map>& qnn_models); + +Status CreateEPContextNodes(Model* model, + unsigned char* buffer, + uint64_t buffer_size, + const std::string& sdk_build_version, + const std::vector& fused_nodes_and_graphs, + const std::unordered_map>& qnn_models, + const onnxruntime::PathString& context_cache_path, + bool qnn_context_embed_mode, + const logging::Logger& logger); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 973b81d337c81..5f0b87c7cb9d7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -517,7 +517,8 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 return context_buffer; } -Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model) { +Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + std::unordered_map>& qnn_models) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || nullptr == qnn_sys_interface_.systemContextFree; @@ -550,8 +551,9 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t graphs_info = binary_info->contextBinaryInfoV2.graphs; } - ORT_RETURN_IF(graph_count > 1, "Load from Qnn cached context only support 1 sub-graph."); - ORT_RETURN_IF(graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); + ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); + LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count << ", EPContext node count: " << qnn_models.size(); + ORT_RETURN_IF(graph_count != qnn_models.size(), "Graph count from QNN context not equal to EPContext node count."); ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, "Invalid function pointer for contextCreateFromBinary."); @@ -571,7 +573,12 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t // More work to support multiple partition, how to map the graph name in compile to qnn graph name // Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile - ORT_RETURN_IF_ERROR(qnn_model.DeserializeGraphInfoFromBinaryInfo(graphs_info[0])); + for (uint32_t i = 0; i < graph_count; ++i) { + std::string graph_name(graphs_info[i].graphInfoV1.graphName); + auto qnn_model_pos = qnn_models.find(graph_name); + ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names."); + ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i])); + } qnn_sys_interface_.systemContextFree(sys_ctx_handle); sys_ctx_handle = nullptr; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index f7b8947ab84bb..36375522b5a0a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -87,7 +87,8 @@ class QnnBackendManager { std::unique_ptr GetContextBinaryBuffer(uint64_t& written_buffer_size); - Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model); + Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + std::unordered_map>& qnn_models); Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 869d9326d9232..314cab4a36ca9 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -97,7 +97,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, std::unordered_map node_unit_map; std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); - const auto& graph_name = graph_viewer.Name(); + // This name must be same with the EPContext node name + const auto& graph_name = fused_node.Name(); ORT_RETURN_IF_ERROR(SetGraphInputOutputInfo(graph_viewer, fused_node)); QnnModelWrapper qnn_model_wrapper = QnnModelWrapper(graph_viewer, logger_, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 5f4e2e62f063e..b58f6e10df94c 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -150,6 +150,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio LOGS_DEFAULT(VERBOSE) << "User specified context cache embed mode: " << qnn_context_embed_mode_; context_cache_path_cfg_ = session_options->config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + LOGS_DEFAULT(VERBOSE) << "User specified context cache path: " << context_cache_path_cfg_; } static const std::string BACKEND_PATH = "backend_path"; @@ -318,14 +319,27 @@ std::unordered_set QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const size_t node_unit_size, - bool load_from_cached_context, + bool is_qnn_ctx_model, const logging::Logger& logger) const { std::unordered_set supported_nodes{}; - // Enable Qnn context cache requires the whole graph partitioned to Qnn EP - // Blindly filter in all nodes if context cache is enabled - if (load_from_cached_context) { + // Filter in the EPContext node for QNN + if (is_qnn_ctx_model) { for (const auto& node : graph_viewer.Nodes()) { - supported_nodes.insert(&node); + NodeAttrHelper node_helper(node); + std::string cache_source = node_helper.Get(qnn::SOURCE, ""); + + std::transform(cache_source.begin(), + cache_source.end(), + cache_source.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + + if (qnn::EPCONTEXT_OP == node.OpType() && (cache_source == "qnnexecutionprovider" || cache_source == "qnn")) { + LOGS(logger, VERBOSE) << "Node supported: [1] index: [" << node.Index() + << "] name: [" << node.Name() + << "] Operator type: [EPContext" + << "] index: [" << node.Index() << "]"; + supported_nodes.insert(&node); + } } return supported_nodes; } @@ -410,22 +424,11 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer } const auto& logger = *GetLogger(); - bool load_from_cached_context = false; - bool is_qnn_ctx_model = qnn::IsQnnCtxModel(graph_viewer); - if (is_qnn_ctx_model) { - load_from_cached_context = true; - } + bool is_qnn_ctx_model = qnn::GraphHasEpContextNode(graph_viewer); - // This is for case: QDQ model + Onnx Qnn context cache model - if (context_cache_enabled_ && !is_qnn_ctx_model) { - onnxruntime::PathString context_cache_path; - load_from_cached_context = qnn::IsContextCacheFileExists(context_cache_path_cfg_, - graph_viewer.ModelPath().ToPathString(), - context_cache_path); - } - - // Load from cached context will load the QnnSystem lib and skip the Qnn context creation - auto rt = qnn_backend_manager_->SetupBackend(logger, load_from_cached_context); + // It will load the QnnSystem lib if is_qnn_ctx_model=true, and + // delay the Qnn context creation to Compile() using the cached context binary + auto rt = qnn_backend_manager_->SetupBackend(logger, is_qnn_ctx_model); if (Status::OK() != rt) { LOGS(logger, ERROR) << "QNN SetupBackend failed " << rt.ErrorMessage(); return result; @@ -443,7 +446,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer std::tie(node_unit_holder, node_unit_map) = GetAllNodeUnits(graph_viewer); const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(), - load_from_cached_context, logger); + is_qnn_ctx_model, logger); // Helper function that returns a string that lists all unsupported nodes. // Ex: { name: mul_123, type: Mul }, {}, ... @@ -496,7 +499,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer if (partition && partition->sub_graph) { nodes_in_partition = partition->sub_graph->nodes.size(); - if (nodes_in_partition == 1) { + if (nodes_in_partition == 1 && !is_qnn_ctx_model) { const Node* node = graph_viewer.GetNode(partition->sub_graph->nodes[0]); if (!node) { @@ -516,7 +519,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer result.push_back(std::move(partition)); num_of_supported_nodes += nodes_in_partition; } - } + } // for } const size_t num_of_partitions = result.size(); @@ -527,7 +530,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer // Print list of unsupported nodes to the ERROR logger if the CPU EP // has been disabled for this inference session. - if (disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) { + if (!is_qnn_ctx_model && disable_cpu_ep_fallback_ && num_nodes_in_graph != num_of_supported_nodes) { LOGS(logger, ERROR) << "Unsupported nodes in QNN EP: " << get_unsupported_node_names(); } @@ -618,64 +621,76 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { const auto& logger = *GetLogger(); - Node& fused_node = fused_nodes_and_graphs[0].fused_node; - const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[0].filtered_graph); - bool is_qnn_ctx_model = false; - ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model)); + bool is_qnn_ctx_model = qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs); onnxruntime::PathString context_cache_path; - bool is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_, - graph_viewer.ModelPath().ToPathString(), - context_cache_path); - const std::string& model_name = graph_viewer.GetGraph().Name(); - const std::string& model_description = graph_viewer.GetGraph().Description(); - const std::string& graph_meta_id = fused_node.Name(); - if (fused_nodes_and_graphs.size() == 1 && !is_qnn_ctx_model && is_ctx_file_exist) { - ORT_RETURN_IF_ERROR(qnn::ValidateWithContextFile(context_cache_path, - model_name, - model_description, - graph_meta_id, - logger)); - } - - if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) { - ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); - std::unique_ptr qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); - // Load and execute from cached context if exist - ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxModel(graph_viewer, + bool is_ctx_file_exist = false; + if (is_qnn_ctx_model || context_cache_enabled_) { + const onnxruntime::GraphViewer& graph_viewer_0(fused_nodes_and_graphs[0].filtered_graph); + is_ctx_file_exist = qnn::ValidateContextCacheFilePath(is_qnn_ctx_model, + context_cache_path_cfg_, + graph_viewer_0.ModelPath().ToPathString(), + context_cache_path); + } + + ORT_RETURN_IF(is_ctx_file_exist && !is_qnn_ctx_model && context_cache_enabled_, + "The inference session is created from normal ONNX model. And an EP context model file is provided and existed. ", + "Please remove the EP context model manually if you want to re-generate it."); + + if (is_qnn_ctx_model) { + // Table, the node name is the graph_meta_id (old) created from user model which used to generate the EP context model + // for this session (created from an EP context model), the graph_meta_id is new + std::unordered_map> qnn_models; + + int main_context_pos = -1; + ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, qnn_backend_manager_.get(), + logger, main_context_pos, qnn_models)); + + const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); + // Create QNN context from the cached binary, deserialize the QNN graph from the binary + ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, context_cache_path, - is_qnn_ctx_model, - is_ctx_file_exist, qnn_backend_manager_.get(), - *(qnn_model.get()), - logger)); - ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); - ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); - - // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] - // the name here should be same with context->node_name in compute_info - qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); + qnn_models)); + + for (auto fused_node_and_graph : fused_nodes_and_graphs) { + const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); + const auto& ep_context_node = graph_viewer.Nodes().begin(); + const Node& fused_node = fused_node_and_graph.fused_node; + const std::string& graph_meta_id = fused_node.Name(); + std::string key = ep_context_node->Name(); + ORT_RETURN_IF(qnn_models.find(key) == qnn_models.end(), key + " key name not exist in table qnn_models."); + auto qnn_model = std::move(qnn_models[key]); + ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node)); + ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); + + // fused node name is QNNExecutionProvider_QNN_[hash_id]_[id] + // the name here must be same with context->node_name in compute_info + qnn_models_.emplace(graph_meta_id, std::move(qnn_model)); + + ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); + } - ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger)); return Status::OK(); } ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger)); - if (context_cache_enabled_ && !is_qnn_ctx_model) { - ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); + // Generate QNN context model if it's QDQ model + context_cache_enabled=true + not exist already + if (!is_qnn_ctx_model && context_cache_enabled_ && !is_ctx_file_exist) { + // All partitioned graph share single QNN context, included in the same context binary uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger); - ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(qnn_ep_context_model_.get(), - context_buffer.get(), - buffer_size, - qnn_backend_manager_->GetSdkVersion(), - fused_nodes_and_graphs, - qnn_models_, - context_cache_path, - qnn_context_embed_mode_, - logger)); + ORT_RETURN_IF_ERROR(qnn::CreateEPContextNodes(qnn_ep_context_model_.get(), + context_buffer.get(), + buffer_size, + qnn_backend_manager_->GetSdkVersion(), + fused_nodes_and_graphs, + qnn_models_, + context_cache_path, + qnn_context_embed_mode_, + logger)); } return Status::OK(); } diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index c50b1002fa8c8..4e1aef2c40b2b 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -613,94 +613,6 @@ static GetTestModelFn BuildCastAddTestCase() { }; } -// Test that models with 2 inputs which has different data type can still generate the context binary -TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - // Add kMSDomain to cover contrib op like Gelu - const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; - - auto& logging_manager = DefaultLoggingManager(); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - - onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), - IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, - logging_manager.DefaultLogger()); - Graph& graph = model.MainGraph(); - ModelTestBuilder helper(graph); - BuildCastAddTestCase()(helper); - helper.SetGraphOutputs(); - ASSERT_STATUS_OK(model.MainGraph().Resolve()); - - // Serialize the model to a string. - std::string model_data; - model.ToProto().SerializeToString(&model_data); - - const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); - - const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; - Ort::SessionOptions so; - so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); - - so.AppendExecutionProvider("QNN", provider_options); - - Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); - - // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - - // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - -// Generate context cache model from the ONNX models with 2 inputs. -// The generated model should have same input order. -// The input ONNX model is created in the way that the model inputs order -// is different with the order in the graph (topological order). -// It cause issue if the generated model doesn't set the inputs/outputs explicitly. -TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - // Add kMSDomain to cover contrib op like Gelu - const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; - - auto& logging_manager = DefaultLoggingManager(); - logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); - - const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; - Ort::SessionOptions so; - so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); - so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); - - so.AppendExecutionProvider("QNN", provider_options); - - Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); - - // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - - std::shared_ptr model; - ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); - auto inputs = model->MainGraph().GetInputs(); - EXPECT_TRUE(inputs.size() == 2); - EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); - EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); - - // clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - // A repro of QC case 06838696, accuracy issue for Cast + Op (quantized) // the value pair(1, 0.00392156886) at index #1 don't match, // which is -0.996078 from 1 diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc new file mode 100644 index 0000000000000..b1f3b52e77553 --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -0,0 +1,657 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/inference_session.h" + +#include "test/providers/qnn/qnn_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::logging; + +// in test_main.cc +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +// Create a model with Case + Add (quantized) +// input1 -> Add -> Q -> DQ \ +// FusedMatMul -> Q -> DQ -> output +// input2 -> Q -> DQ / +static GetTestModelFn BuildGraphWithQAndNonQ(bool single_ep_node = true) { + return [single_ep_node](ModelTestBuilder& builder) { + // Creat non-quantized Add node1 + NodeArg* input1 = MakeTestInput(builder, TestInputDef({2, 2}, false, {0, 1, 0, 1})); + NodeArg* add1_ini_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, {0, 0, 0, 0})); + + auto* add1_output = builder.MakeIntermediate(); + builder.AddNode("FusedMatMul", {input1, add1_ini_input2}, {add1_output}, kMSDomain); + + // Create quantized Add node2 + std::vector data = {0.0f, 0.0f, 1.0f, 0.0f}; + gsl::span data_range = gsl::make_span(data); + QuantParams q_parameter = GetDataQuantParams(data_range); + auto* add2_input1_qdq = AddQDQNodePair(builder, add1_output, q_parameter.scale, q_parameter.zero_point); + + NodeArg* add2_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, data)); + auto* add2_input2_qdq = AddQDQNodePair(builder, add2_input2, q_parameter.scale, q_parameter.zero_point); + + auto* add2_output = builder.MakeIntermediate(); + + builder.AddNode("Add", {add2_input1_qdq, add2_input2_qdq}, {add2_output}); + + if (single_ep_node) { + // add_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add2_output, q_parameter.scale, q_parameter.zero_point); + } else { + auto* add3_input1_qdq = AddQDQNodePair(builder, add2_output, q_parameter.scale, q_parameter.zero_point); + NodeArg* add3_ini_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, {0, 0, 0, 0})); + + auto* add3_output = builder.MakeIntermediate(); + builder.AddNode("FusedMatMul", {add3_input1_qdq, add3_ini_input2}, {add3_output}, kMSDomain); + + // Create quantized Add node4 + auto* add4_input1_qdq = AddQDQNodePair(builder, add3_output, q_parameter.scale, q_parameter.zero_point); + + NodeArg* add4_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, data)); + auto* add4_input2_qdq = AddQDQNodePair(builder, add4_input2, q_parameter.scale, q_parameter.zero_point); + + auto* add4_output = builder.MakeIntermediate(); + + builder.AddNode("Add", {add4_input1_qdq, add4_input2_qdq}, {add4_output}); + // add_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add4_output, q_parameter.scale, q_parameter.zero_point); + } + }; +} + +void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + BuildGraphWithQAndNonQ(single_ep_node)(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string context_binary_file = "./qnn_context_binary_multi_partition_test.onnx"; + std::remove(context_binary_file.c_str()); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + int ep_context_node_count = 0; + int non_ep_context_node_count = 0; + std::shared_ptr ctx_model; + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), ctx_model, nullptr, DefaultLoggingManager().DefaultLogger())); + auto& ctx_graph = ctx_model->MainGraph(); + for (auto& node : ctx_graph.Nodes()) { + if (node.OpType() == "EPContext") { + ++ep_context_node_count; + } else { + ++non_ep_context_node_count; + } + } + + int expected_node_count = single_ep_node ? 1 : 2; + ASSERT_EQ(ep_context_node_count, expected_node_count); + ASSERT_EQ(non_ep_context_node_count, expected_node_count); + + Ort::SessionOptions so2; + // context file path is required if it's non-embed mode and the model is loaded from memory + so2.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so2.AppendExecutionProvider("QNN", provider_options); + + std::string ctx_model_data; + ctx_model->ToProto().SerializeToString(&ctx_model_data); + Ort::Session session2(*ort_env, ctx_model_data.data(), ctx_model_data.size(), so2); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Test that models with 1 non-quantized Add node and 1 quantized Add node can still generate the context binary +// The generated Onnx model has 1 Add node and 1 EPContext node +TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { + bool single_ep_node = true; + QnnContextBinaryMultiPartitionTestBody(single_ep_node); +} + +// Test that models with 2 non-quantized Add nodes and 2 quantized Add nodes can still generate the context binary +// The generated Onnx model has 2 Add nodes and 1 EPContext nodes +TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport2) { + bool single_ep_node = false; + QnnContextBinaryMultiPartitionTestBody(single_ep_node); +} + +// Create a model with Case + Add (quantized) +// cast_input -> Cast -> Q -> DQ \ +// Add -> Q -> DQ -> output +// input2 -> Q -> DQ / +static GetTestModelFn BuildCastAddTestCase() { + return [](ModelTestBuilder& builder) { + // Creat Cast node int32 -> float32 + NodeArg* cast_input = MakeTestInput(builder, TestInputDef({2, 3}, false, {0, 1, 0, 1, 0, 1})); + + auto* cast_output = builder.MakeIntermediate(); + Node& cast_node = builder.AddNode("Cast", {cast_input}, {cast_output}); + cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + + // Create Add node + std::vector data = {0.0f, 0.0f, 1.0f, 0.0f, 1.0f, 0.0f}; + gsl::span data_range = gsl::make_span(data); + QuantParams q_parameter = GetDataQuantParams(data_range); + auto* add_input1_qdq = AddQDQNodePair(builder, cast_output, q_parameter.scale, q_parameter.zero_point); + + NodeArg* add_input2 = MakeTestInput(builder, TestInputDef({2, 3}, false, data)); + auto* add_input2_qdq = AddQDQNodePair(builder, add_input2, q_parameter.scale, q_parameter.zero_point); + + auto* add_output = builder.MakeIntermediate(); + + builder.AddNode("Add", {add_input1_qdq, add_input2_qdq}, {add_output}); + + // add_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, add_output, q_parameter.scale, q_parameter.zero_point); + }; +} + +// Test that models with 2 inputs which has different data type can still generate the context binary +TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + onnxruntime::Model model("QNN_EP_TestModel", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + BuildCastAddTestCase()(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(model.MainGraph().Resolve()); + + // Serialize the model to a string. + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + const auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + + const std::string context_binary_file = "./qnn_context_binary_int32_fp32_inputs_test.onnx"; + std::remove(context_binary_file.c_str()); + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Generate context cache model from the ONNX models with 2 inputs. +// The generated model should have same input order. +// The input ONNX model is created in the way that the model inputs order +// is different with the order in the graph (topological order). +// It cause issue if the generated model doesn't set the inputs/outputs explicitly. +TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + // Add kMSDomain to cover contrib op like Gelu + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + auto inputs = model->MainGraph().GetInputs(); + EXPECT_TRUE(inputs.size() == 2); + EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); + EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Run QDQ model on HTP 3 times +// 1st run will generate the Qnn context cache onnx file +// 2nd run directly loads and run from Qnn context cache model +TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheEmbedModeTest) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_binary_test.onnx"; + std::remove(context_binary_file.c_str()); + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + // 2nd run directly loads and run from Qnn context cache model + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + context_binary_file); + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Run QDQ model on HTP 3 times +// 1st run will generate the Onnx skeleton file + Qnn context cache binary file +// 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file +TEST_F(QnnHTPBackendTests, QnnContextBinaryCacheNonEmbedModeTest) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./testdata/qnn_context_cache_non_embed.onnx"; + std::string qnn_ctx_bin = "./testdata/qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); + + std::remove(context_binary_file.c_str()); + std::remove(qnn_ctx_bin.c_str()); + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Check the Onnx skeleton file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + // Check the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin)); + + std::unordered_map session_option_pairs2; + // Need to set the context file path since TestQDQModelAccuracy load the model from memory + session_option_pairs2.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + // 2nd run directly loads and run from Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + context_binary_file, + session_option_pairs2); + + // load the model from file + std::vector buffer; + { + std::ifstream file(context_binary_file, std::ios::binary | std::ios::ate); + if (!file) + ORT_THROW("Error reading model"); + buffer.resize(narrow(file.tellg())); + file.seekg(0, std::ios::beg); + if (!file.read(buffer.data(), buffer.size())) + ORT_THROW("Error reading model"); + } + + Ort::SessionOptions so; // No need to set the context file path in so since it's load from file + so.AppendExecutionProvider("QNN", provider_options); +#ifdef _WIN32 + std::wstring ctx_model_file(context_binary_file.begin(), context_binary_file.end()); +#else + std::string ctx_model_file(context_binary_file.begin(), context_binary_file.end()); +#endif + Ort::Session session(*ort_env.get(), ctx_model_file.c_str(), so); + + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + ASSERT_EQ(std::remove(qnn_ctx_bin.c_str()), 0); +} + +// Run QDQ model on HTP 2 times +// 1st run will generate the Onnx skeleton file + Qnn context cache binary file +// Then delete the context bin file to make the 2nd sesssion.Initialize() return the status with code INVALID_GRAPH +TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::remove(context_binary_file.c_str()); + std::remove(context_bin.string().c_str()); + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Check the Onnx skeleton file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + // Check the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_bin)); + // Delete the Qnn context cache binary file + EXPECT_TRUE(std::filesystem::remove(context_bin)); + + // loads and run from Onnx skeleton file + Qnn context cache binary file + onnx::ModelProto model_proto; + onnxruntime::Model qnn_ctx_model; + // Load the QNN context cache model from path specified + ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(context_binary_file), model_proto)); + std::string qnn_ctx_model_data; + model_proto.SerializeToString(&qnn_ctx_model_data); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + std::string provider_type = kCpuExecutionProvider; + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); + + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { + const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; + auto& logging_manager = DefaultLoggingManager(); + onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + std::vector shape = {2, 3}; + NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, true, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); + auto* graph_output = helper.MakeOutput(shape); + Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); + ep_context_node.AddAttribute("embed_mode", static_cast(0)); + // The .. in the path will cause INVALID_GRAPH + ep_context_node.AddAttribute("ep_cache_context", external_bin_path); + ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); + ep_context_node.AddAttribute("source", "QNN"); + helper.SetGraphOutputs(); + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + return model_data; +} + +// Create a model with EPContext node. Set the node property ep_cache_context has ".." +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryRelativePathTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode("../qnn_context.bin"); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context has absolute path +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryAbsolutePathTest) { +#if defined(_WIN32) + std::string external_ctx_bin_path = "D:/qnn_context.bin"; +#else + std::string external_ctx_bin_path = "/data/qnn_context.bin"; +#endif + std::string model_data = CreateQnnCtxModelWithNonEmbedMode(external_ctx_bin_path); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context to a file not exist +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode("qnn_context_not_exist.bin"); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Create a model with EPContext node. Set the node property ep_cache_context to empty string +// Verify that it return INVALID_GRAPH status +TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { + std::string model_data = CreateQnnCtxModelWithNonEmbedMode(""); + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); +} + +// Run QDQ model on HTP with 2 inputs +// 1st run will generate the Qnn context cache onnx file +// 2nd run directly loads and run from Qnn context cache model +TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; + std::remove(context_binary_file.c_str()); + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + + const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); + const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Add"; + + // Runs model with DQ-> Add-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + // 2nd run directly loads and run from Qnn context cache model + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + context_binary_file); + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index bfe5bab318313..f4febd99ddae7 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -361,7 +361,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe model_proto.SerializeToString(&qnn_ctx_model_data); // Run QNN context cache model on QNN EP and collect outputs. InferenceModel(qnn_ctx_model_data, "qnn_ctx_model_logger", qnn_options, - expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep); + expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs, is_qnn_ep, session_option_pairs); } else { // Run QDQ model on QNN EP and collect outputs. // Only need to apply the extra session options to this QDQ model inference on QNN EP diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 1e938ae9e334b..2f3b0e84a123e 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -723,381 +723,6 @@ TEST_F(QnnHTPBackendTests, SpaceToDepthOp_U16) { true); // Use com.microsoft domain for Q/DQ ops } -// Run QDQ model on HTP 3 times -// 1st run will generate the Qnn context cache onnx file -// 2nd run will load and run from QDQ model + Qnn context cache model -// 3rd run directly loads and run from Qnn context cache model -TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - const std::string context_binary_file = "./qnn_context_binary_test.onnx"; - - std::unordered_map session_option_pairs; - session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); - - const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); - const std::string op_type = "Atan"; - - // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. - // 1st run will generate the Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - - // 2nd run loads and run from QDQ model + Qnn context cache model - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // 3rd run directly loads and run from Qnn context cache model - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - context_binary_file); - // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - -// Run QDQ model on HTP 3 times -// 1st run will generate the Onnx skeleton file + Qnn context cache binary file -// 2nd run will loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file -// 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file -TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::unordered_map session_option_pairs; - session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); - session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); - - const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); - const std::string op_type = "Atan"; - - // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. - // 1st run will generate the Onnx skeleton file + Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // Check the Onnx skeleton file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - // Check the Qnn context cache binary file is generated - std::string qnn_ctx_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; - EXPECT_TRUE(std::filesystem::exists(qnn_ctx_bin)); - - // 2nd run loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // 3rd run directly loads and run from Onnx skeleton file + Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - context_binary_file); - - // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); - ASSERT_EQ(std::remove(qnn_ctx_bin.c_str()), 0); -} - -// Run QDQ model on HTP 2 times -// 1st run will generate the Onnx skeleton file + Qnn context cache binary file -// Then delete the context bin file to make the 2nd sesssion.Initialize() return the status with code INVALID_GRAPH -TEST_F(QnnHTPBackendTests, ContextBinaryCache_InvalidGraph) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; - std::unordered_map session_option_pairs; - session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); - session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); - - const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); - const std::string op_type = "Atan"; - - // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. - // 1st run will generate the Onnx skeleton file + Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // Check the Onnx skeleton file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - // Check the Qnn context cache binary file is generated - std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; - EXPECT_TRUE(std::filesystem::exists(context_bin)); - // Delete the Qnn context cache binary file - EXPECT_TRUE(std::filesystem::remove(context_bin)); - - // loads and run from Onnx skeleton file + Qnn context cache binary file - onnx::ModelProto model_proto; - onnxruntime::Model qnn_ctx_model; - // Load the QNN context cache model from path specified - ASSERT_STATUS_OK(qnn_ctx_model.Load(ToPathString(context_binary_file), model_proto)); - std::string qnn_ctx_model_data; - model_proto.SerializeToString(&qnn_ctx_model_data); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - std::string provider_type = kCpuExecutionProvider; - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); - - // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - -std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { - const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; - auto& logging_manager = DefaultLoggingManager(); - onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), - IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, - logging_manager.DefaultLogger()); - Graph& graph = model.MainGraph(); - ModelTestBuilder helper(graph); - std::vector shape = {2, 3}; - NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, true, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); - auto* graph_output = helper.MakeOutput(shape); - Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); - ep_context_node.AddAttribute("embed_mode", static_cast(0)); - // The .. in the path will cause INVALID_GRAPH - ep_context_node.AddAttribute("ep_cache_context", external_bin_path); - ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); - ep_context_node.AddAttribute("source", "QNN"); - helper.SetGraphOutputs(); - std::string model_data; - model.ToProto().SerializeToString(&model_data); - - return model_data; -} - -// Create a model with EPContext node. Set the node property ep_cache_context has ".." -// Verify that it return INVALID_GRAPH status -TEST_F(QnnHTPBackendTests, QnnContextBinaryRelativePathTest) { - std::string model_data = CreateQnnCtxModelWithNonEmbedMode("../qnn_context.bin"); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); -} - -// Create a model with EPContext node. Set the node property ep_cache_context has absolute path -// Verify that it return INVALID_GRAPH status -TEST_F(QnnHTPBackendTests, QnnContextBinaryAbsolutePathTest) { -#if defined(_WIN32) - std::string external_ctx_bin_path = "D:/qnn_context.bin"; -#else - std::string external_ctx_bin_path = "/data/qnn_context.bin"; -#endif - std::string model_data = CreateQnnCtxModelWithNonEmbedMode(external_ctx_bin_path); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); -} - -// Create a model with EPContext node. Set the node property ep_cache_context to a file not exist -// Verify that it return INVALID_GRAPH status -TEST_F(QnnHTPBackendTests, QnnContextBinaryFileNotExistTest) { - std::string model_data = CreateQnnCtxModelWithNonEmbedMode("qnn_context_not_exist.bin"); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); -} - -// Create a model with EPContext node. Set the node property ep_cache_context to empty string -// Verify that it return INVALID_GRAPH status -TEST_F(QnnHTPBackendTests, QnnContextBinaryFileEmptyStringTest) { - std::string model_data = CreateQnnCtxModelWithNonEmbedMode(""); - - SessionOptions so; - so.session_logid = "qnn_ctx_model_logger"; - RunOptions run_options; - run_options.run_tag = so.session_logid; - - InferenceSessionWrapper session_object{so, GetEnvironment()}; - - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); - ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); - // Verify the return status with code INVALID_GRAPH - ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::INVALID_GRAPH); -} - -// Run QDQ model on HTP with 2 inputs -// 1st run will generate the Qnn context cache onnx file -// 2nd run will load and run from QDQ model + Qnn context cache model -// 3rd run directly loads and run from Qnn context cache model -TEST_F(QnnHTPBackendTests, ContextBinary2InputsTest) { - ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif - const std::string context_binary_file = "./qnn_context_binary_2inputs_test.onnx"; - std::unordered_map session_option_pairs; - session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); - session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); - - const TestInputDef input_def1({1, 2, 3}, false, -10.0f, 10.0f); - const TestInputDef input_def2({1, 2, 3}, false, -10.0f, 10.0f); - const std::string op_type = "Add"; - - // Runs model with DQ-> Add-> Q and compares the outputs of the CPU and QNN EPs. - // 1st run will generate the Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // Make sure the Qnn context cache binary file is generated - EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); - - // 2nd run loads and run from QDQ model + Qnn context cache model - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - "", // context model file path, not required for this inference - session_option_pairs); - - // 3rd run directly loads and run from Qnn context cache model - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - BuildQDQOpTestCase(op_type, {input_def1, input_def2}, {}, {}), - provider_options, - 14, - ExpectedEPNodeAssignment::All, - QDQTolerance(), - logging::Severity::kERROR, - context_binary_file); - // Clean up - ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); -} - TEST_F(QnnHTPBackendTests, QuantAccuracyTest) { ProviderOptions provider_options; From 319481898c6db9b1268ed23e9480bdfeb0b53bf5 Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Thu, 1 Feb 2024 15:25:33 -0800 Subject: [PATCH 021/207] Give a triton library missing warning instead of silently turn off (#19276) ### Description When USE_ORTMODULE_TRITON is set to 1 but there's no triton library, triton function is silently turned off. This adds a warning --- orttraining/orttraining/python/training/ortmodule/options.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index df3b078788d16..539859a0d58a6 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -379,6 +379,9 @@ def _override_from_env_vars(self): import triton # noqa: F401 except ImportError: pass + self._logger.warning( + "triton library missing. Please install triton with `pip install triton`. Triton feature will be off." + ) else: self.enable_triton = True From 1bdd7d949905e487e15b8e3bc30d3cd197313bdd Mon Sep 17 00:00:00 2001 From: He Li Date: Fri, 2 Feb 2024 07:39:03 +0800 Subject: [PATCH 022/207] Update oneDNN to v3.0.1 in order to support gcc 13 (#19344) ### Description Update the dependency of `oneDNN` to v3.0.1, which fixes a minor bug hindering gcc 13. ### Motivation and Context Referring to [oneDNN-1548](https://github.com/oneapi-src/oneDNN/issues/1548). - When building with `--use_dnnl` using gcc 13.x, it will fail due to this upstream issue. - This is fixed in `v3.0.1` [tag](https://github.com/oneapi-src/oneDNN/tree/v3.0.1) by [this commit](https://github.com/oneapi-src/oneDNN/commit/1d7971ce488da657e23f08488cdb6ef8e484c5e8). --- cmake/external/dnnl.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external/dnnl.cmake b/cmake/external/dnnl.cmake index d7b70640781d0..9eb5fed7a1af6 100644 --- a/cmake/external/dnnl.cmake +++ b/cmake/external/dnnl.cmake @@ -2,7 +2,7 @@ include (ExternalProject) set(DNNL_URL https://github.com/oneapi-src/onednn.git) # If DNNL_TAG is updated, check if MKLML_VERSION and platform.cmake.patch need to be updated. -set(DNNL_TAG v3.0) +set(DNNL_TAG v3.0.1) if(WIN32) set(DNNL_SHARED_LIB dnnl.dll) From 8a2646ce604f7178ee5c57de5d66e7cca7356eb2 Mon Sep 17 00:00:00 2001 From: ironman Date: Fri, 2 Feb 2024 07:52:20 +0800 Subject: [PATCH 023/207] Metrics - llama-2 - Add package name and version to engine of onnxruntime (#19325) ### Description ### Motivation and Context --- .../python/tools/transformers/models/llama/benchmark_all.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py index a8b84729b46be..c6d550d47cf4c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py @@ -253,7 +253,7 @@ def save_results(results, filename): # Save results to csv with standard format records = [] for _, row in df.iterrows(): - if row["Engine"] == "optimum-ort": + if row["Engine"] in ["optimum-ort", "onnxruntime"]: record = BenchmarkRecord( row["Model Name"], row["Precision"], "onnxruntime", row["Device"], ort_pkg_name, ort_pkg_version ) From 13ad922e7ffd4eeeff5ca7199c4d5e7bf9703849 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 1 Feb 2024 16:18:14 -0800 Subject: [PATCH 024/207] Improve MatMulNBits test (#19378) ### Description The test creates millions of threads. This change is to avoid that by using an existing thread pool. ### Motivation and Context --- onnxruntime/test/contrib_ops/matmul_4bits_test.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index e0ed32630277e..2ad20eafc2ef1 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -14,6 +14,8 @@ #include "test/optimizer/graph_transform_test_builder.h" #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/ort_env.h" #include "core/util/qmath.h" #include @@ -21,12 +23,13 @@ #include "gtest/gtest.h" #include "gmock/gmock.h" +extern std::unique_ptr ort_env; namespace onnxruntime { + namespace test { static constexpr int QBits = 4; - void QuantizeDequantize(std::vector& raw_vals, std::vector& quant_vals, std::vector& scales, @@ -34,9 +37,8 @@ void QuantizeDequantize(std::vector& raw_vals, int32_t N, int32_t K, int32_t block_size) { - OrtThreadPoolParams to; - auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, - concurrency::ThreadPoolType::INTRA_OP); + auto& ortenv = **ort_env.get(); + onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool(); MlasQuantizeBlockwise( quant_vals.data(), @@ -48,7 +50,7 @@ void QuantizeDequantize(std::vector& raw_vals, K, N, N, - tp.get()); + tp); // Note that input1_f_vals is NxK after dequant MlasDequantizeBlockwise( @@ -60,7 +62,7 @@ void QuantizeDequantize(std::vector& raw_vals, true, // columnwise quantization K, // number of rows N, // number of columns - tp.get()); + tp); } void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level, From 3a2ab1963a195fe8df59e6220d3f191e5dfe80ee Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 2 Feb 2024 09:59:00 +0800 Subject: [PATCH 025/207] [js/webgpu] Refactor createTensorShapeVariables (#18883) --- .../jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 3 +-- .../webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts | 3 +-- .../webgpu/ops/3rd-party/conv_backprop_webgpu.ts | 2 +- .../webgpu/ops/3rd-party/matmul_packed_webgpu.ts | 4 +--- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 4 +--- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 13 ++++++++++--- js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 8 ++------ js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/expand.ts | 6 ++---- js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts | 4 +--- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 3 +-- js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 4 +--- js/web/lib/wasm/jsep/webgpu/ops/pad.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/pool.ts | 4 ++-- js/web/lib/wasm/jsep/webgpu/ops/reduce.ts | 6 ++---- js/web/lib/wasm/jsep/webgpu/ops/resize.ts | 7 ++----- js/web/lib/wasm/jsep/webgpu/ops/slice.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 5 ++--- js/web/lib/wasm/jsep/webgpu/ops/tile.ts | 6 ++---- js/web/lib/wasm/jsep/webgpu/ops/transpose.ts | 7 ++----- js/web/lib/wasm/jsep/webgpu/ops/where.ts | 7 ++----- 22 files changed, 40 insertions(+), 64 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index bc39bd94e3072..fc2146068de70 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -195,8 +195,7 @@ export const createConv2DMatMulProgramInfo = {type: DataType.int32, data: attributes.strides}, {type: DataType.int32, data: attributes.dilations} ]; appendActivationUniformsData(attributes, programUniforms); - programUniforms.push( - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index d18f8586dd071..b5b6a2a15cd8c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -204,8 +204,7 @@ export const createConv2DTransposeMatMulProgramInfo = {type: DataType.int32, data: pads} ]; appendActivationUniformsData(attributes, programUniforms); - programUniforms.push( - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims)); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, inputs[1].dims)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index ba6776e9d8c94..846ad49c5222b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -269,7 +269,7 @@ export const createConvTranspose2DProgramInfo = {type: DataType.uint32, data: filterDims}, {type: DataType.uint32, data: dilations}, {type: DataType.uint32, data: effectiveFilterDims}, {type: DataType.int32, data: pads}, {type: DataType.uint32, data: inputChannelsPerGroup}, {type: DataType.uint32, data: outputChannelsPerGroup}, - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(inputs[1].dims) + ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims) ]; if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index d9a8d59f731de..8abc27a24861d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -453,9 +453,7 @@ export const createMatmulProgramInfo = {type: DataType.int32, data: dimInner} ]; appendActivationUniformsData(activationAttributes, programUniforms); - programUniforms.push( - ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShapeTemp), - ...createTensorShapeVariables(bShapeTemp)); + programUniforms.push(...createTensorShapeVariables(outerDims, aShapeTemp, bShapeTemp)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; const hasBias = inputs.length > 2; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 51f0c76ed8824..a094fffe239c4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -180,9 +180,7 @@ const createBinaryOpProgramInfo = dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, programUniforms: [ {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, - ...createTensorShapeVariables(a.dims), - ...createTensorShapeVariables(b.dims), - ...createTensorShapeVariables(outputShape), + ...createTensorShapeVariables(a.dims, b.dims, outputShape) ], }), }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 3de57d5ac7f7c..516094d0ef87b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -259,9 +259,16 @@ export const tensorTypeToWsglValueType = (type: DataType, components: 1|2|3|4 = return typeof mappedType === 'string' ? mappedType : mappedType[1]; }; -export const createTensorShapeVariables = (dims: readonly number[]): ProgramUniform[] => dims.length === 0 ? - [] : - [{type: DataType.uint32, data: dims}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dims)}]; +export const createTensorShapeVariables = (...dims: ReadonlyArray): ProgramUniform[] => { + const programUniforms: ProgramUniform[] = []; + dims.forEach(dim => { + if (dim.length !== 0) { + programUniforms.push( + {type: DataType.uint32, data: dim}, {type: DataType.uint32, data: ShapeUtil.computeStrides(dim)}); + } + }); + return programUniforms; +}; /** * A helper function to get maximum vector size for specified data length diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 3c2c3cc4e046c..8495f9040a1b6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -35,9 +35,7 @@ export const createGroupedConvProgramInfo = {type: DataType.uint32, data: outputChannelsPerGroup} ]; appendActivationUniformsData(attributes, programUniforms); - programUniforms.push( - ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), - ...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); @@ -134,9 +132,7 @@ export const createGroupedConvVectorizeProgramInfo = {type: DataType.int32, data: [attributes.pads[0], attributes.pads[1]]} ]; appendActivationUniformsData(attributes, programUniforms); - programUniforms.push( - ...createTensorShapeVariables(xShape), ...createTensorShapeVariables(wShape), - ...createTensorShapeVariables(outputShapeInShader)); + programUniforms.push(...createTensorShapeVariables(xShape, wShape, outputShapeInShader)); const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts index fb17202cd042f..6080301d9946b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts @@ -55,7 +55,7 @@ const createCumsumProgramInfo = dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axis}, - ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape) + ...createTensorShapeVariables(inputShape, inputShape) ] }), diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index f8fdb63160380..80ee906423e19 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -84,10 +84,8 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => ${assignment}`; }; - const programUniforms: ProgramUniform[] = [ - {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape), - ...createTensorShapeVariables(outputShape) - ]; + const programUniforms: ProgramUniform[] = + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)]; return { name: 'Expand', shaderCache: {hint: `${outputShape.length}`, inputDependencies: ['rank']}, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts index a2d4e3d28f7c5..4ab6c175a67e2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts @@ -51,9 +51,7 @@ const createGatherElementsProgramInfo = {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, {type: DataType.uint32, data: axis} ]; - programUniforms.push(...createTensorShapeVariables(inputShape)); - programUniforms.push(...createTensorShapeVariables(indicesShape)); - programUniforms.push(...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(inputShape, indicesShape, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank']; // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index f2c71a9cd4188..5c31e6dd86c00 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -35,8 +35,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: outputSize}, {type: DataType.int32, data: axisDimLimit}, - {type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(inputs[1].dims), ...createTensorShapeVariables(outputShape) + {type: DataType.uint32, data: axis}, ...createTensorShapeVariables(inputs[0].dims, inputs[1].dims, outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 2096b898b5d40..2f652dbd310ab 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -26,7 +26,7 @@ const createInstanceNormProgramInfo = const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'type', 'type']; const programUniforms: ProgramUniform[] = [{type: DataType.uint32, data: normSize}, {type: DataType.uint32, data: normPackedSize}]; - programUniforms.push(...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)); + programUniforms.push(...createTensorShapeVariables(inputShape, inputShape)); const getShaderSource = (shaderHelper: ShaderHelper) => { const x = inputVariable('x', inputs[0].dataType, inputShape.length, components); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index b263451b99134..0c533974e2b26 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -34,9 +34,7 @@ export const createNaiveMatmulProgramInfo = {type: DataType.uint32, data: K} ]; appendActivationUniformsData(activationAttributes, programUniforms); - programUniforms.push( - ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape), - ...createTensorShapeVariables(bShape)); + programUniforms.push(...createTensorShapeVariables(outerDims, aShape, bShape)); if (hasBias) { programUniforms.push(...createTensorShapeVariables(inputs[2].dims)); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index 9f5e60773f080..236fc29fdf1ab 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -158,7 +158,7 @@ const createPadProgramInfo = (inputs: readonly TensorView[], attributes: PadAttr programUniforms.push({type: inputs[0].dataType, data: attributes.value}); } - programUniforms.push(...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(inputs[0].dims, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; const getShaderSource = (shaderHelper: ShaderHelper) => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 70b8acc3146a0..4e933573b9137 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -298,7 +298,7 @@ const createAveragePoolProgramInfo = } const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo(outputShape, adjustedAttributes); - programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(input.dims, outputShape)); const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; return { name, @@ -370,7 +370,7 @@ const createMaxPoolProgramInfo = const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank']; const [programUniforms, uniforms, hasPads, pwStartEndNotZero, phStartEndNotZero] = getUniformAndPadInfo(outputShape, adjustedAttributes); - programUniforms.push(...createTensorShapeVariables(input.dims), ...createTensorShapeVariables(outputShape)); + programUniforms.push(...createTensorShapeVariables(input.dims, outputShape)); return { name, shaderCache: diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index 123eb38a1fb93..e8205ba6fd928 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -100,10 +100,8 @@ export const createReduceProgramInfo = getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ - {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape), - ...createTensorShapeVariables(outputShape) - ] + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputShape, outputShape)] }), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index edfd856aeb850..2c6b537de1f00 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -642,11 +642,8 @@ const createResizeProgramInfo = outputs: [{dims: outputShape, dataType: inputTensor.dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, programUniforms: [ - {type: DataType.uint32, data: outputSize}, - {type: DataType.float, data: scales}, - {type: DataType.float, data: roi}, - ...createTensorShapeVariables(inputShape), - ...createTensorShapeVariables(outputShape), + {type: DataType.uint32, data: outputSize}, {type: DataType.float, data: scales}, + {type: DataType.float, data: roi}, ...createTensorShapeVariables(inputShape, outputShape) ] }) }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 6baa634f69f82..a5e71f30e5966 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -157,7 +157,7 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice const programUniforms: ProgramUniform[] = [ {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: starts}, {type: DataType.int32, data: signs}, {type: DataType.uint32, data: steps}, - ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape) + ...createTensorShapeVariables(inputs[0].dims, outputShape) ]; const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 0b703de2ffa1c..14d6f37927590 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -83,9 +83,8 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split outputs[i] = outputVariable(`output${i}`, dataType, outputShape); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } - programUniforms.push({type: DataType.uint32, data: sizeInSplitAxis}); - programUniforms.push(...createTensorShapeVariables(inputShape)); - outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape))); + programUniforms.push( + {type: DataType.uint32, data: sizeInSplitAxis}, ...createTensorShapeVariables(inputShape, ...outputShapes)); const getShaderSource = (shaderHelper: ShaderHelper) => ` ${ shaderHelper.registerUniform('input_size', 'u32') diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index b080767d2faac..f9728575fe072 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -79,10 +79,8 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf getRunData: () => ({ outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ - {type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape) - ], + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)], }), getShaderSource, }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 920da04398832..7ae801222b875 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -65,11 +65,8 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu return { outputs: [{dims: outputShape, dataType: inputs[0].dataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ - {type: DataType.uint32, data: outputSize}, - ...createTensorShapeVariables(inputs[0].dims), - ...createTensorShapeVariables(outputShape), - ], + programUniforms: + [{type: DataType.uint32, data: outputSize}, ...createTensorShapeVariables(inputs[0].dims, outputShape)], }; }, getShaderSource, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index 51e8f56c229bd..cfee07a9239d7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -97,11 +97,8 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo => getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}, - programUniforms: [ - {type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC), - ...createTensorShapeVariables(dimsA), ...createTensorShapeVariables(dimsB), - ...createTensorShapeVariables(outputShape) - ], + programUniforms: + [{type: DataType.uint32, data: vecSize}, ...createTensorShapeVariables(dimsC, dimsA, dimsB, outputShape)], }), }; }; From b71be3c1e31abba44f95ad6bc568ab0499b34eaa Mon Sep 17 00:00:00 2001 From: zz002 Date: Fri, 2 Feb 2024 15:00:50 +0800 Subject: [PATCH 026/207] [VitisAI] Resolving compilation errors when using USE_VITISAI (#19368) ### Description Resolving compilation errors when using USE_VITISAI ### Motivation and Context There will be compilation errors when USE_VITISAI is enabled This is in addition to the #19058 Co-authored-by: Zhenze Wang --- onnxruntime/core/providers/provider_factory_creators.h | 4 ++++ .../core/providers/vitisai/vitisai_provider_factory.cc | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/provider_factory_creators.h b/onnxruntime/core/providers/provider_factory_creators.h index 6a4ab6a3d2113..42a58097e1635 100644 --- a/onnxruntime/core/providers/provider_factory_creators.h +++ b/onnxruntime/core/providers/provider_factory_creators.h @@ -78,6 +78,10 @@ #include "core/providers/tvm/tvm_provider_factory_creator.h" #endif +#if defined(USE_VITISAI) +#include "core/providers/vitisai/vitisai_provider_factory_creator.h" +#endif + #if defined(USE_XNNPACK) #include "core/providers/xnnpack/xnnpack_provider_factory_creator.h" #endif diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 5895e1973f231..dc34419ef936f 100755 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -10,8 +10,6 @@ #include "./vitisai_execution_provider.h" #include "core/framework/execution_provider.h" -#include "core/session/abi_session_options_impl.h" - using namespace onnxruntime; namespace onnxruntime { From 09d5c1b56f94c9df216ef7de020b920fc52a0c14 Mon Sep 17 00:00:00 2001 From: petermcaughan Date: Thu, 1 Feb 2024 23:53:32 -0800 Subject: [PATCH 027/207] Fix DEBUG_GENERATION build (#19383) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Currently, ORT will fail a build when the flag DEBUG_GENERATION is set to 1 (used to debug BeamSearch and GreedySearch) in [console_dumper.h](https://github.com/microsoft/onnxruntime/blob/3b63d85c253c50099c70ba0db6c141b842bc7cda/onnxruntime/contrib_ops/cpu/utils/console_dumper.h#L12) with the following error: `onnxruntime/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h:270:15: error: ‘DumpScores’ was not declared in this scope` This is because it is defined in `logits_processor.cc`, and a debugging artifact was passed in an earlier PR where this function is called from `logits_processor.h` before it is defined [[link](https://github.com/microsoft/onnxruntime/blob/3a2ab1963a195fe8df59e6220d3f191e5dfe80ee/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h#L270)]. Builds with the flag have been broken since that PR was merged. This PR moves DumpScores() definition from `logits_processor.cc` to `logits_processor.h` so that all debug statements can be used correctly in `logits_processor.cc` and `logits_processor.h` and build succeeds with this debug flag. --------- Co-authored-by: Peter McAughan --- .../cpu/transformers/logits_processor.cc | 36 ------------------- .../cpu/transformers/logits_processor.h | 4 --- 2 files changed, 40 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index f39f090c78b0c..c74e9160cc43f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -17,14 +17,6 @@ namespace onnxruntime { namespace contrib { namespace transformers { -#ifdef DEBUG_GENERATION -template -void DumpScores(const char* name, const NextTokenScores& next_token_scores) { - std::cout << name << std::endl; - ORT_UNUSED_PARAMETER(next_token_scores); -} -#endif - // Interface for all scorers for beam search or beam sample. template MinLengthLogitsProcessor::MinLengthLogitsProcessor(int min_length, int eos_token_id) @@ -36,10 +28,6 @@ void MinLengthLogitsProcessor::Process(const ISequences* sequences, if (sequences->GetSequenceLength() < min_length_) { next_token_scores.SetScore(eos_token_id_, std::numeric_limits::lowest()); } - -#ifdef DEBUG_GENERATION - DumpScores("MinLengthLogitsProcessor", next_token_scores); -#endif } template @@ -68,10 +56,6 @@ void RepetitionPenaltyLogitsProcessor::Process(const ISequences* sequences, beam_token_scores[word_id] = (score < 0 ? score * penalty_ : score / penalty_); } } - -#ifdef DEBUG_GENERATION - DumpScores("RepetitionPenaltyLogitsProcessor", next_token_scores); -#endif } template @@ -109,10 +93,6 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, beam_token_scores[word_id] = std::numeric_limits::lowest(); } } - -#ifdef DEBUG_GENERATION - DumpScores("NoRepeatNGramLogitsProcessor", next_token_scores); -#endif } template @@ -136,10 +116,6 @@ void VocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, } } } - -#ifdef DEBUG_GENERATION - DumpScores("VocabMaskLogitsProcessor", next_token_scores); -#endif } template @@ -171,10 +147,6 @@ void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, } } } - -#ifdef DEBUG_GENERATION - DumpScores("PrefixVocabMaskLogitsProcessor", next_token_scores); -#endif } template @@ -193,10 +165,6 @@ void TemperatureLogitsProcessor::Process(const ISequences* /*sequences*/, *p /= temperature_; ++p; } - -#ifdef DEBUG_GENERATION - DumpScores("TemperatureLogitsProcessor", next_token_scores); -#endif } template @@ -218,10 +186,6 @@ void PresencePenaltyLogitsProcessor::Process(const ISequences*, for (size_t i = 0; i < next_token_scores.scores.size(); i++) { *p -= presence_mask_[i] * presence_penalty_; } - -#ifdef DEBUG_GENERATION - DumpScores("PresencePenaltyLogitsProcessor", next_token_scores); -#endif } void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 4688ff272cee9..03d4e89ac20fe 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -265,10 +265,6 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } } } - -#ifdef DEBUG_GENERATION - DumpScores("TimestampLogitsProcessor", next_token_scores); -#endif } private: From a2eb967008c0a9443d267b41c387e0600202fb3a Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 2 Feb 2024 00:22:16 -0800 Subject: [PATCH 028/207] Fix Split index bugs uncovered by QNN SDK 2.19 (#19381) ### Description - When converting ONNX split sizes to QNN split indices, do not include the split at index 0. QNN 2.19 assumes index 0 is implicit and throws a validation error if provided. - Fix bug when using an ONNX Split operator with a `num_outputs` attribute that does not evenly divide into `shape[axis]`. The ONNX spec states that the last chunk should be smaller, but QNN EP made the last chunk larger. - Fix bug when using an ONNX Split operator with a `split` input. QNN EP was incorrectly passing the split sizes as split indices without conversion. ### Motivation and Context QNN SDK 2.19 updated validation criteria for Split operators. QNN EP was previously passing a split index that should have been implicit. Also, discovered a bugs when using `num_outputs` attribute and `split` input. --- .../qnn/builder/opbuilder/split_op_builder.cc | 40 ++++++++++++------ .../test/providers/qnn/split_op_test.cc | 41 +++++++++++++++---- 2 files changed, 61 insertions(+), 20 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc index f4b0d1ff59175..9849a05db329c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc @@ -55,6 +55,19 @@ Status SplitOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +// Converts an ONNX list of split lengths to a QNN list of split indices. +// Note that the first split index at 0 is implicit (QNN SDK >= 2.19 will raise a validation error if included). +static void ConvertSplitLengthsToSplitIndices(gsl::span split_lengths, + std::vector& split_indices) { + uint32_t split_it = 0; + for (size_t i = 0; i < split_lengths.size(); ++i) { + if (i > 0) { // Do not include the 0th split index. + split_indices.push_back(split_it); + } + split_it += SafeInt(split_lengths[i]); + } +} + Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::vector&& input_names, @@ -79,22 +92,15 @@ Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr const int64_t* tensor_data = reinterpret_cast(unpacked_tensor.data()); size_t tensor_byte_size = unpacked_tensor.size(); size_t size = tensor_byte_size / sizeof(int64_t); - split_index.push_back(0); // QNN need the start index of each range and starts from 0 - std::transform(tensor_data, tensor_data + size, std::back_inserter(split_index), - [](int64_t item) { return SafeInt(item); }); - split_index.pop_back(); + ConvertSplitLengthsToSplitIndices({tensor_data, size}, split_index); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN doesn't support dynamic split"); } } else { NodeAttrHelper node_helper(node_unit); if (node_helper.HasAttr("split")) { - auto split = node_helper.Get("split", std::vector{0}); - uint32_t split_it = 0; - for (size_t i = 0; i < split.size(); ++i) { - split_index.push_back(split_it); - split_it += split[i]; - } + auto split_lengths = node_helper.Get("split", std::vector{0}); + ConvertSplitLengthsToSplitIndices(split_lengths, split_index); } } @@ -105,11 +111,19 @@ Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr "Cannot get shape"); ORT_ENFORCE(static_cast(input_shape.size()) > axis_value, "axis not valid!"); ORT_RETURN_IF_NOT(input_shape.at(axis_value) > 0, "Shape value not valid!"); - auto num_outputs = node_unit.Outputs().size(); - auto step = SafeInt(input_shape.at(axis_value) / num_outputs); + + // ONNX spec states that if not evenly divisible by `num_outputs`, the last chunk is smaller. + // Therefore, we have to use ceil() when computing shape[axis] / num_outputs. + // See: core/providers/cpu/tensor/split.cc::PrepareForCompute() + const float num_outputs = static_cast(node_unit.Outputs().size()); + const float split_dim_size = static_cast(input_shape[axis_value]); + const uint32_t step = SafeInt(std::ceil(split_dim_size / num_outputs)); uint32_t split_it = 0; + for (size_t i = 0; i < num_outputs; ++i) { - split_index.push_back(split_it); + if (i > 0) { // 0th split index is implicit (QNN >= 2.19 raises validation error if included) + split_index.push_back(split_it); + } split_it += step; } } diff --git a/onnxruntime/test/providers/qnn/split_op_test.cc b/onnxruntime/test/providers/qnn/split_op_test.cc index 57e4b211777bb..6dc721edb421e 100644 --- a/onnxruntime/test/providers/qnn/split_op_test.cc +++ b/onnxruntime/test/providers/qnn/split_op_test.cc @@ -302,19 +302,46 @@ TEST_F(QnnHTPBackendTests, Split_Int32_Opset13) { // Test 8-bit QDQ Split opset 18 on HTP backend: equal split of axis 0 via 'num_outputs' attribute // and 'split' input. TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset18) { + // Split 6 into 3 outputs of lengths [2, 2, 2] + TestInputDef input_def({6, 2}, false, + {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f, 9.0f, 10.0f, 11.0f}); + // Use 'split' input (initializer). - RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), - {2, 2}, // split - 0, // axis - -1, // num_outputs - 18, // opset + RunQDQSplitOpTestOnHTP(input_def, + {2, 2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset ExpectedEPNodeAssignment::All); // Use 'num_outputs' attribute. - RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + RunQDQSplitOpTestOnHTP(input_def, + {}, // split (use num_outputs instead) + 0, // axis + 3, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Split opset 18 on HTP backend. Use an uneven split (last chunk should be smaller). +TEST_F(QnnHTPBackendTests, Split_NonEqual_Axis0_Opset18) { + // Split 7 into 3 outputs of lengths [3, 3, 1] + TestInputDef input_def({7, 2}, false, + {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f}); + + // Use a `split` input with uneven split lengths. + RunQDQSplitOpTestOnHTP(input_def, + {3, 3, 1}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + + // Use a `num_outputs` attribute that does not evenly divide into shape[axis]. + RunQDQSplitOpTestOnHTP(input_def, {}, // split (use num_outputs instead) 0, // axis - 2, // num_outputs + 3, // num_outputs 18, // opset ExpectedEPNodeAssignment::All); } From 9139bdda02d555417dcd5a97af411225bd35021b Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Fri, 2 Feb 2024 16:34:51 +0800 Subject: [PATCH 029/207] [ROCm] CK implementation support causal mask (#18943) Use `MaskingSpecialization::MaskOutUpperTriangle` to support causal mask in ck implementation. --- .../impl.cuh | 31 ++- .../impl_fp16.cu | 21 ++ .../impl_fp16_biased.cu | 21 ++ .../impl_fp16_biased_biased.cu | 21 ++ ...ed_gemm_softmax_gemm_permute_pipelines.cuh | 214 ++++++++++-------- .../kernels/gemm_softmax_gemm_permute_test.py | 152 +++++++++++-- .../kernels/rocm/gemm_softmax_gemm_permute.cu | 17 +- 7 files changed, 364 insertions(+), 113 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh index 0599318a4022d..be8508670e4b1 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh @@ -31,7 +31,7 @@ using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecializatio using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface +using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; @@ -141,6 +141,35 @@ std::vector, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +// fp16, biased, non-masked +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +// fp16, biased, fp16 masked, basically, two bias +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu index 181e47f012c99..2e32a6594d164 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< return instances; } +using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu index 1577bdf397fa5..91da8d9e1f9a8 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< return instances; } +using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu index 14de59234356b..b08123be18977 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< return instances; } +using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 78983ac95e672..54dda4bfa6d2c 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -732,122 +732,154 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp -auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { +template +auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { constexpr const int kNumBiasBuffer = static_cast(USE_BIAS) + static_cast(USE_MASK); using Nop = ck::tensor_operation::element_wise::PassThrough; using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), + "attention mode is not supported, got ", params->attention->mode); + if constexpr (USE_BIAS) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->bias_buffer == nullptr, "biased version only support input with bias"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->bias_buffer != nullptr, "non-biased version only support input without bias"); + } + if constexpr (USE_MASK) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), + "mask type is not supported, got ", params->attention->mask_type); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->mask_index_buffer == nullptr, "masked version only support input with mask"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); + } + + auto attn = params->attention; + const int& G0 = attn->batch_size; + const int& G1 = attn->num_heads; + const int& M = attn->sequence_length; + const int& N = attn->total_sequence_length; + const int& K = attn->head_size; + const int& O = attn->v_head_size; + { + auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); + ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); + } + + auto [qs, ks, vs] = GetQkvStrides(attn); + std::vector q_buffer_lengths = {G0, G1, M, K}; + std::vector q_buffer_strides = qs.template ForBNSHCoord>(); + std::vector k_buffer_lengths = {G0, G1, N, K}; + std::vector k_buffer_strides = ks.template ForBNSHCoord>(); + std::vector v_buffer_lengths = {G0, G1, O, N}; + std::vector v_buffer_strides = vs.template ForBNHSCoord>(); + std::vector out_buffer_lengths = {G0, G1, M, O}; + std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 + + std::array bias_buffers{}; + std::array, kNumBiasBuffer> bias_lengths{}; + std::array, kNumBiasBuffer> bias_strides{}; + if constexpr (USE_BIAS) { + bias_buffers[0] = const_cast(params->bias_buffer); + bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) + bias_strides[0] = {G1 * M * N, M * N, N, 1}; + } + if constexpr (USE_MASK) { + bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; + bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) + if (params->mask_index_dims.size() == 2) { // [B,T] + bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; + } else if (params->mask_index_dims.size() == 3) { // [B,S,T] + bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; + } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] + bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; + } else { + ORT_ENFORCE(false, "Unreachable"); + } + } + + auto arg = impl->MakeArgumentPointer( + params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, + bias_buffers, // Gemm1 bias, as attention mask + {}, // Gemm2 bias + q_buffer_lengths, q_buffer_strides, + k_buffer_lengths, k_buffer_strides, + v_buffer_lengths, v_buffer_strides, + out_buffer_lengths, out_buffer_strides, + bias_lengths, bias_strides, + {}, + {}, + Nop{}, + Nop{}, + Acc0ElementOp{params->scale}, + Nop{}, + Nop{}); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support the params"); + + if constexpr (USE_MASK) { + ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); + } + + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); +} + +template +auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; using D0DataType = typename ck::detail::tuple_concat< std::conditional_t, ck::Tuple<>>, std::conditional_t, ck::Tuple<>>>::type; - constexpr static auto MaskingSpec = + constexpr static auto MaskingSpecMaskDisabled = ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + constexpr static auto MaskingSpecMaskOutUpperTriangle = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + + std::vector>>> + ret; - std::vector>>> ret; for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpec>()) { + CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) { auto type_string = impl->GetTypeString(); auto invoker = impl->MakeInvokerPointer(); auto op = [impl = std::move(impl), invoker = std::move(invoker)]( const GemmSoftmaxGemmPermuteParams* params) -> Status { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), - "attention mode is not supported, got ", params->attention->mode); - if constexpr (USE_BIAS) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer == nullptr, "biased version only support input with bias"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer != nullptr, "non-biased version only support input without bias"); - } - if constexpr (USE_MASK) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), - "mask type is not supported, got ", params->attention->mask_type); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer == nullptr, "masked version only support input with mask"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); - } + params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); - auto attn = params->attention; - const int& G0 = attn->batch_size; - const int& G1 = attn->num_heads; - const int& M = attn->sequence_length; - const int& N = attn->total_sequence_length; - const int& K = attn->head_size; - const int& O = attn->v_head_size; - { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); - } + return GetArgAndRunInvoker(impl, invoker, params); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); + } - auto [qs, ks, vs] = GetQkvStrides(attn); - std::vector q_buffer_lengths = {G0, G1, M, K}; - std::vector q_buffer_strides = qs.template ForBNSHCoord>(); - std::vector k_buffer_lengths = {G0, G1, N, K}; - std::vector k_buffer_strides = ks.template ForBNSHCoord>(); - std::vector v_buffer_lengths = {G0, G1, O, N}; - std::vector v_buffer_strides = vs.template ForBNHSCoord>(); - std::vector out_buffer_lengths = {G0, G1, M, O}; - std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 - - std::array bias_buffers{}; - std::array, kNumBiasBuffer> bias_lengths{}; - std::array, kNumBiasBuffer> bias_strides{}; - if constexpr (USE_BIAS) { - bias_buffers[0] = const_cast(params->bias_buffer); - bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - bias_strides[0] = {G1 * M * N, M * N, N, 1}; - } - if constexpr (USE_MASK) { - bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; - bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - if (params->mask_index_dims.size() == 2) { // [B,T] - bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; - } else if (params->mask_index_dims.size() == 3) { // [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else { - ORT_ENFORCE(false, "Unreachable"); - } - } + for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) { + auto type_string = impl->GetTypeString(); - auto arg = impl->MakeArgumentPointer( - params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, - bias_buffers, // Gemm1 bias, as attention mask - {}, // Gemm2 bias - q_buffer_lengths, q_buffer_strides, - k_buffer_lengths, k_buffer_strides, - v_buffer_lengths, v_buffer_strides, - out_buffer_lengths, out_buffer_strides, - bias_lengths, bias_strides, - {}, - {}, - Nop{}, - Nop{}, - Acc0ElementOp{params->scale}, - Nop{}, - Nop{}); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - - if constexpr (USE_MASK) { - ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); - } - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); + auto invoker = impl->MakeInvokerPointer(); + auto op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GemmSoftmaxGemmPermuteParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle"); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->attention->sequence_length != params->attention->total_sequence_length, + "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); + + return GetArgAndRunInvoker(impl, invoker, params); }; ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); } + return ret; } #endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py index 6e1e431842a56..802d924c27b62 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py @@ -44,6 +44,7 @@ def get_ck_binding_name(dtype, biased: bool, masked: bool): num_heads = [8, 12] head_sizes = [64] biaseds = [False, True] +causals = [False] mask_dims = [0, 2, 3, 4] @@ -81,8 +82,57 @@ def maybe_pack_q_k_v_bnsh_for_device_on_host(q, k, v, dtype, qkv_format): raise NotImplementedError +def _make_causal_mask( + seqence_length, + total_sequence_length, + dtype: np.dtype, +): + """ + Make causal mask used for Attention with attribute unidirectional == 1. + The mask is a upper triangular matrix with shape [sequence_length, total_sequence_length]. + Putting a 1 indicates that the token at this position should be masked. + For Example: + sequence_length = 5, total_sequence_length = 5, + mask: [[0. 1. 1. 1. 1.] + [0. 0. 1. 1. 1.] + [0. 0. 0. 1. 1.] + [0. 0. 0. 0. 1.] + [0. 0. 0. 0. 0.]] + seqence_length = 5, total_seqence_length = 3, + mask: [[1. 1. 1.] + [1. 1. 1.] + [0. 1. 1.] + [0. 0. 1.] + [0. 0. 0.]] + seqence_length = 5, total_seqence_length = 7, + mask: [[0. 0. 0. 1. 1. 1. 1.] + [0. 0. 0. 0. 1. 1. 1.] + [0. 0. 0. 0. 0. 1. 1.] + [0. 0. 0. 0. 0. 0. 1.] + [0. 0. 0. 0. 0. 0. 0.]] + """ + mask = np.full((seqence_length, seqence_length), 1) + mask_cond = np.arange(mask.shape[-1]) + mask = np.where(mask_cond < (mask_cond + 1).reshape(mask.shape[-1], 1), 0, mask) + + mask = mask.astype(dtype) + + if total_sequence_length - seqence_length > 0: + mask = np.concatenate( + [np.zeros((seqence_length, total_sequence_length - seqence_length), dtype=dtype), mask], axis=-1 + ) + + if total_sequence_length - seqence_length < 0: + mask = mask[:, -total_sequence_length:] + + correct_mask = np.full((seqence_length, total_sequence_length), 1) + for i in range(seqence_length): + correct_mask[i][:] = sum(mask[i]) != total_sequence_length + return mask, correct_mask + + def _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, qkv_format + f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format ): v_head_size = head_size q_shape = [batch, num_heads, seqlen, head_size] @@ -123,6 +173,8 @@ def _test_gemm_softmax_gemm_permute( pre_softmax_attn_scores = pre_softmax_attn_scores * scale if attn_bias is not None: pre_softmax_attn_scores = pre_softmax_attn_scores + attn_bias + + correct_causal_mask = np.full((seqlen, total_seqlen), 1) if attn_mask is not None: filter_value = -10000.0 if mask_dim == 4: @@ -131,7 +183,18 @@ def _test_gemm_softmax_gemm_permute( else: converted_mask = (1 - attn_mask.reshape(mask_shape_broadcasted)) * filter_value pre_softmax_attn_scores = pre_softmax_attn_scores + converted_mask + if causal: + filter_value = np.finfo(dtype).min + causal_mask, correct_causal_mask = _make_causal_mask(seqlen, total_seqlen, pre_softmax_attn_scores.dtype) + causal_mask = np.broadcast_to(causal_mask, pre_softmax_attn_scores.shape) * filter_value + pre_softmax_attn_scores = pre_softmax_attn_scores + causal_mask attn_scores = softmax(pre_softmax_attn_scores, axis=-1) + + # apply mask to attn_scores to correct softmax result, in c++ implementation, if all values in a row are masked, + # the softmax result in this row will be filled with 0. + correct_causal_mask = np.broadcast_to(correct_causal_mask, pre_softmax_attn_scores.shape) + attn_scores = attn_scores * correct_causal_mask + attn = matmul(attn_scores, v) ref = np.swapaxes(attn, 2, 1) # permute 0213 @@ -154,6 +217,7 @@ def _test_gemm_softmax_gemm_permute( head_size, mask_dim, scale, + causal, qkv_format, dev_q, dev_k, @@ -202,12 +266,26 @@ def _test_gemm_softmax_gemm_permute( @pytest.mark.parametrize("total_seqlen", total_seqlens) @pytest.mark.parametrize("seqlen", seqlens) @pytest.mark.parametrize("batch", [16]) +@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("dtype", ["float16", "float32"]) -def test_gemm_softmax_gemm_permute_generic(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim): +def test_gemm_softmax_gemm_permute_generic( + dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim +): f = getattr(ke, "GemmSoftmaxGemmPermuteGeneric_" + dtype_to_suffix(dtype)) scale = 1.0 / np.sqrt(head_size) _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, ke.qkv_format.Q_K_V_BNSH + f, + dtype, + batch, + seqlen, + total_seqlen, + nhead, + head_size, + biased, + mask_dim, + scale, + causal, + ke.qkv_format.Q_K_V_BNSH, ) @@ -218,14 +296,26 @@ def test_gemm_softmax_gemm_permute_generic(dtype, batch, seqlen, total_seqlen, n @pytest.mark.parametrize("total_seqlen", [128]) @pytest.mark.parametrize("seqlen", [64]) @pytest.mark.parametrize("batch", [16]) +@pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("dtype", ["float16", "float32"]) def test_gemm_softmax_gemm_permute_generic_nested_tunable( - dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim + dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim ): f = getattr(ke, "GemmSoftmaxGemmPermuteGenericNestedTunable_" + dtype_to_suffix(dtype)) scale = 1.0 / np.sqrt(head_size) _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, ke.qkv_format.Q_K_V_BNSH + f, + dtype, + batch, + seqlen, + total_seqlen, + nhead, + head_size, + biased, + mask_dim, + scale, + causal, + ke.qkv_format.Q_K_V_BNSH, ) @@ -237,12 +327,24 @@ def test_gemm_softmax_gemm_permute_generic_nested_tunable( @pytest.mark.parametrize("total_seqlen", total_seqlens) @pytest.mark.parametrize("seqlen", seqlens) @pytest.mark.parametrize("batch", batches) +@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("dtype", dtypes) -def test_gemm_softmax_gemm_permute_ck(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim): +def test_gemm_softmax_gemm_permute_ck(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim): f = getattr(ke, get_ck_binding_name(dtype, biased, mask_dim != 0)) scale = 1.0 / np.sqrt(head_size) _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, ke.qkv_format.Q_K_V_BNSH + f, + dtype, + batch, + seqlen, + total_seqlen, + nhead, + head_size, + biased, + mask_dim, + scale, + causal, + ke.qkv_format.Q_K_V_BNSH, ) @@ -253,12 +355,26 @@ def test_gemm_softmax_gemm_permute_ck(dtype, batch, seqlen, total_seqlen, nhead, @pytest.mark.parametrize("total_seqlen", [128]) @pytest.mark.parametrize("seqlen", [64]) @pytest.mark.parametrize("batch", [16]) +@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("dtype", ["float16"]) -def test_gemm_softmax_gemm_permute_tunable(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim): +def test_gemm_softmax_gemm_permute_tunable( + dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim +): f = getattr(ke, "GemmSoftmaxGemmPermuteTunable_" + dtype_to_suffix(dtype)) scale = 1.0 / np.sqrt(head_size) _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, ke.qkv_format.Q_K_V_BNSH + f, + dtype, + batch, + seqlen, + total_seqlen, + nhead, + head_size, + biased, + mask_dim, + scale, + causal, + ke.qkv_format.Q_K_V_BNSH, ) @@ -278,16 +394,17 @@ def test_gemm_softmax_gemm_permute_tunable(dtype, batch, seqlen, total_seqlen, n @pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") @pytest.mark.parametrize("mask_dim", [0], ids=get_mask_dim_id) @pytest.mark.parametrize("biased", [False], ids=get_biased_id) +@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("batch, seqlen, total_seqlen, nhead, head_size, qkv_format_name", stabel_diffusion_configs) @pytest.mark.parametrize("dtype", dtypes) def test_gemm_softmax_gemm_permute_ck_sd( - dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, qkv_format_name + dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim, qkv_format_name ): qkv_format = getattr(ke.qkv_format, qkv_format_name) f = getattr(ke, get_ck_binding_name(dtype, biased, mask_dim != 0)) scale = 1.0 / np.sqrt(head_size) _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, qkv_format + f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, causal, qkv_format ) @@ -316,7 +433,7 @@ def report(self): def profile_gemm_softmax_gemm_permute_func( - f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, qkv_format + f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format ): v_head_size = head_size q_shape = [batch, num_heads, seqlen, head_size] @@ -369,6 +486,7 @@ def profile_gemm_softmax_gemm_permute_func( head_size, mask_dim, scale, + causal, qkv_format, dev_q, dev_k, @@ -402,10 +520,10 @@ def profile_gemm_softmax_gemm_permute_func( def profile_with_args( - dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, qkv_format, *, sort=False + dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, causal, mask_dim, scale, qkv_format, *, sort=False ): with ke.benchmark(sort): - args = (dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, qkv_format) + args = (dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format) if qkv_format == ke.qkv_format.Q_K_V_BNSH: profile_gemm_softmax_gemm_permute_func( getattr(ke, "GemmSoftmaxGemmPermuteGeneric_" + dtype_to_suffix(dtype)), *args @@ -429,6 +547,7 @@ def profile(): nhead, head_size, biased=False, + causal=False, mask_dim=0, qkv_format=getattr(ke.qkv_format, qkv_format_name), scale=0.125, @@ -436,7 +555,7 @@ def profile(): ) print() - for args in product(dtypes, batches, seqlens, total_seqlens, num_heads, head_sizes, biaseds, mask_dims): + for args in product(dtypes, batches, seqlens, total_seqlens, num_heads, head_sizes, biaseds, causals, mask_dims): profile_with_args(*args, qkv_format=ke.qkv_format.Q_K_V_BNSH, scale=0.125, sort=True) print() @@ -455,6 +574,7 @@ def profile(): group.add_argument("head_size", type=int) group.add_argument("biased", type=int, choices=[0, 1], default=0) group.add_argument("mask_dim", type=int, choices=[0, 2, 3, 4], default=2, help="0 for mask disabled") + group.add_argument("causal", type=int, choices=[0, 1], default=0) group.add_argument("--scale", type=float, default=None, help="default to 1.0/sqrt(head_size)") group.add_argument( "--qkv_format", @@ -471,6 +591,7 @@ def profile(): profile() else: args = parser.parse_args() + print(args) profile_with_args( args.dtype, args.batch, @@ -479,6 +600,7 @@ def profile(): args.num_heads, args.head_size, args.biased, + args.causal, args.mask_dim, args.scale, getattr(ke.qkv_format, args.qkv_format), diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu index 5e60bad776d4a..7068fc8fd0ebc 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu @@ -28,6 +28,7 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer { int64_t head_size, int64_t mask_dim, double scale, + bool causal, contrib::AttentionQkvFormat qkv_format, DeviceArray& Q, std::optional& K, @@ -51,7 +52,7 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer { attn_.v_hidden_size = attn_.hidden_size; // Q,K,V hidden size must agree now attn_.v_head_size = attn_.head_size; // Q,K,V hidden size must agree now attn_.num_heads = num_heads; - attn_.is_unidirectional = false; + attn_.is_unidirectional = causal; attn_.past_present_share_buffer = false; attn_.do_rotary = false; attn_.mask_filter_value = -10000.0f; @@ -148,6 +149,7 @@ class GemmSoftmaxGemmPermuteGeneric : public IGemmSoftmaxGemmPermuteKernelExplor int64_t head_size, int64_t mask_dim, double scale, + bool causal, contrib::AttentionQkvFormat qkv_format, DeviceArray& Q, std::optional& K, @@ -156,7 +158,7 @@ class GemmSoftmaxGemmPermuteGeneric : public IGemmSoftmaxGemmPermuteKernelExplor std::optional& attn_mask, DeviceArray& out) : IGemmSoftmaxGemmPermuteKernelExplorer(batch, seqlen, total_seqlen, max_seqlen, - num_heads, head_size, mask_dim, scale, qkv_format, + num_heads, head_size, mask_dim, scale, causal, qkv_format, Q, K, V, attn_bias, attn_mask, out) { this->SetWorkspace(GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(&this->attn_)); } @@ -187,6 +189,7 @@ class GemmSoftmaxGemmPermuteGenericNestedTunable : public GemmSoftmaxGemmPermute int64_t head_size, int64_t mask_dim, double scale, + bool causal, contrib::AttentionQkvFormat qkv_format, DeviceArray& Q, std::optional& K, @@ -195,7 +198,7 @@ class GemmSoftmaxGemmPermuteGenericNestedTunable : public GemmSoftmaxGemmPermute std::optional& attn_mask, DeviceArray& out) : GemmSoftmaxGemmPermuteGeneric(batch, seqlen, total_seqlen, max_seqlen, - num_heads, head_size, mask_dim, scale, qkv_format, + num_heads, head_size, mask_dim, scale, causal, qkv_format, Q, K, V, attn_bias, attn_mask, out) { this->params_.TuningContext()->EnableTunableOpAndTuning(); } @@ -214,6 +217,7 @@ class GemmSoftmaxGemmPermuteCK : public IGemmSoftmaxGemmPermuteKernelExplorer int64_t head_size, int64_t mask_dim, double scale, + bool causal, contrib::AttentionQkvFormat qkv_format, DeviceArray& Q, std::optional& K, @@ -222,7 +226,7 @@ class GemmSoftmaxGemmPermuteCK : public IGemmSoftmaxGemmPermuteKernelExplorer std::optional& attn_mask, DeviceArray& out) : IGemmSoftmaxGemmPermuteKernelExplorer(batch, seqlen, total_seqlen, max_seqlen, - num_heads, head_size, mask_dim, scale, qkv_format, + num_heads, head_size, mask_dim, scale, causal, qkv_format, Q, K, V, attn_bias, attn_mask, out) { this->SetWorkspace(GemmSoftmaxGemmPermuteTunableOp::GetWorkspaceNumBytes(&this->attn_)); @@ -275,6 +279,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor int64_t head_size, int64_t mask_dim, double scale, + bool causal, contrib::AttentionQkvFormat qkv_format, DeviceArray& Q, std::optional& K, @@ -283,7 +288,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor std::optional& attn_mask, DeviceArray& out) : IGemmSoftmaxGemmPermuteKernelExplorer(batch, seqlen, total_seqlen, max_seqlen, - num_heads, head_size, mask_dim, scale, qkv_format, + num_heads, head_size, mask_dim, scale, causal, qkv_format, Q, K, V, attn_bias, attn_mask, out) { this->SetWorkspace(std::max( GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(&this->attn_), @@ -311,7 +316,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor #define REGISTER_COMMON(name, type, ...) \ py::class_>(m, name) \ .def(py::init, int64_t, int64_t, int64_t, \ - float, contrib::AttentionQkvFormat, \ + float, bool, contrib::AttentionQkvFormat, \ DeviceArray&, \ std::optional&, \ std::optional&, \ From efc17e79de8c1a62eb419d19576ccb90b371b0d0 Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Fri, 2 Feb 2024 18:04:06 +0800 Subject: [PATCH 030/207] [js/webgpu] Fix the undefined push error (#19366) ### Description This PR fixes below errors when enable webgpu profiling: ``` TypeError: Cannot read properties of undefined (reading 'push') ``` --- js/web/lib/wasm/jsep/backend-webgpu.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 4b544595d76bb..98990a6fe477b 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -530,8 +530,10 @@ export class WebGpuBackend { }; this.pendingKernels.push(pendingKernelInfo); - const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); - sessionPendingKernels!.push(pendingKernelInfo); + if (this.sessionStatus === 'capturing') { + const sessionPendingKernels = this.capturedPendingKernels.get(this.currentSessionId!); + sessionPendingKernels!.push(pendingKernelInfo); + } } this.programManager.run(artifact, inputDatas, outputDatas, normalizedDispatchGroup, uniformBufferBinding); From 50806a7dd515a5378043117eb462362e773f6a8b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 2 Feb 2024 09:05:57 -0800 Subject: [PATCH 031/207] [js/web] support external data in npm test (#19377) ### Description support external data in npm test. This allows test runner to detect whether an external data is available in the test folder, and if it is, load it as external data automatically. this feature does not parse every model to figure out whether the model has external data. the following comments in code explained how to determine whether should parse the model file. ```js // for performance consideration, we do not parse every model. when we think it's likely to have external // data, we will parse it. We think it's "likely" when one of the following conditions is met: // 1. any file in the same folder has the similar file name as the model file // (e.g., model file is "model_abc.onnx", and there is a file "model_abc.pb" or "model_abc.onnx.data") // 2. the file size is larger than 1GB ``` --- js/web/script/test-runner-cli.ts | 49 +++++++++++++++++++++++++++++++- js/web/test/test-runner.ts | 11 +++---- js/web/test/test-types.ts | 1 + 3 files changed, 55 insertions(+), 6 deletions(-) diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index d56792c6e3595..9105c02412e34 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -12,6 +12,7 @@ import * as os from 'os'; import * as path from 'path'; import {inspect} from 'util'; +import {onnx} from '../lib/onnxjs/ort-schema/protobuf/onnx'; import {bufferToBase64} from '../test/test-shared'; import {Test} from '../test/test-types'; @@ -264,10 +265,12 @@ async function main() { let modelUrl: string|null = null; let cases: Test.ModelTestCase[] = []; + let externalData: Array<{data: string; path: string}>|undefined; npmlog.verbose('TestRunnerCli.Init.Model', `Start to prepare test data from folder: ${testDataRootFolder}`); try { + const maybeExternalDataFiles: Array<[fileNameWithoutExtension: string, size: number]> = []; for (const thisPath of fs.readdirSync(testDataRootFolder)) { const thisFullPath = path.join(testDataRootFolder, thisPath); const stat = fs.lstatSync(thisFullPath); @@ -282,6 +285,8 @@ async function main() { } else { throw new Error('there are multiple model files under the folder specified'); } + } else { + maybeExternalDataFiles.push([path.parse(thisPath).name, stat.size]); } } else if (stat.isDirectory()) { const dataFiles: string[] = []; @@ -307,6 +312,34 @@ async function main() { if (modelUrl === null) { throw new Error('there are no model file under the folder specified'); } + // for performance consideration, we do not parse every model. when we think it's likely to have external + // data, we will parse it. We think it's "likely" when one of the following conditions is met: + // 1. any file in the same folder has the similar file name as the model file + // (e.g., model file is "model_abc.onnx", and there is a file "model_abc.pb" or "model_abc.onnx.data") + // 2. the file size is larger than 1GB + const likelyToHaveExternalData = maybeExternalDataFiles.some( + ([fileNameWithoutExtension, size]) => + path.basename(modelUrl!).startsWith(fileNameWithoutExtension) || size >= 1 * 1024 * 1024 * 1024); + if (likelyToHaveExternalData) { + const model = onnx.ModelProto.decode(fs.readFileSync(path.join(testDataRootFolder, path.basename(modelUrl!)))); + const externalDataPathSet = new Set(); + for (const initializer of model.graph!.initializer!) { + if (initializer.externalData) { + for (const data of initializer.externalData) { + if (data.key === 'location') { + externalDataPathSet.add(data.value!); + } + } + } + } + externalData = []; + const externalDataPaths = [...externalDataPathSet]; + for (const dataPath of externalDataPaths) { + const fullPath = path.resolve(testDataRootFolder, dataPath); + const url = path.join(TEST_DATA_BASE, path.relative(TEST_ROOT, fullPath)); + externalData.push({data: url, path: dataPath}); + } + } } catch (e) { npmlog.error('TestRunnerCli.Init.Model', `Failed to prepare test data. Error: ${inspect(e)}`); throw e; @@ -340,9 +373,23 @@ async function main() { npmlog.verbose('TestRunnerCli.Init.Model', ` Model file: ${modelUrl}`); npmlog.verbose('TestRunnerCli.Init.Model', ` Backend: ${backend}`); npmlog.verbose('TestRunnerCli.Init.Model', ` Test set(s): ${cases.length} (${caseCount})`); + if (externalData) { + npmlog.verbose('TestRunnerCli.Init.Model', ` External data: ${externalData.length}`); + for (const data of externalData) { + npmlog.verbose('TestRunnerCli.Init.Model', ` - ${data.path}`); + } + } npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); - return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases, ioBinding}; + return { + name: path.basename(testDataRootFolder), + platformCondition, + modelUrl, + backend, + cases, + ioBinding, + externalData + }; } function tryLocateModelTestFolder(searchPattern: string): string { diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 442cb1bcf1f34..b01d474788f25 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -138,8 +138,8 @@ async function loadTensors( async function initializeSession( modelFilePath: string, backendHint: ort.InferenceSession.ExecutionProviderConfig, ioBindingMode: Test.IOBindingMode, - profile: boolean, sessionOptions: ort.InferenceSession.SessionOptions, - fileCache?: FileCacheBuffer): Promise { + profile: boolean, externalData: ort.InferenceSession.SessionOptions['externalData'], + sessionOptions: ort.InferenceSession.SessionOptions, fileCache?: FileCacheBuffer): Promise { const preloadModelData: Uint8Array|undefined = fileCache && fileCache[modelFilePath] ? fileCache[modelFilePath] : undefined; Logger.verbose( @@ -153,7 +153,8 @@ async function initializeSession( executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile, - preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined + preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, + externalData }; let session: ort.InferenceSession; @@ -246,8 +247,8 @@ export class ModelTestContext { const executionProviderConfig = modelTest.backend === 'webnn' ? (testOptions?.webnnOptions || 'webnn') : modelTest.backend!; const session = await initializeSession( - modelTest.modelUrl, executionProviderConfig, modelTest.ioBinding, profile, testOptions?.sessionOptions || {}, - this.cache); + modelTest.modelUrl, executionProviderConfig, modelTest.ioBinding, profile, modelTest.externalData, + testOptions?.sessionOptions || {}, this.cache); const initEnd = now(); diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index cd008e82e570b..14b9fd7c005ab 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -65,6 +65,7 @@ export declare namespace Test { export interface ModelTest { name: string; modelUrl: string; + externalData?: InferenceSession.SessionOptions['externalData']; backend?: string; // value should be populated at build time ioBinding: IOBindingMode; platformCondition?: PlatformCondition; From ccbe264a39cf05acaebfa326edc9b9039c1771fe Mon Sep 17 00:00:00 2001 From: Jiajia Qin Date: Sat, 3 Feb 2024 01:06:38 +0800 Subject: [PATCH 032/207] [js/webgpu] Add LeakyRelu activation for fusedConv (#19369) ### Description This PR 1) adds LeakyRelu activation for fusedConv; 2) makes `vec4` value work with `float32` uniforms attributes. For example: `clamp(value, vec4(uniforms.clip_min), vec4(uniforms.clip_max)` will throw compilation errors since `uniforms.clip_min` and `uniforms.clip_min` are `f32` not `f16`. So we need to change it to `clamp(value, vec4(f16(uniforms.clip_min)), vec4(f16(uniforms.clip_max))` And above problem was introduced when we make activation attributes as uniforms instead of constant. BTW, after adding LeakyRelu, `realesrgan-t256` model can pass. --- .../webgpu/ops/3rd-party/conv2d_mm_webgpu.ts | 2 +- .../ops/3rd-party/matmul_packed_webgpu.ts | 3 +- .../lib/wasm/jsep/webgpu/ops/conv-grouped.ts | 8 +- js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts | 47 +++--- js/web/lib/wasm/jsep/webgpu/ops/matmul.ts | 5 +- js/web/test/data/ops/fused-conv.jsonc | 144 ++++++++++++++++++ 6 files changed, 184 insertions(+), 25 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index fc2146068de70..24006d393592a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -130,7 +130,7 @@ const conv2dCommonSnippet = isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); const bType = isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); - const applyActivation = getActivationSnippet(attributes, resType); + const applyActivation = getActivationSnippet(attributes, resType, dataType); const userCode = ` fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { ${isChannelsLast ? sampleX : sampleW} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 8abc27a24861d..29c7941e6bd30 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -479,7 +479,8 @@ export const createMatmulProgramInfo = const uniforms: UniformsArrayType = [{name: 'dim_a_outer', type: 'i32'}, {name: 'dim_b_outer', type: 'i32'}, {name: 'dim_inner', type: 'i32'}]; appendActivationUniforms(activationAttributes, uniforms); - const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); const declareFunctions = matMulReadWriteFnSource( components, hasBias, applyActivation, [batchDims, A, B, output], [outerDimsA, outerDimsB, outerDims], isChannelsLast); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 8495f9040a1b6..7d424305c715f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types'; -import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common'; +import {createTensorShapeVariables, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; import {calculateOutputShape, ConvAttributes} from './conv'; import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from './fuse-utils'; @@ -45,7 +45,8 @@ export const createGroupedConvProgramInfo = const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShape.length); - const applyActivation = getActivationSnippet(attributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); const x = inputVariable('x', inputs[0].dataType, xShape.length); const w = inputVariable('w', inputs[1].dataType, wShape.length); const inputVars = [x, w]; @@ -136,7 +137,8 @@ export const createGroupedConvVectorizeProgramInfo = const xNumber = (outputNumber - 1) * attributes.strides[1] + wShape[1]; const getShaderSource = (shaderHelper: ShaderHelper) => { const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const applyActivation = getActivationSnippet(attributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(attributes, output.type.value, baseType); const x = inputVariable('x', inputs[0].dataType, xShape.length, components); const w = inputVariable('w', inputs[1].dataType, wShape.length, components); const inputVars = [x, w]; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts index 60067c014613b..6e66abacf3471 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/fuse-utils.ts @@ -15,24 +15,28 @@ export interface InternalActivationAttributes { readonly beta?: number; } -export const getActivationSnippet = (attributes: InternalActivationAttributes, valueType: string): string => { - switch (attributes.activation) { - case 'Relu': - return `value = max(value, ${valueType}(0.0));`; - case 'Sigmoid': - return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; - case 'Clip': - return `value = clamp(value, ${valueType}(uniforms.clip_min), ${valueType}(uniforms.clip_max));`; - case 'HardSigmoid': - return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${valueType}(uniforms.alpha) * value + ${ - valueType}(uniforms.beta)));`; - case '': - return ''; - // TODO: adding other activations that can be fused. - default: - throw new Error(`Unsupported activation ${attributes.activation}`); - } -}; +export const getActivationSnippet = + (attributes: InternalActivationAttributes, valueType: string, baseType = 'f32'): string => { + switch (attributes.activation) { + case 'Relu': + return `value = max(value, ${valueType}(0.0));`; + case 'Sigmoid': + return `value = (${valueType}(1.0) / (${valueType}(1.0) + exp(-value)));`; + case 'Clip': + return `value = clamp(value, ${valueType}(${baseType}(uniforms.clip_min)), ${valueType}(${ + baseType}(uniforms.clip_max)));`; + case 'HardSigmoid': + return `value = max(${valueType}(0.0), min(${valueType}(1.0), ${baseType}(uniforms.alpha) * value + ${ + baseType}(uniforms.beta)));`; + case 'LeakyRelu': + return `value = select(${baseType}(uniforms.alpha) * value, value, value >= ${valueType}(0.0));`; + case '': + return ''; + // TODO: adding other activations that can be fused. + default: + throw new Error(`Unsupported activation ${attributes.activation}`); + } + }; export const appendActivationUniformsData = (attributes: InternalActivationAttributes, programUniform: ProgramUniform[]) => { @@ -42,6 +46,8 @@ export const appendActivationUniformsData = } else if (attributes.activation === 'HardSigmoid') { programUniform.push( {type: DataType.float, data: attributes.alpha!}, {type: DataType.float, data: attributes.beta!}); + } else if (attributes.activation === 'LeakyRelu') { + programUniform.push({type: DataType.float, data: attributes.alpha!}); } }; @@ -50,6 +56,8 @@ export const appendActivationUniforms = (attributes: InternalActivationAttribute uniforms.push({name: 'clip_max', type: 'f32'}, {name: 'clip_min', type: 'f32'}); } else if (attributes.activation === 'HardSigmoid') { uniforms.push({name: 'alpha', type: 'f32'}, {name: 'beta', type: 'f32'}); + } else if (attributes.activation === 'LeakyRelu') { + uniforms.push({name: 'alpha', type: 'f32'}); } }; @@ -62,6 +70,9 @@ export const parseInternalActivationAttributes = } else if (activation === 'Clip') { const [clipMin, clipMax] = attributes?.activation_params as [number, number] || [MIN_CLIP, MAX_CLIP]; return {activation, clipMax, clipMin}; + } else if (activation === 'LeakyRelu') { + const [alpha] = attributes?.activation_params as [number] || [0.01]; + return {activation, alpha}; } return {activation}; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 0c533974e2b26..1a92d861002fb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -7,7 +7,7 @@ import {BroadcastUtil, ShapeUtil} from '../../util'; import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; -import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, UniformsArrayType,} from './common'; +import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet, InternalActivationAttributes} from './fuse-utils'; export const createNaiveMatmulProgramInfo = @@ -45,7 +45,8 @@ export const createNaiveMatmulProgramInfo = const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents); const b = inputVariable('b', inputs[1].dataType, bShape.length, components); const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components); - const applyActivation = getActivationSnippet(activationAttributes, output.type.value); + const baseType = tensorTypeToWsglStorageType(output.type.tensor); + const applyActivation = getActivationSnippet(activationAttributes, output.type.value, baseType); const inputVariables = [a, b]; let processBias = ''; if (hasBias) { diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index c734d6db9b92a..6a10e3b96a26a 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -286,5 +286,149 @@ ] } ] + }, + { + "name": "fused group-conv with LeakyRelu", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9, -6, 51, 47, -170, -10, 251, 229, 847, 889, 973, 1015], + "dims": [1, 3, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC group-conv with LeakyRelu", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, -3.0, 4.0, -5.0, 6.0, 7.0, 8.0, -9.0, -10.0, 11.0, -12.0, 13.0, -14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-162, 63, -158, 33, 281, 85, 105, 337, 455, 177, 515, 609], + "dims": [1, 2, 2, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "fused conv with LeakyRelu", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-540, -860, 390, 430], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "NHWC conv with LeakyRelu", + "operator": "Conv", + "attributes": [ + { "name": "activation", "data": "LeakyRelu", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [2.0], "type": "floats" } + ], + "opset": { "domain": "com.ms.internal.nhwc", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, -30, -40, -50, -60, 70, 80, 90], + "dims": [1, 3, 3, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [-540, -860, 390, 430], + "dims": [1, 2, 2, 1], + "type": "float32" + } + ] + } + ] } ] From 18c3acb1983e402248365b9f2babfaf8e2182e81 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 2 Feb 2024 10:16:37 -0800 Subject: [PATCH 033/207] update import in convert_generation.py (#19385) Fix for https://github.com/microsoft/onnxruntime/issues/19376 - Use absolute import instead of relative import for now. - Fix some typo --- .../tools/transformers/convert_generation.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 17f0dd0bc6078..a2cdd17e19fa5 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -55,10 +55,6 @@ import torch from benchmark_helper import Precision, setup_logger from fusion_utils import NumpyHelper -from models.gpt2.convert_to_onnx import main as convert_gpt2_to_onnx -from models.gpt2.gpt2_helper import PRETRAINED_GPT2_MODELS -from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models -from models.t5.t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS from onnx import GraphProto, ModelProto, TensorProto from onnx_model import OnnxModel from transformers import ( @@ -73,6 +69,10 @@ ) from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers +from onnxruntime.transformers.models.gpt2.convert_to_onnx import main as convert_gpt2_to_onnx +from onnxruntime.transformers.models.gpt2.gpt2_helper import PRETRAINED_GPT2_MODELS +from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models +from onnxruntime.transformers.models.t5.t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS logger = logging.getLogger("") @@ -372,7 +372,7 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: type=int, required=False, default=1, - help="Minimumber of tokens we keep per batch example in the output.", + help="Minimum number of tokens we keep per batch example in the output.", ) beam_parameters_group.add_argument( @@ -466,7 +466,7 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: "--save_test_data", required=False, action="store_true", - help="save test data for onnxruntimer_perf_test tool", + help="save test data for onnxruntime_perf_test tool", ) test_group.set_defaults(save_test_data=False) @@ -1225,7 +1225,7 @@ def find_past_seq_len_usage(subg: GraphProto): tensor_names_to_rename = set() nodes_to_remove = [] - graph_intput_names = {inp.name: index for index, inp in enumerate(subg.input)} + graph_input_names = {inp.name: index for index, inp in enumerate(subg.input)} input_name_to_nodes = {} output_name_to_node = {} @@ -1259,7 +1259,7 @@ def find_past_seq_len_usage(subg: GraphProto): if ( shape_node.op_type == "Shape" and shape_node.input[0] - and shape_node.input[0] in graph_intput_names + and shape_node.input[0] in graph_input_names and ( shape_node.input[0].startswith("past_key_self_") or shape_node.input[0].startswith("past_value_self_") @@ -1423,7 +1423,7 @@ def update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(subg: ModelP if node.op_type == "MultiHeadAttention": old_nodes.extend([node]) - # If not all the MultiheadAttention nodes are fused, this optimization is not applicable + # If not all the MultiHeadAttention nodes are fused, this optimization is not applicable if len(old_nodes) < num_layers: return False From debd1cab10fd82394a416711890dc15fddec88dc Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 3 Feb 2024 09:42:21 +1000 Subject: [PATCH 034/207] Add coremltools 7.1 as a dependency (#19389) ### Description Setup usage of coremltools via dependencies instead of copying files. Pull in some changes from https://github.com/microsoft/onnxruntime/pull/19347 in preparation for supporting ML Program and enabling building the ML Model on all platforms to make development and testing of CoreML EP code easier. - Update to coremltools 7.1 - Add patch for changes required for cross platform build of ML Program related code - Generate coreml proto files on all platforms - mainly to test these changes work everywhere, as the proto files will be used on all platforms when #19347 is checked in - rename onnxruntime_coreml_proto target to coreml_proto as it contains purely coreml protobuf code with no ORT related chagnes ### Motivation and Context Improve setup. --- cgmanifests/generated/cgmanifest.json | 10 + cmake/deps.txt | 3 +- .../external/onnxruntime_external_deps.cmake | 13 +- cmake/onnxruntime_providers.cmake | 2 +- cmake/onnxruntime_providers_coreml.cmake | 208 +- cmake/onnxruntime_unittests.cmake | 8 +- .../coremltools/crossplatformbuild.patch | 155 + .../providers/coreml/builders/coreml_spec.h | 2 +- .../coreml/builders/impl/builder_utils.cc | 2 +- .../ArrayFeatureExtractor.proto | 19 - .../BayesianProbitRegressor.proto | 139 - .../mlmodel_format/CategoricalMapping.proto | 38 - .../coreml/mlmodel_format/CustomModel.proto | 30 - .../mlmodel_format/DataStructures.proto | 95 - .../mlmodel_format/DictVectorizer.proto | 36 - .../coreml/mlmodel_format/FeatureTypes.proto | 224 - .../mlmodel_format/FeatureVectorizer.proto | 26 - .../coreml/mlmodel_format/GLMClassifier.proto | 43 - .../coreml/mlmodel_format/GLMRegressor.proto | 28 - .../coreml/mlmodel_format/Gazetteer.proto | 43 - .../coreml/mlmodel_format/Identity.proto | 18 - .../coreml/mlmodel_format/Imputer.proto | 43 - .../ItemSimilarityRecommender.proto | 93 - .../coreml/mlmodel_format/LinkedModel.proto | 42 - .../coreml/mlmodel_format/Model.proto | 322 - .../mlmodel_format/NearestNeighbors.proto | 132 - .../coreml/mlmodel_format/NeuralNetwork.proto | 6531 ----------------- .../NonMaximumSuppression.proto | 187 - .../coreml/mlmodel_format/Normalizer.proto | 38 - .../coreml/mlmodel_format/OneHotEncoder.proto | 41 - .../coreml/mlmodel_format/Parameters.proto | 52 - .../providers/coreml/mlmodel_format/README.md | 16 - .../providers/coreml/mlmodel_format/SVM.proto | 195 - .../coreml/mlmodel_format/Scaler.proto | 34 - .../SoundAnalysisPreprocessing.proto | 60 - .../mlmodel_format/TextClassifier.proto | 43 - .../coreml/mlmodel_format/TreeEnsemble.proto | 161 - .../mlmodel_format/VisionFeaturePrint.proto | 63 - .../coreml/mlmodel_format/WordEmbedding.proto | 35 - .../coreml/mlmodel_format/WordTagger.proto | 75 - .../templates/download-deps.yml | 4 +- 41 files changed, 297 insertions(+), 9012 deletions(-) create mode 100644 cmake/patches/coremltools/crossplatformbuild.patch delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/ArrayFeatureExtractor.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/BayesianProbitRegressor.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/CategoricalMapping.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/CustomModel.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/DataStructures.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/DictVectorizer.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/FeatureTypes.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/FeatureVectorizer.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/GLMClassifier.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/GLMRegressor.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/Gazetteer.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/Identity.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/Imputer.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/ItemSimilarityRecommender.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/LinkedModel.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/Model.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/NearestNeighbors.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/NeuralNetwork.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/NonMaximumSuppression.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/Normalizer.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/OneHotEncoder.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/Parameters.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/README.md delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/SVM.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/Scaler.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/SoundAnalysisPreprocessing.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/TextClassifier.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/TreeEnsemble.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/VisionFeaturePrint.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/WordEmbedding.proto delete mode 100644 onnxruntime/core/providers/coreml/mlmodel_format/WordTagger.proto diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 03e3f84547a68..efd901787fdb7 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -42,6 +42,16 @@ "comments": "abseil_cpp" } }, + { + "component": { + "type": "git", + "git": { + "commitHash": "dbb0094fd0cb936469e35320bf37e866ef7a1da4", + "repositoryUrl": "https://github.com/apple/coremltools.git" + }, + "comments": "coremltools" + } + }, { "component": { "type": "git", diff --git a/cmake/deps.txt b/cmake/deps.txt index ba9c2bb73cf7a..cb431f8c77397 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -13,6 +13,7 @@ # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240116.0.zip;bc2cec6baaad67fcb6c0c38972b687d4797927e9 +coremltools;https://github.com/apple/coremltools/archive/refs/tags/7.1.zip;f1bab0f30966f2e217d8e01207d518f230a1641a cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 @@ -55,4 +56,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 \ No newline at end of file +composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 403b4b2c4107a..22d12b128dc1f 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -224,8 +224,6 @@ FetchContent_Declare( URL_HASH SHA1=${DEP_SHA1_mp11} ) -set(JSON_BuildTests OFF CACHE INTERNAL "") -set(JSON_Install OFF CACHE INTERNAL "") set(JSON_BuildTests OFF CACHE INTERNAL "") set(JSON_Install OFF CACHE INTERNAL "") @@ -541,6 +539,17 @@ if(onnxruntime_ENABLE_TRAINING OR (onnxruntime_ENABLE_TRAINING_APIS AND onnxrunt onnxruntime_fetchcontent_makeavailable(cxxopts) endif() +if (onnxruntime_USE_COREML) + FetchContent_Declare( + coremltools + URL ${DEP_URL_coremltools} + URL_HASH SHA1=${DEP_SHA1_coremltools} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/coremltools/crossplatformbuild.patch + ) + # we don't build directly so use Populate. selected files are built from onnxruntime_providers_coreml.cmake + FetchContent_Populate(coremltools) +endif() + message("Finished fetching external dependencies") diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 8d3ea403fb74b..c6c9d8f4894c5 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -67,7 +67,7 @@ if(onnxruntime_USE_CUDA) endif() if(onnxruntime_USE_COREML) if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - set(PROVIDERS_COREML onnxruntime_providers_coreml onnxruntime_coreml_proto) + set(PROVIDERS_COREML onnxruntime_providers_coreml coreml_proto) else() set(PROVIDERS_COREML onnxruntime_providers_coreml) endif() diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake index aa8c35526b274..2ca4a22aca7d2 100644 --- a/cmake/onnxruntime_providers_coreml.cmake +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -1,107 +1,119 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. - if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) - message(FATAL_ERROR "CoreML EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") - endif() - - add_compile_definitions(USE_COREML=1) - - # Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - set(COREML_PROTO_ROOT ${PROJECT_SOURCE_DIR}/../onnxruntime/core/providers/coreml/mlmodel_format) - file(GLOB coreml_proto_srcs - "${COREML_PROTO_ROOT}/*.proto" - ) - onnxruntime_add_static_library(onnxruntime_coreml_proto ${coreml_proto_srcs}) - target_include_directories(onnxruntime_coreml_proto PUBLIC $ "${CMAKE_CURRENT_BINARY_DIR}") - target_compile_definitions(onnxruntime_coreml_proto PUBLIC $) - set_target_properties(onnxruntime_coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") - set_target_properties(onnxruntime_coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") - set(_src_sub_dir "coreml/") - onnxruntime_protobuf_generate( - APPEND_PATH - GEN_SRC_SUB_DIR ${_src_sub_dir} - IMPORT_DIRS ${COREML_PROTO_ROOT} - TARGET onnxruntime_coreml_proto - ) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_coreml_proto - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR} - ) - endif() - endif() - - # These are shared utils, - # TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML - file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" - ) +if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD) + message(FATAL_ERROR "CoreML EP can not be used in a basic minimal build. Please build with '--minimal_build extended'") +endif() - file(GLOB - onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.cc" - ) +add_compile_definitions(USE_COREML=1) - # Add builder source code - file(GLOB_RECURSE - onnxruntime_providers_coreml_cc_srcs_nested CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" - ) - if (NOT CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(REMOVE_ITEM onnxruntime_providers_coreml_cc_srcs_nested - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.cc" - ) - endif() - - # Add CoreML objective c++ source code - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - file(GLOB - onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.mm" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" - ) - endif() - - set(onnxruntime_providers_coreml_cc_srcs - ${onnxruntime_providers_coreml_cc_srcs_top} - ${onnxruntime_providers_coreml_cc_srcs_nested} - ${onnxruntime_providers_shared_utils_cc_srcs} +# Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml_proto +set(COREML_PROTO_ROOT ${coremltools_SOURCE_DIR}/mlmodel/format) +file(GLOB coreml_proto_srcs "${COREML_PROTO_ROOT}/*.proto") + +onnxruntime_add_static_library(coreml_proto ${coreml_proto_srcs}) +target_include_directories(coreml_proto + PUBLIC $ + "${CMAKE_CURRENT_BINARY_DIR}") +target_compile_definitions(coreml_proto + PUBLIC $) +set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") +set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") +set(_src_sub_dir "coreml_proto/") + +onnxruntime_protobuf_generate( + APPEND_PATH + GEN_SRC_SUB_DIR ${_src_sub_dir} + IMPORT_DIRS ${COREML_PROTO_ROOT} + TARGET coreml_proto +) + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS coreml_proto + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR} ) +endif() + +# Add the .proto and generated .cc/.h files to the External/coreml_proto folder in Visual Studio. +# Separate source_group for each as the .proto files are in the repo and the .cc/.h files are generated in the build +# output directory. +set_target_properties(coreml_proto PROPERTIES FOLDER "External") +source_group(TREE ${COREML_PROTO_ROOT} PREFIX coreml_proto FILES ${coreml_proto_srcs}) + +# filter to the generated .cc/.h files +get_target_property(coreml_proto_generated_srcs coreml_proto SOURCES) +list(FILTER coreml_proto_generated_srcs INCLUDE REGEX "\.pb\.(h|cc)$") +source_group(TREE ${CMAKE_CURRENT_BINARY_DIR} PREFIX coreml_proto_generated FILES ${coreml_proto_generated_srcs}) - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_coreml_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_coreml - ${onnxruntime_providers_coreml_cc_srcs} ${onnxruntime_providers_coreml_objcc_srcs} +# These are shared utils, +# TODO, move this to a separated lib when used by EPs other than NNAPI and CoreML +file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" +) + +file(GLOB + onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.cc" +) + +# Add builder source code +file(GLOB_RECURSE + onnxruntime_providers_coreml_cc_srcs_nested CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" +) +if (NOT CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") + list(REMOVE_ITEM onnxruntime_providers_coreml_cc_srcs_nested + "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.cc" ) - onnxruntime_add_include_to_target(onnxruntime_providers_coreml - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface +endif() + +# Add CoreML objective c++ source code +if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") + file(GLOB + onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.mm" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" ) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - onnxruntime_add_include_to_target(onnxruntime_providers_coreml onnxruntime_coreml_proto) - target_link_libraries(onnxruntime_providers_coreml PRIVATE onnxruntime_coreml_proto "-framework Foundation" "-framework CoreML") - add_dependencies(onnxruntime_providers_coreml onnxruntime_coreml_proto) - endif() - add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) - - set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) - set_target_properties(onnxruntime_providers_coreml PROPERTIES FOLDER "ONNXRuntime") - target_include_directories(onnxruntime_providers_coreml PRIVATE ${ONNXRUNTIME_ROOT} ${coreml_INCLUDE_DIRS}) - set_target_properties(onnxruntime_providers_coreml PROPERTIES LINKER_LANGUAGE CXX) - - if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_providers_coreml - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file +endif() + +set(onnxruntime_providers_coreml_cc_srcs + ${onnxruntime_providers_coreml_cc_srcs_top} + ${onnxruntime_providers_coreml_cc_srcs_nested} + ${onnxruntime_providers_shared_utils_cc_srcs} +) + +source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_coreml_cc_srcs}) +onnxruntime_add_static_library(onnxruntime_providers_coreml + ${onnxruntime_providers_coreml_cc_srcs} ${onnxruntime_providers_coreml_objcc_srcs} +) +onnxruntime_add_include_to_target(onnxruntime_providers_coreml + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface +) +if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") + onnxruntime_add_include_to_target(onnxruntime_providers_coreml coreml_proto) + target_link_libraries(onnxruntime_providers_coreml PRIVATE coreml_proto "-framework Foundation" "-framework CoreML") + add_dependencies(onnxruntime_providers_coreml coreml_proto) +endif() +add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) + +set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) +set_target_properties(onnxruntime_providers_coreml PROPERTIES FOLDER "ONNXRuntime") +target_include_directories(onnxruntime_providers_coreml PRIVATE ${ONNXRUNTIME_ROOT} ${coreml_INCLUDE_DIRS}) +set_target_properties(onnxruntime_providers_coreml PROPERTIES LINKER_LANGUAGE CXX) + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_providers_coreml + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 6a4551ad94d9e..5b4a007d6b974 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -566,7 +566,7 @@ endif() if(onnxruntime_USE_COREML) if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml onnxruntime_coreml_proto) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) else() list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) endif() @@ -675,9 +675,9 @@ endif() if(onnxruntime_USE_COREML) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/coreml/*) if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml onnxruntime_coreml_proto) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml onnxruntime_coreml_proto) - list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml onnxruntime_coreml_proto) + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) + list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto) else() list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) diff --git a/cmake/patches/coremltools/crossplatformbuild.patch b/cmake/patches/coremltools/crossplatformbuild.patch new file mode 100644 index 0000000000000..7f2268f50c82e --- /dev/null +++ b/cmake/patches/coremltools/crossplatformbuild.patch @@ -0,0 +1,155 @@ +diff --git a/mlmodel/src/MILBlob/Blob/FileWriter.cpp b/mlmodel/src/MILBlob/Blob/FileWriter.cpp +index adc7bfcf..7b2bf9cc 100644 +--- a/mlmodel/src/MILBlob/Blob/FileWriter.cpp ++++ b/mlmodel/src/MILBlob/Blob/FileWriter.cpp +@@ -8,8 +8,12 @@ + + #include + #include ++ ++// ORT_EDIT: Exclude mmap on Windows. Not used in this file anyway. ++#if !defined(_WIN32) + #include + #include ++#endif + + using namespace MILBlob; + using namespace MILBlob::Blob; +diff --git a/mlmodel/src/MILBlob/Fp16.cpp b/mlmodel/src/MILBlob/Fp16.cpp +index ae1e71a1..77a7161f 100644 +--- a/mlmodel/src/MILBlob/Fp16.cpp ++++ b/mlmodel/src/MILBlob/Fp16.cpp +@@ -5,6 +5,8 @@ + + #include "MILBlob/Fp16.hpp" + ++// ORT_EDIT: Exclude clang specific pragmas from other builds ++#if defined(__clang__) + // fp16 lib code has some conversion warnings we don't want to globally ignore + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wincompatible-pointer-types" +@@ -12,6 +14,9 @@ + #pragma clang diagnostic ignored "-Wconversion" + #include "fp16/fp16.h" + #pragma clang diagnostic pop ++#else ++#include "fp16/fp16.h" ++#endif + + using namespace MILBlob; + +diff --git a/modelpackage/src/ModelPackage.cpp b/modelpackage/src/ModelPackage.cpp +index 8fee56b9..99e0d8d6 100644 +--- a/modelpackage/src/ModelPackage.cpp ++++ b/modelpackage/src/ModelPackage.cpp +@@ -26,7 +26,14 @@ namespace std { + #else + #error "missing required header " + #endif ++ ++// ORT_EDIT: Use UuidCreate on Windows. ++#if defined(_WIN32) ++#pragma comment(lib, "rpcrt4.lib") // UuidCreate ++#include ++#else + #include ++#endif + #include + + #if defined(__cplusplus) +@@ -187,7 +194,10 @@ public: + ModelPackageItemInfo createFile(const std::string& name, const std::string& author, const std::string& description); + }; + ++// ORT_EDIT: pragma only available on APPLE platforms ++#if defined(__APPLE__) + #pragma mark ModelPackageImpl ++#endif + + ModelPackageImpl::ModelPackageImpl(const std::filesystem::path& path, bool createIfNecessary, bool readOnly) + : m_packagePath(path), +@@ -372,6 +382,20 @@ std::filesystem::path ModelPackageImpl::getItemPath(const std::string& name, con + } + + std::string ModelPackageImpl::generateIdentifier() const { ++// ORT_EDIT: Use built-in UUID generation on Windows ++#if defined(_WIN32) ++ UUID uuid; ++ UuidCreate(&uuid); ++ ++ RPC_CSTR uuidStr; ++ UuidToStringA(&uuid, &uuidStr); ++ ++ std::string uuidStrCpp(reinterpret_cast(uuidStr)); ++ ++ RpcStringFreeA(&uuidStr); ++ ++ return uuidStrCpp; ++#else + uuid_t uuid; + + // uuid_unparse generates a 36-character null-terminated string (37 bytes). +@@ -383,6 +407,7 @@ std::string ModelPackageImpl::generateIdentifier() const { + uuid_unparse(uuid, buf); + + return std::string(buf); ++#endif + } + + ModelPackageItemInfo ModelPackageImpl::createFile(const std::string& name, const std::string& author, const std::string& description) { +@@ -468,7 +493,13 @@ std::shared_ptr ModelPackageImpl::findItem(const std::stri + auto author = itemInfoEntry->getString(kModelPackageItemInfoAuthorKey); + auto description = itemInfoEntry->getString(kModelPackageItemInfoDescriptionKey); + ++// ORT_EDIT: need to use path.string() on Windows ++#if defined(_WIN32) ++ return std::make_shared(std::make_shared(identifier, path.string(), name, author, description)); ++ ++#else + return std::make_shared(std::make_shared(identifier, path, name, author, description)); ++#endif + } + + std::shared_ptr ModelPackageImpl::findItem(const std::string& name, const std::string& author) const +@@ -514,7 +545,9 @@ void ModelPackageImpl::removeItem(const std::string& identifier) + } + + auto path = m_packageDataDirPath / itemInfoEntry->getString(kModelPackageItemInfoPathKey); +- if (0 != std::remove(path.c_str())) { ++ // ORT_EDIT: std::remove doesn't work on Windows. Use std::filesystem::remove instead. ++ // if (0 != std::remove(path.c_str())) { ++ if (!std::filesystem::remove(path)) { + throw std::runtime_error("Failed to remove file at path: " + path.string()); + } + +@@ -525,13 +558,16 @@ bool ModelPackageImpl::isValid(const std::filesystem::path& path) + { + try { + ModelPackageImpl(path, false, true); +- } catch (std::runtime_error& e) { ++ } catch (std::runtime_error& /*e*/) { // ORT_EDIT: comment out unused variable + return false; + } + return true; + } + ++// ORT_EDIT: pragma only available on APPLE platforms ++#if defined(__APPLE__) + #pragma mark ModelPackage ++#endif + + ModelPackage::ModelPackage(const std::string& packagePath, bool createIfNecessary, bool readOnly) + : m_modelPackageImpl(std::make_shared(packagePath, createIfNecessary, readOnly)) +@@ -544,7 +580,12 @@ ModelPackage::~ModelPackage() + + std::string ModelPackage::path() const + { ++// ORT_EDIT: Windows doesn't automatically convert to std::string as the native format could be char or wchar. ++#if defined(_WIN32) ++ return m_modelPackageImpl->path().string(); ++#else + return m_modelPackageImpl->path(); ++#endif + } + + std::string ModelPackage::setRootModel(const std::string& path, const std::string& name, const std::string& author, const std::string& description) diff --git a/onnxruntime/core/providers/coreml/builders/coreml_spec.h b/onnxruntime/core/providers/coreml/builders/coreml_spec.h index 631bb7e258303..e9cd4af94e5fd 100644 --- a/onnxruntime/core/providers/coreml/builders/coreml_spec.h +++ b/onnxruntime/core/providers/coreml/builders/coreml_spec.h @@ -9,6 +9,6 @@ #error "This file should only be included when building on Apple platforms." #endif -#include "coreml/Model.pb.h" +#include "coreml_proto/Model.pb.h" namespace COREML_SPEC = CoreML::Specification; diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index 3b7bd5c1840cc..ef66e6b877a1f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -11,7 +11,7 @@ #include "core/providers/shared/utils/utils.h" #include "core/optimizer/initializer.h" -#include "coreml/NeuralNetwork.pb.h" +#include "coreml_proto/NeuralNetwork.pb.h" namespace onnxruntime { namespace coreml { diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/ArrayFeatureExtractor.proto b/onnxruntime/core/providers/coreml/mlmodel_format/ArrayFeatureExtractor.proto deleted file mode 100644 index 2b83ccbe3574f..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/ArrayFeatureExtractor.proto +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * An array feature extractor. - * - * Given an index, extracts the value at that index from its array input. - * Indexes are zero-based. - */ -message ArrayFeatureExtractor { - repeated uint64 extractIndex = 1; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/BayesianProbitRegressor.proto b/onnxruntime/core/providers/coreml/mlmodel_format/BayesianProbitRegressor.proto deleted file mode 100644 index 9688d87ce48ba..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/BayesianProbitRegressor.proto +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** -* A Bayesian probit regressor. -* -* The probit regression model is superficially similar to the more commonly known -* logistic regression, with sampling distribution of the model given by -* -* P(y=+1|x,w) = Φ(/β) -* -* where w are the set of weights, -* x are the set of features for the given event, -* β is a model hyper-parameter, and -* Φ is the link function, defined to be the CDF of the normal distribution. -* The weights w[i,j] are Gaussian distributed, with mean μ[i,j] and precision 1/(σ[i,j])^2 -* (where i indexes over features and j indexes over the values for the feature). -* The parameter β scales the steepness of the inverse link function. -* -* (see https://en.wikipedia.org/wiki/Probit_model and https://en.wikipedia.org/wiki/Logistic_regression -* for more details on probit model and logistic regression, respectively) -* -* Input: X -* x represents a set of features, each taking on a discrete value (note that continuous values -* would first need to be discretized). x can be represented as a vector where the index i is -* the feature id and x[i] is the feature value. Alternatively, x can be represented as a matrix -* with 2 columns where the first column indicates the feature id and the second column contains -* the feature values, i.e. x[i,0] is the feature id and x[i,1] is the feature value. -* -* additional input features: -* - "optimism": apply a mean shift to the probability, i.e. shift regression mean by o*stdev, -* where o is the "optimism" parameter (see additional output features) -* - "samplingScale": for sampling from posterior, multiply standard deviation by this factor -* - "samplingTruncation": for sampling from posterior, truncate sampling distribution at given multiple of std from mean -* -* Output: Y -* probability P(y|x,w) -* -* additional output features: -* - mean (regression output before applying link function) -* - variance (regression output variance before applying link function) -* - pessimistic probability: P(y|x,w) with a mean shift parameterized by "optimism" feature -* - sampled probability: p ~ P(y|x,w) with standard deviation scaling parametrized by "samplingScale" feature -* and distribution truncated at multiple of standard deviation, -* where multiple parameterized by "samplingTruncation" feature. -* -*/ - -message BayesianProbitRegressor { - - /* - * Parameterization of a Gaussian distribution - */ - message Gaussian { - double mean = 1; - double precision = 2; // inverse of the variance - } - - /* - * Weight for a specific feature value - * The weight is represented as a Gaussian distribution - * with a mean and precision (1/variance) to capture - * uncertainty in the weight - */ - message FeatureValueWeight { - uint32 featureValue = 1; - Gaussian featureWeight = 2; - } - - /* - * Feature with associated weights (for different values) - * Each feature has a set of weights for the (discrete) values - * it can take - */ - message FeatureWeight { - uint32 featureId = 1; - repeated FeatureValueWeight weights = 2; - } - - uint32 numberOfFeatures = 1; - - Gaussian bias = 2; // bias term - - /* - * Set of features with associated weights - */ - repeated FeatureWeight features = 3; // feature weights - - /* - * Set this name to be the same as input feature of type multi-array (1D) - * in the model description you want to use as the regression input - */ - string regressionInputFeatureName = 10; - - /* - * Set this name to be the same as optional input feature of type double - * in the model description you want to use as the optimism input - */ - string optimismInputFeatureName = 11; - - /* - * Set this name to be the same as optional input feature of type double - * in the model description you want to use as the samplingScale input - */ - string samplingScaleInputFeatureName = 12; - - /* - * Set this name to be the same as optional input feature of type double - * in the model description you want to use as the samplingBounds input - */ - string samplingTruncationInputFeatureName = 13; - - /* - * name of 'mean' output feature - */ - string meanOutputFeatureName = 20; - - /* - * name of 'variance' output feature - */ - string varianceOutputFeatureName = 21; - - /* - * name of 'pessimistic' output feature - */ - string pessimisticProbabilityOutputFeatureName = 22; - - /* - * name of 'sampled' output feature: samples from the scaled posterior probability distribuiton - */ - string sampledProbabilityOutputFeatureName = 23; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/CategoricalMapping.proto b/onnxruntime/core/providers/coreml/mlmodel_format/CategoricalMapping.proto deleted file mode 100644 index 23112d074213a..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/CategoricalMapping.proto +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * A categorical mapping. - * - * This allows conversion from integers to strings, or from strings to integers. - */ -message CategoricalMapping { - oneof MappingType { - // Conversion from strings to integers - StringToInt64Map stringToInt64Map = 1; - - // Conversion from integer to string - Int64ToStringMap int64ToStringMap = 2; - } - - /** - * The value returned if an input is not contained in the map above. - * If one of these is not set, then an error is raised on an unknown input. - */ - oneof ValueOnUnknown { - // Default output when converting from an integer to a string. - string strValue = 101; - - // Default output when converting from a string to an integer. - int64 int64Value = 102; - } -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/CustomModel.proto b/onnxruntime/core/providers/coreml/mlmodel_format/CustomModel.proto deleted file mode 100644 index 9a6d36e009ada..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/CustomModel.proto +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** -* A parameterized model whose function is defined in code -*/ -message CustomModel { - - message CustomModelParamValue { - oneof value { - double doubleValue = 10; - string stringValue = 20; - int32 intValue = 30; - int64 longValue = 40; - bool boolValue = 50; - bytes bytesValue = 60; - } - } - - string className = 10; // The name of the class (conforming to MLCustomModel) corresponding to this model - map parameters = 30; - string description = 40; // An (optional) description provided by the model creator. This information is displayed when viewing the model, but does not affect the model's execution on device. -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/DataStructures.proto b/onnxruntime/core/providers/coreml/mlmodel_format/DataStructures.proto deleted file mode 100644 index 8b120c2d7d102..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/DataStructures.proto +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "FeatureTypes.proto"; - -package CoreML.Specification; - -/** - * A mapping from a string - * to a 64-bit integer. - */ -message StringToInt64Map { - map map = 1; -} - -/** - * A mapping from a 64-bit integer - * to a string. - */ -message Int64ToStringMap { - map map = 1; -} - -/** - * A mapping from a string - * to a double-precision floating point number. - */ -message StringToDoubleMap { - map map = 1; -} - -/** - * A mapping from a 64-bit integer - * to a double-precision floating point number. - */ -message Int64ToDoubleMap { - map map = 1; -} - -/** - * A vector of strings. - */ -message StringVector { - repeated string vector = 1; -} - -/** - * A vector of 64-bit integers. - */ -message Int64Vector { - repeated int64 vector = 1; -} - -/** - * A vector of floating point numbers. - */ -message FloatVector { - repeated float vector = 1; -} - -/** - * A vector of double-precision floating point numbers. - */ -message DoubleVector { - repeated double vector = 1; -} - -/** - * A range of int64 values - */ -message Int64Range { - int64 minValue = 1; - int64 maxValue = 2; -} - -/** - * A set of int64 values - */ -message Int64Set { - repeated int64 values = 1; -} - -/** - * A range of double values - */ -message DoubleRange { - double minValue = 1; - double maxValue = 2; -} - diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/DictVectorizer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/DictVectorizer.proto deleted file mode 100644 index 3f94eeec1745c..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/DictVectorizer.proto +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * Uses an index mapping to convert a dictionary to an array. - * - * The output array will be equal in length to the index mapping vector parameter. - * All keys in the input dictionary must be present in the index mapping vector. - * - * For each item in the input dictionary, insert its value in the output array. - * The position of the insertion is determined by the position of the item's key - * in the index mapping. Any keys not present in the input dictionary, will be - * zero in the output array. - * - * For example: if the ``stringToIndex`` parameter is set to ``["a", "c", "b", "z"]``, - * then an input of ``{"a": 4, "c": 8}`` will produce an output of ``[4, 8, 0, 0]``. - * - */ -message DictVectorizer { - oneof Map { - /// String keys to indexes - StringVector stringToIndex = 1; - - /// Int keys to indexes - Int64Vector int64ToIndex = 2; - } -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/FeatureTypes.proto b/onnxruntime/core/providers/coreml/mlmodel_format/FeatureTypes.proto deleted file mode 100644 index 8711ac7de3026..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/FeatureTypes.proto +++ /dev/null @@ -1,224 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * The 64-bit integer feature type. - */ -message Int64FeatureType {} - -/** - * The double-precision floating point number feature type. - */ -message DoubleFeatureType {} - -/** - * The string feature type. - */ -message StringFeatureType {} - - -message SizeRange { - uint64 lowerBound = 1; - int64 upperBound = 2; // negative value means unbound otherwise upperbound is included in range -} - -/** - * The image feature type. - */ -message ImageFeatureType { - // Assumes raw (decompressed) format - enum ColorSpace { - INVALID_COLOR_SPACE = 0; - GRAYSCALE = 10; // 8 bits per pixel - RGB = 20; // 32 bits per pixel: RGBA with A channel ignored - BGR = 30; // 32 bits per pixel: BGRA with A channel ignored - } - - message ImageSize { - uint64 width = 1; - uint64 height = 2; - } - - message EnumeratedImageSizes { - repeated ImageSize sizes = 1; - } - - message ImageSizeRange { - SizeRange widthRange = 1; - SizeRange heightRange = 2; - } - - // The required or default image size is width x height - // - // If specificationVersion <= 2 or SizeFlexibility is empty, - // width x height is the required fixed image size - // - // If SizeFlexibility is present, width x height indicate a "default" - // image size which must be consistent with the flexibilty specified - - int64 width = 1; - int64 height = 2; - - // For specification version >= 3 you can specify image size flexibility. - - oneof SizeFlexibility { - - // Use enumeratedSizes for a set of distinct fixed sizes - // e.g. portrait or landscape: [80 x 100, 100 x 8] - // - // If the width x height fields above are specified then they must be - // one of the sizes listed. - // - // If width and height are not specified above then the default width - // and height will be enumeratedSizes[0] - // - // Must be non-empty - - EnumeratedImageSizes enumeratedSizes = 21; - - // Use imageSizeRange to allow for ranges of values - // e.g. any image greater than 10 x 20: [10..= 3 you can specify image size flexibility. - - oneof ShapeFlexibility { - - // Use enumeratedShapes for a set of distinct fixed shapes - // - // If the shape field is specified then it must be - // one of the enumerated shapes. - /// - // If shape is not specifed, the "default" shape will be considered - // enumeratedShapes[0] - // - // Must be non-empty - - EnumeratedShapes enumeratedShapes = 21; - - // Use shapeRange to allow the size of each dimension vary within - // indpendently specified ranges - // - // If you specify shape above it must fall in the range - // specified in shapeRanges. It will be treated as the default shape. - // - // If you don't specify shape above then the default shape will - // have shape[d] = shapeRange.sizeRanges[d].lowerBound - - ShapeRange shapeRange = 31; - - } - - oneof defaultOptionalValue { - int32 intDefaultValue = 41; - float floatDefaultValue = 51; - double doubleDefaultValue = 61; - } - -} - -/** - * The dictionary feature type. - */ -message DictionaryFeatureType { - /** - * Key/value type tags, with the following restrictions: - * - ``keyType`` must be a hashable type - * - ``valueType`` is assumed to be a ``double`` - */ - oneof KeyType { - Int64FeatureType int64KeyType = 1; - StringFeatureType stringKeyType = 2; - } -} - -/** - * The Sequence feature type. - */ -message SequenceFeatureType { - - /** - * Currently only categorical int64 and String sequences are supported - */ - oneof Type { - Int64FeatureType int64Type = 1; - StringFeatureType stringType = 3; - } - - // Range of allowed size/length/count of sequence - SizeRange sizeRange = 101; -} - -/** - * A feature, which may be optional. - */ -message FeatureType { - oneof Type { - Int64FeatureType int64Type = 1; - DoubleFeatureType doubleType = 2; - StringFeatureType stringType = 3; - ImageFeatureType imageType = 4; - ArrayFeatureType multiArrayType = 5; - DictionaryFeatureType dictionaryType = 6; - SequenceFeatureType sequenceType = 7; - } - - bool isOptional = 1000; -} - diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/FeatureVectorizer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/FeatureVectorizer.proto deleted file mode 100644 index 75eaf14b53669..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/FeatureVectorizer.proto +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * A FeatureVectorizer puts one or more features into a single array. - * - * The ordering of features in the output array is determined by - * ``inputList``. - * - * ``inputDimensions`` is a zero based index. - */ -message FeatureVectorizer { - message InputColumn { - string inputColumn = 1; - uint64 inputDimensions = 2; - } - - repeated InputColumn inputList = 1; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/GLMClassifier.proto b/onnxruntime/core/providers/coreml/mlmodel_format/GLMClassifier.proto deleted file mode 100644 index 47f6f4a3c7b8c..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/GLMClassifier.proto +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * A generalized linear model classifier. - */ -message GLMClassifier { - message DoubleArray { - repeated double value = 1; - } - - enum PostEvaluationTransform { - Logit = 0; - Probit = 1; /// Only binary classification is supported for probit - } - - enum ClassEncoding { - ReferenceClass = 0; /// First class is the reference class - OneVsRest = 1; /// Also called One vs All - } - - repeated DoubleArray weights = 1; - repeated double offset = 2; - PostEvaluationTransform postEvaluationTransform = 3; - ClassEncoding classEncoding = 4; - - /** - * Required class label mapping. - */ - oneof ClassLabels { - StringVector stringClassLabels = 100; - Int64Vector int64ClassLabels = 101; - } -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/GLMRegressor.proto b/onnxruntime/core/providers/coreml/mlmodel_format/GLMRegressor.proto deleted file mode 100644 index 64093c4f156a8..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/GLMRegressor.proto +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * A generalized linear model regressor. - */ -message GLMRegressor { - message DoubleArray { - repeated double value = 1; - } - - enum PostEvaluationTransform { - NoTransform = 0; - Logit = 1; - Probit = 2; - } - - repeated DoubleArray weights = 1; - repeated double offset = 2; - PostEvaluationTransform postEvaluationTransform = 3; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Gazetteer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Gazetteer.proto deleted file mode 100644 index 6abbffaf623b9..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Gazetteer.proto +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2019, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification.CoreMLModels; - -/** -* A model which uses an efficient probabilistic representation -* for assigning labels to a set of strings. -*/ -message Gazetteer { - - /* - * Stores the revision number for the model, revision 2 is available on - * iOS, tvOS 13.0+, macOS 10.15+ - */ - uint32 revision = 1; - - /* - * Stores the language of the model, as specified in BCP-47 format, - * e.g. "en-US". See https://tools.ietf.org/html/bcp47 - */ - string language = 10; - - /* - * Natural Lanaguge framework's efficient representation of a gazetter. - */ - bytes modelParameterData = 100; - - /* - * Stores the set of output class labels - */ - oneof ClassLabels { - StringVector stringClassLabels = 200; - } - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Identity.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Identity.proto deleted file mode 100644 index 123a15e59156d..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Identity.proto +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * An identity model. - * - * This model returns given inputs as outputs, unchanged. - * Intended to be used for testing purposes. - */ -message Identity { -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Imputer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Imputer.proto deleted file mode 100644 index 3de280b2f162d..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Imputer.proto +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * A transformer that replaces missing values with a default value, - * such as a statistically-derived value. - * - * If ``ReplaceValue`` is set, then missing values of that type are - * replaced with the corresponding value. - * - * For example: if ``replaceDoubleValue`` is set to ``NaN`` - * and a single ``NaN`` double value is provided as input, - * then it is replaced by ``imputedDoubleValue``. However - * if the input is an array of doubles, then any instances - * of ``NaN`` in the array is replaced with the corresponding - * value in ``imputedDoubleArray``. - */ -message Imputer { - oneof ImputedValue { - double imputedDoubleValue = 1; - int64 imputedInt64Value = 2; - string imputedStringValue = 3; - DoubleVector imputedDoubleArray = 4; - Int64Vector imputedInt64Array = 5; - StringToDoubleMap imputedStringDictionary = 6; - Int64ToDoubleMap imputedInt64Dictionary = 7; - } - - oneof ReplaceValue { - double replaceDoubleValue = 11; - int64 replaceInt64Value = 12; - string replaceStringValue = 13; - } -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/ItemSimilarityRecommender.proto b/onnxruntime/core/providers/coreml/mlmodel_format/ItemSimilarityRecommender.proto deleted file mode 100644 index a5a8c11092d36..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/ItemSimilarityRecommender.proto +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -/** - * Each tree is a collection of nodes, - * each of which is identified by a unique identifier. - * - * Each node is either a branch or a leaf node. - * A branch node evaluates a value according to a behavior; - * if true, the node identified by ``true_child_node_id`` is evaluated next, - * if false, the node identified by ``false_child_node_id`` is evaluated next. - * A leaf node adds the evaluation value to the base prediction value - * to get the final prediction. - * - * A tree must have exactly one root node, - * which has no parent node. - * A tree must not terminate on a branch node. - * All leaf nodes must be accessible - * by evaluating one or more branch nodes in sequence, - * starting from the root node. - */ - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - - -/** - * Item Similarity Recommender - * - * The Item Similarity recommender takes as input a list of items and scores, - * then uses that information and a table of item similarities to predict similarity - * scores for all items. By default, the items predicted are most similar to the given - * items but not part of that item set. - * - * The predicted score for a given item k is - * sum_(i in observed items) sim_(k,i) * (score_i - shift_k) - * - * Because only the most similar scores for each item i are stored, - * sim_(k,i) is often zero. - * - * For many models, the score adjustment parameter shift_j is zero -- it's occasionally used - * to counteract global biases for popular items. - * - * - * References: - */ -message ItemSimilarityRecommender { - - /** The items similar to a given base item. - */ - message ConnectedItem { - uint64 itemId = 1; - double similarityScore = 2; - } - - /** The formula for the score of a given model as given above, with shift_k - * parameter given by itemScoreAdjustment, and the similar item list filling in - * all the known sim(k,i) scores for i given by itemID and k given by the itemID parameter in - * the similarItemList. - */ - message SimilarItems { - uint64 itemId = 1; - repeated ConnectedItem similarItemList = 2; - double itemScoreAdjustment = 3; - } - - repeated SimilarItems itemItemSimilarities = 1; - - /** One or none of these are given. If none are given, then the items must number 0, 1, ..., num_items - 1. - * If either is given, the length must be exactly num_items. - */ - StringVector itemStringIds = 2; - Int64Vector itemInt64Ids = 3; - - /** Input parameter names specifying different possible inputs to the recommender. - */ - string itemInputFeatureName = 10; /* Required */ - string numRecommendationsInputFeatureName = 11; /* Optional; defaults to all items if not given.*/ - string itemRestrictionInputFeatureName = 12; /* Optional. */ - string itemExclusionInputFeatureName = 13; /* Optional; defaults to input item list if not given. */ - - /** The predicted outputs. At least one of these must be specified. - */ - string recommendedItemListOutputFeatureName = 20; - string recommendedItemScoreOutputFeatureName = 21; - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/LinkedModel.proto b/onnxruntime/core/providers/coreml/mlmodel_format/LinkedModel.proto deleted file mode 100644 index b113000e80a8d..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/LinkedModel.proto +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright (c) 2019, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; -import public "Parameters.proto"; - -package CoreML.Specification; - -/** - * A model which wraps another (compiled) model external to this one - */ -message LinkedModel { - - oneof LinkType { - // A model located via a file system path - LinkedModelFile linkedModelFile = 1; - } -} - -// Model is referenced by a model file name and search path -message LinkedModelFile { - - // Model file name: e.g. "MyFetureExtractor.mlmodelc" - StringParameter linkedModelFileName = 1; - - // Search path to find the linked model file - // Multiple paths can be searched using the unix-style path separator ":" - // Each path can be relative (to this model) or absolute - // - // An empty string is the same as teh relative search path "." - // which searches in the same location as this model file - // - // There are some special paths which start with $ - // - $BUNDLE_MAIN - Indicates to look in the main bundle - // - $BUNDLE_IDENTIFIER(identifier) - Looks in Bunde with given identifer - StringParameter linkedModelSearchPath = 2; -} - - diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Model.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Model.proto deleted file mode 100644 index 737233f2e3fe7..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Model.proto +++ /dev/null @@ -1,322 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -/** - * A Core ML model consists of a specification version - * and a model description, - * and can be any one of the following types: - * - * Neural Networks - * - `NeuralNetwork` - * - * Regressors - * - ``GLMRegressor`` - * - ``SupportVectorRegressor`` - * - ``TreeEnsembleRegressor`` - * - ``NeuralNetworkRegressor`` - * - ``BayesianProbitRegressor`` - * - * Classifiers - * - `NeuralNetworkClassifier` - * - `TreeEnsembleClassifier` - * - `GLMClassifier` - * - `SupportVectorClassifier` - * - `KNearestNeighborsClassifier` - * - * Other models - * - `CustomModel` - * - `TextClassifier` - * - `WordTagger` - * - `Gazetteer` - * - `WordEmbedding` - * - `VisionFeaturePrint` - * - `LinkedModel` - * - `SoundAnalysisPreprocessing` - * - `ItemSimilarityRecommender` - * - * Feature Engineering - * - `Imputer` - * - `Scaler` - * - `Normalizer` - * - `OneHotEncoder` - * - `CategoricalMapping` - * - `FeatureVectorizer` - * - `DictVectorizer` - * - `ArrayFeatureExtractor` - * - `NonMaximumSuppression` - * - * Pipelines - * - `PipelineClassifier` - * - `PipelineRegressor` - * - `Pipeline` - * - * Simple Mathematical Functions - * - `Identity` - */ - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "VisionFeaturePrint.proto"; -import public "TextClassifier.proto"; -import public "WordTagger.proto"; -import public "Gazetteer.proto"; -import public "WordEmbedding.proto"; -import public "ArrayFeatureExtractor.proto"; -import public "BayesianProbitRegressor.proto"; -import public "CategoricalMapping.proto"; -import public "CustomModel.proto"; -import public "DictVectorizer.proto"; -import public "FeatureTypes.proto"; -import public "FeatureVectorizer.proto"; -import public "GLMRegressor.proto"; -import public "GLMClassifier.proto"; -import public "NearestNeighbors.proto"; -import public "Identity.proto"; -import public "Imputer.proto"; -import public "NeuralNetwork.proto"; -import public "Normalizer.proto"; -import public "OneHotEncoder.proto"; -import public "Scaler.proto"; -import public "NonMaximumSuppression.proto"; -import public "SVM.proto"; -import public "TreeEnsemble.proto"; -import public "Parameters.proto"; -import public "ItemSimilarityRecommender.proto"; -import public "SoundAnalysisPreprocessing.proto"; -import public "LinkedModel.proto"; - -package CoreML.Specification; - -/** - * A pipeline consisting of one or more models. - */ -message Pipeline { - repeated Model models = 1; - - // Optional names given for each model - // If not supplied it defaults to ["model0",..., "model"(models.size()-1)] - // These names can be used to disambiguate the scope / domain of a parameter - repeated string names = 2; -} - -/** - * A classifier pipeline. - */ -message PipelineClassifier { - Pipeline pipeline = 1; -} - -/** - * A regressor pipeline. - */ -message PipelineRegressor { - Pipeline pipeline = 1; -} - -/** - * A feature description, - * consisting of a name, short description, and type. - */ -message FeatureDescription { - string name = 1; - string shortDescription = 2; - FeatureType type = 3; -} - -/** - * Model metadata, - * consisting of a short description, a version string, - * an author, a license, and any other user defined - * key/value meta data. - */ -message Metadata { - string shortDescription = 1; - string versionString = 2; - string author = 3; - string license = 4; - map userDefined = 100; -} - -/** - * A description of a model, - * consisting of descriptions of its input and output features. - * Both regressor and classifier models require the name of the - * primary predicted output feature (``predictedFeatureName``). - * Classifier models can specify the output feature containing - * probabilities for the predicted classes - * (``predictedProbabilitiesName``). - */ -message ModelDescription { - repeated FeatureDescription input = 1; - repeated FeatureDescription output = 10; - - // [Required for regressor and classifier models]: the name - // to give to an output feature containing the prediction. - string predictedFeatureName = 11; - - // [Optional for classifier models]: the name to give to an - // output feature containing a dictionary mapping class - // labels to their predicted probabilities. If not specified, - // the dictionary will not be returned by the model. - string predictedProbabilitiesName = 12; - - repeated FeatureDescription trainingInput = 50; - - Metadata metadata = 100; -} - -message SerializedModel { - // Identifier whose content describes the model type of the serialized protocol buffer message. - string identifier = 1; - - // Must be a valid serialized protocol buffer of the above specified type. - bytes model = 2; -} - -/** - * A Core ML model, - * consisting of a specification version, - * a model description, and a model type. - * - * Core ML model compatibility is indicated by - * a monotonically increasing specification version number, - * which is incremented anytime a backward-incompatible change is made - * (this is functionally equivalent to the MAJOR version number - * described by `Semantic Versioning 2.0.0 `_). - * - * Specification Versions : OS Availability (Core ML Version) - * - * 1 : iOS 11, macOS 10.13, tvOS 11, watchOS 4 (Core ML 1) - * - Feedforward & Recurrent Neural Networks - * - General Linear Models - * - Tree Ensembles - * - Support Vector Machines - * - Pipelines - * - Feature Engineering - * - * 2 : iOS 11.2, macOS 10.13.2, tvOS 11.2, watchOS 4.2 (Core ML 1.2) - * - Custom Layers for Neural Networks - * - Float 16 support for Neural Network layers - * - * 3 : iOS 12, macOS 10.14, tvOS 12, watchOS 5 (Core ML 2) - * - Flexible shapes and image sizes - * - Categorical sequences - * - Core ML Vision Feature Print, Text Classifier, Word Tagger - * - Non Max Suppression - * - Crop and Resize Bilinear NN layers - * - Custom Models - * - * 4 : iOS 13, macOS 10.15, tvOS 13, watchOS 6 (Core ML 3) - * - Updatable models - * - Exact shape / general rank mapping for neural networks - * - Large expansion of supported neural network layers - * - Generalized operations - * - Control flow - * - Dynamic layers - * - See NeuralNetwork.proto - * - Nearest Neighbor Classifier - * - Sound Analysis Prepreocessing - * - Recommender - * - Linked Model - * - NLP Gazeteer - * - NLP WordEmbedding - * - * 5 : iOS 14, macOS 11, tvOS 14, watchOS 7 (Core ML 4) - * - Model Deployment - * - Model Encryption - * - Unified converter API with PyTorch and Tensorflow 2 Support in coremltools 4 - * - MIL builder for neural networks and composite ops in coremltools 4 - * - New layers in neural network: - * - CumSum - * - OneHot - * - ClampedReLu - * - ArgSort - * - SliceBySize - * - Convolution3D - * - Pool3D - * - Bilinear Upsample with align corners and fractional factors - * - PixelShuffle - * - MatMul with int8 weights and int8 activations - * - Concat interleave - * - See NeuralNetwork.proto - * - Enhanced Xcode model view with interactive previews - * - Enhanced Xcode Playground support for Core ML models - * - */ -message Model { - int32 specificationVersion = 1; - ModelDescription description = 2; - - /* - * Following model types support on-device update: - * - * - NeuralNetworkClassifier - * - NeuralNetworkRegressor - * - NeuralNetwork - * - KNearestNeighborsClassifier - */ - bool isUpdatable = 10; - - // start at 200 here - // model specific parameters: - oneof Type { - // pipeline starts at 200 - PipelineClassifier pipelineClassifier = 200; - PipelineRegressor pipelineRegressor = 201; - Pipeline pipeline = 202; - - // regressors start at 300 - GLMRegressor glmRegressor = 300; - SupportVectorRegressor supportVectorRegressor = 301; - TreeEnsembleRegressor treeEnsembleRegressor = 302; - NeuralNetworkRegressor neuralNetworkRegressor = 303; - BayesianProbitRegressor bayesianProbitRegressor = 304; - - // classifiers start at 400 - GLMClassifier glmClassifier = 400; - SupportVectorClassifier supportVectorClassifier = 401; - TreeEnsembleClassifier treeEnsembleClassifier = 402; - NeuralNetworkClassifier neuralNetworkClassifier = 403; - KNearestNeighborsClassifier kNearestNeighborsClassifier = 404; - - // generic models start at 500 - NeuralNetwork neuralNetwork = 500; - ItemSimilarityRecommender itemSimilarityRecommender = 501; - - // Custom and linked models - CustomModel customModel = 555; - LinkedModel linkedModel = 556; - - // feature engineering starts at 600 - OneHotEncoder oneHotEncoder = 600; - Imputer imputer = 601; - FeatureVectorizer featureVectorizer = 602; - DictVectorizer dictVectorizer = 603; - Scaler scaler = 604; - CategoricalMapping categoricalMapping = 606; - Normalizer normalizer = 607; - ArrayFeatureExtractor arrayFeatureExtractor = 609; - NonMaximumSuppression nonMaximumSuppression = 610; - - - // simple mathematical functions used for testing start at 900 - Identity identity = 900; - - // reserved until 1000 - - // CoreML provided models - CoreMLModels.TextClassifier textClassifier = 2000; - CoreMLModels.WordTagger wordTagger = 2001; - CoreMLModels.VisionFeaturePrint visionFeaturePrint = 2002; - CoreMLModels.SoundAnalysisPreprocessing soundAnalysisPreprocessing = 2003; - CoreMLModels.Gazetteer gazetteer = 2004; - CoreMLModels.WordEmbedding wordEmbedding = 2005; - - // Reserved private messages start at 3000 - // These messages are subject to change with no notice or support. - SerializedModel serializedModel = 3000; - } -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/NearestNeighbors.proto b/onnxruntime/core/providers/coreml/mlmodel_format/NearestNeighbors.proto deleted file mode 100644 index 82acd8490374d..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/NearestNeighbors.proto +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -import public "DataStructures.proto"; -import public "Parameters.proto"; - -/** - * A k-Nearest-Neighbor classifier - */ -message KNearestNeighborsClassifier { - - /** - * The "core" nearest neighbor model attributes. - */ - NearestNeighborsIndex nearestNeighborsIndex = 1; - - /** - * Number of neighbors to use for classification. - */ - Int64Parameter numberOfNeighbors = 3; - - /** - * Type of labels supported by the model. Currently supports String or Int64 - * labels. - */ - oneof ClassLabels { - StringVector stringClassLabels = 100; - Int64Vector int64ClassLabels = 101; - } - - /** - * Default value of class label (useful when prediction is called on an empty kNN classifier) - */ - oneof DefaultClassLabel { - string defaultStringLabel = 110; - int64 defaultInt64Label = 111; - } - - /** - * Weighting scheme to be used when computing the majority label of a - * new data point. - */ - oneof WeightingScheme { - UniformWeighting uniformWeighting = 200; - InverseDistanceWeighting inverseDistanceWeighting = 210; - } -} - -/** - * The "core" attributes of a Nearest Neighbors model. - */ -message NearestNeighborsIndex { - - /** - * Number of dimensions of the input data. - */ - int32 numberOfDimensions = 1; - - /** - * Vector of floating point data that makes up the model. Each data point must have 'numberOfDimensions' - * dimensions. - */ - repeated FloatVector floatSamples = 2; - - /** - * Backing data structure for the Nearest Neighbors Index. Currently supports - * a linear index or a kd-tree index. - */ - oneof IndexType { - LinearIndex linearIndex = 100; - SingleKdTreeIndex singleKdTreeIndex = 110; - } - - /** - * Distance function to be used to find neighbors. Currently only Squared Euclidean - * Distance is supported. - */ - oneof DistanceFunction { - SquaredEuclideanDistance squaredEuclideanDistance = 200; - } - -} - -/** - * Specifies a uniform weighting scheme (i.e. each neighbor receives equal - * voting power). - */ -message UniformWeighting { -} - - -/** - * Specifies a inverse-distance weighting scheme (i.e. closest neighbors receives higher - * voting power). A nearest neighbor with highest sum of (1 / distance) is picked. - */ -message InverseDistanceWeighting { -} - - -/** - * Specifies a flat index of data points to be searched by brute force. - */ -message LinearIndex { -} - - -/** - * Specifies a kd-tree backend for the nearest neighbors model. - */ -message SingleKdTreeIndex { - - /** - * Number of data points contained within a leaf node of the kd-tree. - */ - int32 leafSize = 1; - -} - - -/** - * Specifies the Squared Euclidean Distance function. - */ -message SquaredEuclideanDistance { -} - diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/NeuralNetwork.proto b/onnxruntime/core/providers/coreml/mlmodel_format/NeuralNetwork.proto deleted file mode 100644 index 44a77c6e7f5f1..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/NeuralNetwork.proto +++ /dev/null @@ -1,6531 +0,0 @@ -// Copyright (c) 2017-2019, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -/** - * A neural network is defined through a collection of layers - * and represents a directed acyclic graph (DAG). - * Each layer has a name, a layer type, - * a list of input names, a list of output names, - * and a collection of parameters specific to the layer type. - * - * The graph structure and connectivity of the neural network - * is inferred from the input and output names. - * A neural network starts with the layer - * whose input name is equal to the value specified in - * ``Model.description.input.name``, - * and ends with the layer - * whose output name is equal to the value specified in - * ``Model.description.output.name``. - * Layers must have unique input and output names, - * and a layer may not have input or output names that - * refer to layers that are not yet defined. - * - * For Core ML specification version <=3, - * all inputs are mapped to static rank 5 tensors, with axis notations - * [Sequence, Batch, Channel, Height, Width]. - * - * From specification version 4 onwards (iOS >= 13, macOS >= 10.15), more options are available - * (see enums ``NeuralNetworkMultiArrayShapeMapping``, ``NeuralNetworkImageShapeMapping``) - * to map inputs to generic N-Dimensional (or N rank) tensors, where N >= 1. - * - * Each layer type may have specific constraints on the ranks of its inputs and outputs. - * - * Some of the layers (such as softmax, reduce, etc) have parameters that have been described in - * terms of notational axis "Channel", "Height", "Width" or "Sequence". They can be re-interpreted easily in - * the general ND setting by using the following rule: - * "width" is same as axis = -1 (i.e. the last axis from the end) - * "height" is same as axis = -2 (i.e. the second last axis from the end) - * "channel" is same as axis = -3 (i.e. the third last axis from the end) - * "sequence" is same as axis = -5 (i.e. the fifth last axis from the end) - * - * Several layers are available in 3 different variations, with the names ending - * in identifiers: ``like``, ``static`` and ``dynamic``. For instance, ``FillLike``, - * ``FillStatic`` and ``FillDynamic``. The ``static`` variation generally will have - * a property corresponding to the shape of the output. For instance, if the - * output of the ``FillStatic`` layer is desired to be of shape (10, 4), the - * property ``targetShape`` will have to be set to [10, 4]. In the ``dynamic`` case, - * the shape is an input, hence it can be changed at runtime. For instance, for - * a ``FillDynamic`` layer, the input would have to be an array containing the - * values 10 and 4, if the desired output is of shape (10, 4). Whereas in the - * ``like`` case, the additional input's shape is used as the output shape, ignoring - * its values. For instance, for a ``FillLike`` layer, for an input with shape - * (10, 4), the output generated will also be of shape (10, 4), values of the - * input will be ignored. - */ - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; -import public "Parameters.proto"; - -package CoreML.Specification; - - -enum NeuralNetworkMultiArrayShapeMapping { - - /* - * Describes how the MultiArray shape for the inputs, - * provided in Features Types proto via model description, - * is mapped to construct tensors that are fed into the Neural Network layers. - */ - - /* - * Default legacy value. Only supported for Core ML Specification version <= 3. - * - * The default legacy shape mapping resolves all input shapes to a rank 5 equivalent - * with axis notation of [Seq, Batch, Channel, Height, Width]. - * - * When this enum value is selected, - * the repeated shape field in the message "ArrayFeatureType" in feature types proto, - * must be either length 1 or length 3. - * - * The following rule is used to map the values in the shape field to the actual tensor shape: - * rank 1 shape is mapped to shape [1,1,C,1,1] - * rank 3 shape is mapped to shape [1,1,C,H,W] - * At runtime, the first two dimensions (Seq or Batch) can be presented as well, with non-1 values. - * - * It is invalid to use this enum value if any of the layers added - * Specification version 4 (iOS >= 13, macOS >= 10.15) onwards are used in the network. - * Validator will raise an error in that case. - */ - RANK5_ARRAY_MAPPING = 0; - - /* - * The exact shape and rank (i.e. number of dimensions in the shape) of the input, - * as specified in the message "ArrayFeatureType", is passed through to the layers. - * Supported only for Specification version >= 4 (iOS >= 13, macOS >= 10.15). - */ - EXACT_ARRAY_MAPPING = 1; - -} - -enum NeuralNetworkImageShapeMapping { - - /* - * Describes how the shape of the input tensors is constructed from image inputs. - */ - - /* - * In this case, image input is mapped to a rank 5 tensor. - * For Color images, input tensor is shaped as [1,1,3,H,W]. - * For Gray images, input tensor is shaped as [1,1,1,H,W]. - */ - RANK5_IMAGE_MAPPING = 0; - - /* - * For Color images, input tensor is shaped as [1,3,H,W]. - * For Gray images, input tensor is shaped as [1,1,H,W]. - * Supported only for Specification version >= 4 (iOS >= 13, macOS >= 10.15). - */ - RANK4_IMAGE_MAPPING = 1; - -} - -/** - A neural network. - */ -message NeuralNetwork { - - repeated NeuralNetworkLayer layers = 1; - repeated NeuralNetworkPreprocessing preprocessing = 2; - - // use this enum value to determine the input tensor shapes to the neural network, for multiarray inputs - NeuralNetworkMultiArrayShapeMapping arrayInputShapeMapping = 5; - - // use this enum value to determine the input tensor shapes to the neural network, for image inputs - NeuralNetworkImageShapeMapping imageInputShapeMapping = 6; - - - NetworkUpdateParameters updateParams = 10; - -} - -/// Preprocessing -/// ------------- - -/** - * A neural network preprocessor that - * performs a scalar multiplication of an image - * followed by addition of scalar biases to the channels. - * - * Input: X - * An image in BGR or RGB format with shape ``[3, H, W]`` - * or in grayscale format with shape ``[1, H, W]``. - * Output: Y - * An image with format and shape corresponding to the input. - * - * If the input image is in BGR format: - * - * .. code:: - * - * Y[0, :, :] = channelScale * X[0, :, :] + blueBias - * Y[1, :, :] = channelScale * X[1, :, :] + greenBias - * Y[2, :, :] = channelScale * X[2, :, :] + redBias - * - * If the input image is in RGB format: - * - * .. code:: - * - * Y[0, :, :] = channelScale * X[0, :, :] + redBias - * Y[1, :, :] = channelScale * X[1, :, :] + greenBias - * Y[2, :, :] = channelScale * X[2, :, :] + blueBias - * - * If the input image is in grayscale format: - * - * .. code:: - * - * Y[0, :, :] = channelScale * X[0, :, :] + grayBias - */ -message NeuralNetworkImageScaler { - - float channelScale = 10; ///Scalar to be multiplied. - float blueBias = 20; ///Scalar blue bias to be added. - float greenBias = 21; ///Scalar green bias to be added. - float redBias = 22; ///Scalar red bias to be added. - float grayBias = 30; ///Scalar bias to be added for grayscale images. - -} - -/** - * A neural network preprocessor that - * subtracts the provided mean image from the input image. - * The mean image is subtracted from the input named - * ``NeuralNetworkPreprocessing.featureName``. - */ -message NeuralNetworkMeanImage { - - /** - * Mean image stored as a flattened array of floats, - * representing shape [Channel,Height,Width]. - */ - repeated float meanImage = 1; - -} - -/// Preprocessing parameters for image inputs. -message NeuralNetworkPreprocessing { - - string featureName = 1; /// must be equal to the input name to which the preprocessing is applied - oneof preprocessor { - NeuralNetworkImageScaler scaler = 10; - NeuralNetworkMeanImage meanImage = 11; - } - -} - -/// Activation Functions -/// -------------------- - -/** - * A rectified linear unit (ReLU) activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \text{max}(0, x) - */ -message ActivationReLU { - -} - -/** - * A leaky rectified linear unit (ReLU) activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \begin{cases} - * x & \text{if } x \geq 0 \\ - * \alpha x & \text{if } x < 0 - * \end{cases} - */ -message ActivationLeakyReLU { - - float alpha = 1; //negative slope value for leakyReLU - -} - -/** - * A hyperbolic tangent activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \dfrac{1 - e^{-2x}}{1 + e^{-2x}} - */ -message ActivationTanh { - -} - -/** - * A scaled hyperbolic tangent activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \alpha \tanh(\beta x) - */ -message ActivationScaledTanh { - - float alpha = 1; - float beta = 2; - -} - -/** - * A sigmoid activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \dfrac{1}{1 + e^{-x}} - */ -message ActivationSigmoid { - -} - -/** - * A linear activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \alpha x + \beta - */ -message ActivationLinear { - - float alpha = 1; - float beta = 2; - -} - -/** - * A hard sigmoid activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \text{min}(\text{max}(\alpha x + \beta, 0), 1) - */ -message ActivationSigmoidHard { - - float alpha = 1; - float beta = 2; - -} - -/** - * A parameterized rectified linear unit (PReLU) activation function. - * Input must be at least rank 3. Axis = -3 is denoted by "C", or channels. - * "alpha" parameter can be a vector of length C. - * - * This function has the following formula: - * - * .. math:: - * f(x_i) = \begin{cases} - * x_i & \text{if } x_i \geq 0 \\ - * \alpha_i x_i & \text{if } x_i < 0 - * \end{cases} \;,\;i=1,...,C - */ -message ActivationPReLU { - - // parameter of length C or 1. - // If length is 1, same value is used for all channels - WeightParams alpha = 1; - -} - -/** - * An exponential linear unit (ELU) activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \begin{cases} - * x & \text{if } x \geq 0 \\ - * \alpha (e^x - 1) & \text{if } x < 0 - * \end{cases} - */ -message ActivationELU { - - float alpha = 1; - -} - -/** - * A thresholded rectified linear unit (ReLU) activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \begin{cases} - * x & \text{if } x \geq \alpha \\ - * 0 & \text{if } x < \alpha - * \end{cases} - */ -message ActivationThresholdedReLU { - - float alpha = 1; - -} - -/** - * A softsign activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \dfrac{x}{1 + |x|} - */ -message ActivationSoftsign { - -} - -/** - * A softplus activation function. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \text{log}(1 + e^x) - */ -message ActivationSoftplus { - -} - -/** - * A parametric softplus activation function. - * Input must be at least rank 3. axis = -3 is denoted by "C", or channels. - * "alpha"/"beta" parameter can be a vector of length C. - * - * This function has the following formula: - * - * .. math:: - * f(x_i) = \alpha_i \text{log}(1 + e^{\beta_i x_i}) \;,\;i=1,...,C - */ -message ActivationParametricSoftplus { - - // If length is 1, same value is used for all channels - WeightParams alpha = 1; //parameter of length C or 1 - WeightParams beta = 2; //parameter of length C or 1 - -} - -message ActivationParams { - - oneof NonlinearityType { - ActivationLinear linear = 5; - - ActivationReLU ReLU = 10; - ActivationLeakyReLU leakyReLU = 15; - ActivationThresholdedReLU thresholdedReLU = 20; - ActivationPReLU PReLU = 25; - - ActivationTanh tanh = 30; - ActivationScaledTanh scaledTanh = 31; - - ActivationSigmoid sigmoid = 40; - ActivationSigmoidHard sigmoidHard = 41; - - ActivationELU ELU = 50; - - ActivationSoftsign softsign = 60; - ActivationSoftplus softplus = 70; - ActivationParametricSoftplus parametricSoftplus = 71; - } - -} - -/** - * Representation of the intermediate tensors - */ -message Tensor { - - // Number of dimensions in the tensor shape - uint32 rank = 1; - // actual value of the tensor shape. - // must be of length "rank". Can contain -1s for unknown dimensions. - repeated int64 dimValue = 2; - -} - -/** - * A single neural network layer. - */ -message NeuralNetworkLayer { - - string name = 1; //descriptive name of the layer - repeated string input = 2; - repeated string output = 3; - - repeated Tensor inputTensor = 4; // must be the same length as the "input" field - repeated Tensor outputTensor = 5; // must be the same length as the "output" field - - // Must be set to true to mark the layer as updatable. - // If true, the weightParams in the layer's properties must also be set to updatable - // If false, the value of the isUpdatable parameter within the layer's weights are ignored - bool isUpdatable = 10; - - oneof layer { - - // Start at 100 here - ConvolutionLayerParams convolution = 100; - - PoolingLayerParams pooling = 120; - - ActivationParams activation = 130; - - InnerProductLayerParams innerProduct = 140; - EmbeddingLayerParams embedding = 150; - - // Normalization-related Layers - BatchnormLayerParams batchnorm = 160; - MeanVarianceNormalizeLayerParams mvn = 165; - L2NormalizeLayerParams l2normalize = 170; - SoftmaxLayerParams softmax = 175; - LRNLayerParams lrn = 180; - - CropLayerParams crop = 190; - PaddingLayerParams padding = 200; - UpsampleLayerParams upsample = 210; - - ResizeBilinearLayerParams resizeBilinear = 211; - CropResizeLayerParams cropResize = 212; - - UnaryFunctionLayerParams unary = 220; - - // Element-wise Operations - AddLayerParams add = 230; - MultiplyLayerParams multiply = 231; - - AverageLayerParams average = 240; - ScaleLayerParams scale = 245; - - BiasLayerParams bias = 250; - MaxLayerParams max = 260; - MinLayerParams min = 261; - - DotProductLayerParams dot = 270; - ReduceLayerParams reduce = 280; - LoadConstantLayerParams loadConstant = 290; - - // Data Reorganization - ReshapeLayerParams reshape = 300; - FlattenLayerParams flatten = 301; - PermuteLayerParams permute = 310; - ConcatLayerParams concat = 320; - SplitLayerParams split = 330; - SequenceRepeatLayerParams sequenceRepeat = 340; - - ReorganizeDataLayerParams reorganizeData = 345; - SliceLayerParams slice = 350; - - // Recurrent Layers - SimpleRecurrentLayerParams simpleRecurrent = 400; - GRULayerParams gru = 410; - UniDirectionalLSTMLayerParams uniDirectionalLSTM = 420; - BiDirectionalLSTMLayerParams biDirectionalLSTM = 430; - - // Custom (user-implemented) Layer - CustomLayerParams custom = 500; - - // Following layers are available only after Core ML Specification - // version >= 4 (iOS >= 13, macOS >= 10.15) - - // Control Flow related Layers - CopyLayerParams copy = 600; - BranchLayerParams branch = 605; - - LoopLayerParams loop = 615; - LoopBreakLayerParams loopBreak = 620; - LoopContinueLayerParams loopContinue = 625; - - RangeStaticLayerParams rangeStatic = 635; - RangeDynamicLayerParams rangeDynamic = 640; - - // Element-wise Unary Layers - ClipLayerParams clip = 660; - CeilLayerParams ceil = 665; - FloorLayerParams floor = 670; - - SignLayerParams sign = 680; - RoundLayerParams round = 685; - - Exp2LayerParams exp2 = 700; - - SinLayerParams sin = 710; - CosLayerParams cos = 715; - TanLayerParams tan = 720; - - AsinLayerParams asin = 730; - AcosLayerParams acos = 735; - AtanLayerParams atan = 740; - - SinhLayerParams sinh = 750; - CoshLayerParams cosh = 755; - TanhLayerParams tanh = 760; - - AsinhLayerParams asinh = 770; - AcoshLayerParams acosh = 775; - AtanhLayerParams atanh = 780; - - ErfLayerParams erf = 790; - GeluLayerParams gelu = 795; - - // Element-wise Binary with Broadcasting Support - EqualLayerParams equal = 815; - NotEqualLayerParams notEqual = 820; - LessThanLayerParams lessThan = 825; - LessEqualLayerParams lessEqual = 827; - GreaterThanLayerParams greaterThan = 830; - GreaterEqualLayerParams greaterEqual = 832; - - LogicalOrLayerParams logicalOr = 840; - LogicalXorLayerParams logicalXor = 845; - LogicalNotLayerParams logicalNot = 850; - LogicalAndLayerParams logicalAnd = 855; - - ModBroadcastableLayerParams modBroadcastable = 865; - MinBroadcastableLayerParams minBroadcastable = 870; - MaxBroadcastableLayerParams maxBroadcastable = 875; - AddBroadcastableLayerParams addBroadcastable = 880; - PowBroadcastableLayerParams powBroadcastable = 885; - DivideBroadcastableLayerParams divideBroadcastable = 890; - FloorDivBroadcastableLayerParams floorDivBroadcastable = 895; - MultiplyBroadcastableLayerParams multiplyBroadcastable = 900; - SubtractBroadcastableLayerParams subtractBroadcastable = 905; - - // Tensor Manipulations - TileLayerParams tile = 920; - StackLayerParams stack = 925; - GatherLayerParams gather = 930; - ScatterLayerParams scatter = 935; - GatherNDLayerParams gatherND = 940; - ScatterNDLayerParams scatterND = 945; - SoftmaxNDLayerParams softmaxND = 950; - GatherAlongAxisLayerParams gatherAlongAxis = 952; - ScatterAlongAxisLayerParams scatterAlongAxis = 954; - - ReverseLayerParams reverse = 960; - ReverseSeqLayerParams reverseSeq = 965; - - SplitNDLayerParams splitND = 975; - ConcatNDLayerParams concatND = 980; - TransposeLayerParams transpose = 985; - - SliceStaticLayerParams sliceStatic = 995; - SliceDynamicLayerParams sliceDynamic = 1000; - SlidingWindowsLayerParams slidingWindows = 1005; - - TopKLayerParams topK = 1015; - ArgMinLayerParams argMin = 1020; - ArgMaxLayerParams argMax = 1025; - - EmbeddingNDLayerParams embeddingND = 1040; - BatchedMatMulLayerParams batchedMatmul = 1045; - - // Tensor Allocation / Reshape-related Operations - GetShapeLayerParams getShape = 1065; - LoadConstantNDLayerParams loadConstantND = 1070; - - FillLikeLayerParams fillLike = 1080; - FillStaticLayerParams fillStatic = 1085; - FillDynamicLayerParams fillDynamic = 1090; - - BroadcastToLikeLayerParams broadcastToLike = 1100; - BroadcastToStaticLayerParams broadcastToStatic = 1105; - BroadcastToDynamicLayerParams broadcastToDynamic = 1110; - - SqueezeLayerParams squeeze = 1120; - ExpandDimsLayerParams expandDims = 1125; - FlattenTo2DLayerParams flattenTo2D = 1130; - ReshapeLikeLayerParams reshapeLike = 1135; - ReshapeStaticLayerParams reshapeStatic = 1140; - ReshapeDynamicLayerParams reshapeDynamic = 1145; - RankPreservingReshapeLayerParams rankPreservingReshape = 1150; - - ConstantPaddingLayerParams constantPad = 1155; - - // Random Distributions - RandomNormalLikeLayerParams randomNormalLike = 1170; - RandomNormalStaticLayerParams randomNormalStatic = 1175; - RandomNormalDynamicLayerParams randomNormalDynamic = 1180; - - RandomUniformLikeLayerParams randomUniformLike = 1190; - RandomUniformStaticLayerParams randomUniformStatic = 1195; - RandomUniformDynamicLayerParams randomUniformDynamic = 1200; - - RandomBernoulliLikeLayerParams randomBernoulliLike = 1210; - RandomBernoulliStaticLayerParams randomBernoulliStatic = 1215; - RandomBernoulliDynamicLayerParams randomBernoulliDynamic = 1220; - - CategoricalDistributionLayerParams categoricalDistribution = 1230; - - // Reduction-related Layers: - ReduceL1LayerParams reduceL1 = 1250; - ReduceL2LayerParams reduceL2 = 1255; - ReduceMaxLayerParams reduceMax = 1260; - ReduceMinLayerParams reduceMin = 1265; - ReduceSumLayerParams reduceSum = 1270; - ReduceProdLayerParams reduceProd = 1275; - ReduceMeanLayerParams reduceMean = 1280; - ReduceLogSumLayerParams reduceLogSum = 1285; - ReduceSumSquareLayerParams reduceSumSquare = 1290; - ReduceLogSumExpLayerParams reduceLogSumExp = 1295; - - // Masking / Selection Layers - WhereNonZeroLayerParams whereNonZero = 1313; - MatrixBandPartLayerParams matrixBandPart = 1315; - LowerTriangularLayerParams lowerTriangular = 1320; - UpperTriangularLayerParams upperTriangular = 1325; - WhereBroadcastableLayerParams whereBroadcastable = 1330; - - // Normalization Layers - LayerNormalizationLayerParams layerNormalization = 1350; - - NonMaximumSuppressionLayerParams NonMaximumSuppression = 1400; - - // Following layers are available only after Core ML Specification - // version >= 5 (iOS >= 14, macOS >= 11.0) - OneHotLayerParams oneHot = 1450; - CumSumLayerParams cumSum = 1455; - ClampedReLULayerParams clampedReLU = 1460; - ArgSortLayerParams argSort = 1461; - Pooling3DLayerParams pooling3d = 1465; - GlobalPooling3DLayerParams globalPooling3d = 1466; - SliceBySizeLayerParams sliceBySize = 1470; - Convolution3DLayerParams convolution3d = 1471; - - } - -} - -/** - * Branching Layer - * - * A layer that provides the functionality of branching or an If-Else block. - * - * Must have 1 input. There are no outputs as the execution is transferred to either the - * if or the else branch based on the value of the input. - * - * Input is the condition predicate. Must be a scalar (length 1 tensor). - * - */ -message BranchLayerParams { - - /** - * execute this graph if the absolute value of the input Tensor is greater than 1e-6 - * This must be present. - */ - NeuralNetwork ifBranch = 1; - /** - * execute this graph if the absolute value of the input Tensor is less than 1e-6 - * This is optional. - */ - NeuralNetwork elseBranch = 2; - -} - -/** - * Loop Layer - * - * A layer that provides the functionality of a "for" loop or a "while" loop. - * - * There are either no inputs or 1 input. When an input is present, it corresponds to the maximum loop count, - * in that case the value of the "maxLoopIterations" field is ignored. Input must be a scalar. - * (For description below, maxLoopIterations is assumed to be the value of the input, when its present) - * - * No outputs are produced. Blobs produced by the condition or the body network are visible in the scope of the overall network. - * - * "conditionNetwork" must produce a tensor with the name specified in the "conditionVar" field. - * - * There are 3 possible cases for determining the termination condition: - * - * Case 1: - * - * If there is no "conditionNetwork", in this case the layer corresponds to a pure for loop, which is run "maxLoopIterations" number of times. - * Equivalent pseudo-code: - * - * for loopIterator = 0 : maxLoopIterations - * bodyNetwork() - * - * - * Case 2: - * - * "conditionNetwork" is present, and "maxLoopIterations" is 0 and there is no input, - * in this case the layer corresponds to a while loop. Equivalent pseudo-code: - * - * conditionVar = conditionNetwork() - * while conditionVar: - * bodyNetwork() - * conditionVar = conditionNetwork() - * - * - * Case 3: - * - * "conditionNetwork" is provided, and "maxLoopIterations" is positive or there is an input, - * in this case the layer corresponds to a while loop with a joint condition. Equivalent pseudo-code: - * - * loopIterator = 0 - * conditionVar = conditionNetwork() - * while (conditionVar and loopIterator < maxLoopIterations): - * bodyNetwork() - * loopIterator = loopIterator + 1 - * conditionVar = conditionNetwork() - * - */ -message LoopLayerParams { - - /** - * maximum number of iterations. Ignored if input is present. - */ - uint64 maxLoopIterations = 1; - /** - * This field provides the name of the tensor which is produced by the conditionNetwork - * and whose value is checked to start/continue/terminate the loop. Value close to 0.0f is treated as False. - * This field is optional. - * Must be a non empty string if and only if "conditionNetwork" is present. - */ - string conditionVar = 2; - /** - * Must generate a tensor with the name provided in the "conditionVar" field. - * This field is optional. - * Must be present if and only if "conditionVar" field is a non empty string. - */ - NeuralNetwork conditionNetwork = 3; - /** - * Body of the loop. - * This field must be present. - */ - NeuralNetwork bodyNetwork = 4; - -} - -/** - * Loop break Layer - * - * Terminate the loop that has this layer. - * If present, it should always reside in the "bodyNetwork" of the loop layer - * - * No inputs/outputs - * - */ -message LoopBreakLayerParams { - -} - -/** - * Loop Continue Layer - * - * Stop the current loop iteration and continue on the next iteration. - * If present, it should always reside in the "bodyNetwork" of the loop layer - * - * No inputs/outputs - * - */ -message LoopContinueLayerParams { - -} - -/** - * Copy Layer - * - * A layer that copies its input tensor to the output tensor. - * Must have 1 input and 1 output, with distinct names. - * This is the only layer that is allowed to re-generate an output that is already present in the neural network prior to this layer, - * in which case it will overwrite the output tensor. - * - */ -message CopyLayerParams { - -} - -/** - * GreaterThan Layer - * - * Either 1 or 2 inputs. - * Produces 1 output. - * Perform elementwise greater than operation. - * - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = x1 > x2 - * or - * y = x1 > alpha, if only one input is provided - * - * Broadcasting is supported. - * - */ -message GreaterThanLayerParams { - - /** - * Compare to the scalar value provided here if there is 1 input - */ - float alpha = 2; - -} - -/** - * GreaterEqual Layer - * - * Either 1 or 2 inputs. - * Produces 1 output. - * Perform elementwise greater equal operation. - * - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = x1 >= x2 - * or - * y = x1 >= alpha, if only one input is provided - * - * Broadcasting is supported. - * - */ -message GreaterEqualLayerParams { - - /** - * Compare to the scalar value provided here if there is 1 input - */ - float alpha = 2; - -} - -/** - * LessThan Layer - * - * Either 1 or 2 inputs. - * Produces 1 output. - * Perform elementwise less than operation. - * - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = x1 < x2 - * or - * y = x1 < alpha, if only one input is provided - * - * Broadcasting is supported. - * - */ -message LessThanLayerParams { - - /** - * Compare to the scalar value provided here if there is 1 input - */ - float alpha = 2; - -} - -/** - * LessEqual Layer - * - * Either 1 or 2 inputs. - * Produces 1 output. - * Perform elementwise less equal operation. - * - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = x1 <= x2 - * or - * y = x1 <= alpha, if only one input is provided - * - * Broadcasting is supported. - * - */ -message LessEqualLayerParams { - - /** - * Compare to the scalar value provided here if there is 1 input - */ - float alpha = 2; - -} - -/** - * Equal Layer - * - * Either 1 or 2 inputs. - * Produces 1 output. - * Perform elementwise equal operation. - * - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = x1 == x2 - * or - * y = x1 == alpha, if only one input is provided - * - * Broadcasting is supported. - * - */ -message EqualLayerParams { - - /** - * Compare to the scalar value provided here if there is 1 input - */ - float alpha = 1; - -} - -/** - * NotEqual Layer - * - * Either 1 or 2 inputs. - * Produces 1 output. - * Perform elementwise not equal operation. - * - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = x1 != x2 - * or - * y = x1 != alpha, if only one input is provided - * - * Broadcasting is supported. - * - */ -message NotEqualLayerParams { - - /** - * Compare to the scalar value provided here if there is 1 input - */ - float alpha = 1; - -} - -/** - * LogicalAnd Layer - * - * Must have 2 inputs, produces 1 output. - * Perform elementwise logical AND operation. - * - * Input is considered False if equal to 0.0f otherwise True. - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = AND(x1, x2) - * - * Broadcasting is supported. - * - */ -message LogicalAndLayerParams { - -} - -/** - * LogicalOr Layer - * - * Must have 2 inputs, produces 1 output. - * Perform elementwise logical OR operation. - * - * Input is considered False if equal to 0.0f otherwise True. - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = OR(x1, x2) - * - * Broadcasting is supported. - * - */ -message LogicalOrLayerParams { - -} - -/** - * LogicalXor Layer - * - * Must have 2 inputs, produces 1 output. - * Perform elementwise logical XOR operation. - * - * Input is considered False if equal to 0.0f otherwise True. - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = XOR(x1, x2) - * - * Broadcasting is supported. - * - */ -message LogicalXorLayerParams { - -} - -/** - * LogicalNot Layer - * - * Must have 1 input, produces 1 output. - * Perform elementwise logical NOT operation. - * - * Input is considered False if equal to 0.0f otherwise True. - * Output is 1.0f if the condition is true otherwise 0.0f. - * - * .. code:: - * - * y = NOT(x) - * - * - */ -message LogicalNotLayerParams { - -} - -/// Border Amounts -/// -------------- - -/** - * Specifies the amount of spatial border to be either padded or cropped. - * - * For padding: - * - * .. code:: - * - * H_out = borderAmounts[0].startEdgeSize + H_in + borderAmounts[0].endEdgeSize - * W_out = borderAmounts[1].startEdgeSize + W_in + borderAmounts[1].endEdgeSize - * - * topPaddingAmount == Height startEdgeSize - * bottomPaddingAmount == Height endEdgeSize - * leftPaddingAmount == Width startEdgeSize - * rightPaddingAmount == Width endEdgeSize - * - * For cropping: - * - * .. code:: - * - * H_out = (-borderAmounts[0].startEdgeSize) + H_in + (-borderAmounts[0].endEdgeSize) - * W_out = (-borderAmounts[1].startEdgeSize) + W_in + (-borderAmounts[1].endEdgeSize) - * - * topCropAmount == Height startEdgeSize - * bottomCropAmount == Height endEdgeSize - * leftCropAmount == Width startEdgeSize - * rightCropAmount == Width endEdgeSize - */ -message BorderAmounts { - - message EdgeSizes { - /** - * The amount to be padded or cropped from the beginning. - */ - uint64 startEdgeSize = 1; - - /** - * The amount to be padded or cropped from the end. - */ - uint64 endEdgeSize = 2; - } - - /** - * The border amounts. - * This must be length 2 in the order ``[H, W]``. - */ - repeated EdgeSizes borderAmounts = 10; - -} - -/** - * Specifies the type of padding to be used with Convolution/Deconvolution and Pooling layers. - * After padding, input spatial shape: ``[H_in, W_in]``, gets modified to the - * output spatial shape ``[H_out, W_out]``. - * - * .. code:: - * - * topPaddingAmount == Height startEdgeSize == borderAmounts[0].startEdgeSize - * bottomPaddingAmount == Height endEdgeSize == borderAmounts[0].endEdgeSize - * leftPaddingAmount == Width startEdgeSize == borderAmounts[1].startEdgeSize - * rightPaddingAmount == Width endEdgeSize == borderAmounts[1].endEdgeSize - * - * With Convolution or Pooling: - * - * .. code:: - * - * H_out = int_division_round_down((H_in + topPaddingAmount + bottomPaddingAmount - KernelSize[0]),stride[0]) + 1 - * - * which is same as: - * - * .. code:: - * - * H_out = int_division_round_up((H_in + topPaddingAmount + bottomPaddingAmount - KernelSize[0] + 1),stride[0]) - * - * With Deconvolution: - * - * .. code:: - * - * H_out = (H_in-1) * stride[0] + kernelSize[0] - (topPaddingAmount + bottomPaddingAmount) - * - * - * The equivalent expressions hold true for ``W_out`` as well. - * - * - * By default, the values of ``paddingAmounts`` are set to ``0``, - * which results in a "true" valid padding. - * If non-zero values are provided for ``paddingAmounts``, - * "valid" convolution/pooling is performed within the spatially expanded input. - * - */ -message ValidPadding { - - BorderAmounts paddingAmounts = 1; - -} - -/** - * Specifies the type of padding to be used with Convolution/Deconvolution and pooling layers. - * After padding, input spatial shape: ``[H_in, W_in]``, gets modified to the - * output spatial shape ``[H_out, W_out]``. - * With Convolution or pooling: - * - * .. code:: - * - * H_out = int_division_round_up(H_in,stride[0]) - * W_out = int_division_round_up(W_in,stride[1]) - * - * This is achieved by using the following padding amounts: - * - * .. code:: - * - * totalPaddingHeight = max(0,(H_out-1) * stride[0] + KernelSize[0] - Hin) - * totalPaddingWidth = max(0,(W_out-1) * stride[1] + KernelSize[1] - Win) - * - * There are two modes of asymmetry: - * ``BOTTOM_RIGHT_HEAVY``, and ``TOP_LEFT_HEAVY``. - * - * If the mode is ``BOTTOM_RIGHT_HEAVY``: - * - * .. code:: - * - * topPaddingAmount = floor(totalPaddingHeight / 2) - * bottomPaddingAmount = totalPaddingHeight - topPaddingAmount - * leftPaddingAmount = floor(totalPaddingWidth / 2) - * rightPaddingAmount = totalPaddingWidth - leftPaddingAmount - * - * If the mode is ``TOP_LEFT_HEAVY``: - * - * .. code:: - * - * bottomPaddingAmount = floor(totalPaddingHeight / 2) - * topPaddingAmount = totalPaddingHeight - bottomPaddingAmount - * rightPaddingAmount = floor(totalPaddingWidth / 2) - * leftPaddingAmount = totalPaddingWidth - rightPaddingAmount - * - * - * With Deconvolution: - * - * .. code:: - * - * H_out = H_in * stride[0] - * W_out = W_in * stride[1] - */ -message SamePadding { - - enum SamePaddingMode { - - BOTTOM_RIGHT_HEAVY = 0; - TOP_LEFT_HEAVY = 1; - - } - SamePaddingMode asymmetryMode = 1; - -} - -/** - * Specifies how grid points are sampled from an interval. - * Without the loss of generality, assume the interval to be [0, X-1] from which N points are to be sampled. - * Here X may correspond to an input image's height or width. - * All the methods can be expressed in terms of numpy's linspace function, along with the constraint that grid points have to lie in the interval [0, X-1]. - * Note: numpy.linspace(start = start, end = end, num = N, endpoint = True) corresponds to sampling - * N points uniformly from the interval [start, end], endpoints included. - * The methods vary in how the ``start`` and ``end`` values are computed. - */ -message SamplingMode { - - enum Method { - - /** - * start = 0, end = X-1 - * grid points = numpy.linspace(start, end) - */ - STRICT_ALIGN_ENDPOINTS_MODE = 0; - - /** - * if N == 1: start = end = (X-1)/2 - * otherwise, start = 0, end = X-1 - * grid points = numpy.linspace(start, end) - */ - ALIGN_ENDPOINTS_MODE = 1; - - /** - * start = 0, end = X - X/N - * grid points = min(X-1, numpy.linspace(start, end)) - * This is same as the mode used in the upsample layer in this specification, when used with bilinear interpolation. In that case N/X = upsample ratio. - */ - UPSAMPLE_MODE = 2; - - /** - * spacing = max(1, X-1)/N - * start = 0.5 * spacing - * end = start + (N-1) * spacing - * grid points = min(X-1, numpy.linspace(start, end)) - */ - ROI_ALIGN_MODE = 3; - - } - - Method samplingMethod = 1; - -} - -/** - * Specifies the convention used to specify four bounding box coordinates for an image of size (Height, Width). - * The (0,0) coordinate corresponds to the top-left corner of the image. - */ -message BoxCoordinatesMode { - - enum Coordinates { - - /** - * [h_start, w_start, h_end, w_end] - */ - CORNERS_HEIGHT_FIRST = 0; - - /** - * [w_start, h_start, w_end, h_end] - */ - CORNERS_WIDTH_FIRST = 1; - - /** - * [h_center, w_center, box_height, box_width] - */ - CENTER_SIZE_HEIGHT_FIRST = 2; - - /** - * [w_center, h_center, box_width, box_height] - */ - CENTER_SIZE_WIDTH_FIRST = 3; - - } - - Coordinates boxMode = 1; - -} - -/** - * Weights for layer parameters. - * Weights are stored as repeated floating point numbers - * using row-major ordering - * and can represent 1-, 2-, 3-, or 4-dimensional data. - */ -message WeightParams { - - /** - * Values specified in single / float / FP32 precision. - */ - repeated float floatValue = 1; - - /** - * Values in 16-bit half precision floating point. - */ - bytes float16Value = 2; - - /** - * Raw value specification for quantized lower precisions. - * - * This field is interpreted as uintN, where N is the number of bits in quantization. - * E.g. if n=8, the field is interpreted as an array of UINT8. - * Use this field for quantized parameters unless specifically noted to use - * int8RawValue. - */ - bytes rawValue = 30; - - /** - * Field to be used if int8DynamicQuantize is set in the parent layer. - * Cannot be set if rawValue is also set. - * The values in this field are interpreted as INT8. - * - * If this field is set, following conditions must hold true: - * * QuantizationType == LinearQuantizationParams, such that - * * size of the "scale" field is 1 and "bias" field is empty in "LinearQuantizationParams" - */ - bytes int8RawValue = 31; - - /** - * Quantization related parameters. - */ - QuantizationParams quantization = 40; - - bool isUpdatable = 50; - -} - -/** - * Quantization parameters. - */ -message QuantizationParams { - - uint64 numberOfBits = 1; - oneof QuantizationType { - LinearQuantizationParams linearQuantization = 101; - LookUpTableQuantizationParams lookupTableQuantization = 102; - } - -} - -message LinearQuantizationParams { - - /** - * Stores scale and bias values corresponding to the quantized weights. - * Must be an array of 1 element, or an array of C elements, where C - * is number of output channels. For recurrent layers it is equal to - * the output vector size. - * - * Relationship between quantized weights, unquantized weights, scale and bias: - * - * W_unquantized = W_quantized * scale + bias - * - */ - repeated float scale = 1; - repeated float bias = 2; - -} - -message LookUpTableQuantizationParams { - - /* Stores look-up table quantization values. Must be an array of - (2^numberOfBits) Elements. - */ - repeated float floatValue = 1; - -} - -/// Layers -/// ------ - -/** - * A layer that performs spatial convolution or deconvolution. - * - * .. code:: - * - * y = ConvolutionLayer(x) - * - * Requires 1 or 2 inputs and produces 1 output. - * - * Input - * First Input: - * A blob with rank greater than or equal to 4. - * Rank 4 blob represents [Batch, channels, height, width]. - * For ranks greater than 4, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * From Core ML specification version 4 onwards (iOS >= 13, macOS >= 10.15). - * convolution layer can have 2 inputs, in which case the second input is - * the blob representing the weights. This is allowed when "isDeconvolution" = False. - * The weight blob should have shape - * ``[outputChannels, kernelChannels, kernelHeight, kernelWidth]``, - * where kernelChannels == inputChannels / nGroups. - * - * Output - * Rank is same as the input. e.g.: for rank 4 input, output shape is [B, C_out, H_out, W_out] - * - * - * If ``dilationFactor`` is not 1, effective kernel size is - * modified as follows: - * - * .. code:: - * - * KernelSize[0] <-- (kernelSize[0]-1) * dilationFactor[0] + 1 - * KernelSize[1] <-- (kernelSize[1]-1) * dilationFactor[1] + 1 - * - * Type of padding can be ``valid`` or ``same``. Output spatial dimensions depend on the - * the type of padding. For details, refer to the descriptions of the messages "ValidPadding" - * and "SamePadding". Padded values are all zeros. - * - * For Deconvolution, ``ConvolutionPaddingType`` (``valid`` or ``same``) is ignored when ``outputShape`` is set. - * - * - */ -message ConvolutionLayerParams { - - /** - * The number of kernels. - * Same as ``C_out`` used in the layer description. - */ - uint64 outputChannels = 1; - - /** - * Channel dimension of the kernels. - * Must be equal to ``inputChannels / nGroups``, if isDeconvolution == False - * Must be equal to ``inputChannels``, if isDeconvolution == True - */ - uint64 kernelChannels = 2; - - /** - * Group convolution, i.e. weight reuse along channel axis. - * Input and kernels are divided into g groups - * and convolution / deconvolution is applied within the groups independently. - * If not set or 0, it is set to the default value 1. - */ - uint64 nGroups = 10; - - /** - * Must be length 2 in the order ``[H, W]``. - * If not set, default value ``[3, 3]`` is used. - */ - repeated uint64 kernelSize = 20; - - /** - * Must be length 2 in the order ``[H, W]``. - * If not set, default value ``[1, 1]`` is used. - */ - repeated uint64 stride = 30; - - /** - * Must be length 2 in order ``[H, W]``. - * If not set, default value ``[1, 1]`` is used. - * It is ignored if ``isDeconvolution == true``. - */ - repeated uint64 dilationFactor = 40; - - /** - * The type of padding. - */ - oneof ConvolutionPaddingType { - ValidPadding valid = 50; - SamePadding same = 51; - } - - /** - * Flag to specify whether it is a deconvolution layer. - */ - bool isDeconvolution = 60; - - /** - * Flag to specify whether a bias is to be added or not. - */ - bool hasBias = 70; - - /** - * Weights associated with this layer. - * If convolution (``isDeconvolution == false``), weights have the shape - * ``[outputChannels, kernelChannels, kernelHeight, kernelWidth]``, where kernelChannels == inputChannels / nGroups - * If deconvolution (``isDeconvolution == true``) weights have the shape - * ``[kernelChannels, outputChannels / nGroups, kernelHeight, kernelWidth]``, where kernelChannels == inputChannels - */ - WeightParams weights = 90; - WeightParams bias = 91; /// Must be of size [outputChannels]. - - /** - * The output shape, which has length 2 ``[H_out, W_out]``. - * This is used only for deconvolution (``isDeconvolution == true``). - * If not set, the deconvolution output shape is calculated - * based on ``ConvolutionPaddingType``. - */ - repeated uint64 outputShape = 100; - -} - -/** - * A layer that performs a 3-dimensional convolution. - * - * .. code:: - * - * y = Convolution3DLayer(x) - * - * Input - * A blob of rank 5. - * The input blob's shape should be ``[batch, channels, depth, height, width]``. - * - * Fields - * The bias field, if set, should have shape of ``[channelsOut]``. - * - * Output - * A blob of rank 5. - * The output blob's shape is ``[batch, channelsOut, depthOut, heightOut, widthOut]``. - * - * Type of padding can be ``custom``, ``valid``, or ``same``. Padded values are all zeros. - * Output spatial dimensions depend on the the type of padding. For details, refer to the - * descriptions of the ``PaddingType`` field of this ``Convolution3DLayerParams`` message. - * - * Example - * For example, given an input of size ``[1, 3, 3, 8, 8]``, a stride of 2 in each dimension, - * a kernel of 3 in each dimension, 2 output channels, and ``same`` padding, this layer will - * compute the total padding applied in the depth, height, and width dimensions to be 2, 1, and 1, - * respectively. The depth padding is even and will be applied equally to both sides of the depth - * dimension. Since the height and width padding values are odd, they'll be applied to the - * bottom/right of the height/width dimensions. Thus, the padding applied to the input will be - * ``[1, 1, 0, 1, 0, 1]`` (front, back, top, bottom, left, right). Finally, the output produced - * will have size ``[1, 2, 2, 4, 4]``. - * - */ -message Convolution3DLayerParams { - - /** - * The number of channels in the output (channelsOut). Must be a positive integer. - */ - int32 outputChannels = 1; - - /** - * The number of channels in the input (channels). Must be a positive integer. - */ - int32 inputChannels = 2; - - /** - * Group convolution, i.e., weight reuse along the channel axis. - * It must evenly divide both the number of input and output channels and be at most the number - * of input channels (a depthwise convolution). - * Input and kernels are divided into g groups and convolution is applied within the groups - * independently. - */ - int32 nGroups = 10; - - /* Depth of the convolution kernel. Must be a positive integer. - */ - int32 kernelDepth = 20; - - /* Height of the convolution kernel. Must be a positive integer. - */ - int32 kernelHeight = 21; - - /* Width of the convolution kernel. Must be a positive integer. - */ - int32 kernelWidth = 22; - - /* Stride along the depth direction. Must be a positive integer. - */ - int32 strideDepth = 31; - - /* Stride along the height direction. Must be a positive integer. - */ - int32 strideHeight = 32; - - /* Stride along the width direction. Must be a positive integer. - */ - int32 strideWidth = 33; - - /* Dilation along the depth direction. Must be a positive integer. - */ - int32 dilationDepth = 40; - - /* Dilation along the height direction. Must be a positive integer. - */ - int32 dilationHeight = 41; - - /* Dilation along the width direction. Must be a positive integer. - */ - int32 dilationWidth = 42; - - /** - * Flag to specify whether a bias is to be added or not. - * If false, then no bias is added. - */ - bool hasBias = 50; - - /** - * Weights associated with this layer. - * Weights have the shape - * if deconvolution == False - * ``[outputChannels, kernelChannels, kernelDepth, kernelHeight, kernelWidth]``, where - * kernelChannels == inputChannels / nGroups - * else if deconvolution == True - * ``[outputChannels / nGroups, kernelChannels, kernelDepth, kernelHeight, kernelWidth]``, where - */ - WeightParams weights = 60; - - /** - * Must be of size ``[outputChannels]``. - */ - WeightParams bias = 61; - - - /** - * The type of padding. - * All padding types pad the input shape with zeros. - * CUSTOM padding will add the custom padding values specified below to their respective - * dimensions, e.g., `customPaddingFront` number of zeros will be added to one side of the - * input's depth dimension and `customPaddingBack` number of zeros will be added to the other - * side of the input's depth dimension. - * VALID padding adds no padding to any dimension. In this case, the last convolution along - * each dimension will be dropped if the input dimension and the kernel size, stride, and - * dilation do not match. - * SAME padding adds enough padding to each dimension such that the output of the convolution - * has size ``Ceiling(inputShape / stride)``. Padding is added evenly to both sides of each - * dimension unless the total padding to add is odd, in which case it is added to the - * back/bottom/right side of the respective dimension. For example, if the total padding needed - * in the depth dimension is 3, 1 zero will be added to the front side of the depth dimension - * and 2 zeros will be added to the back side. - */ - enum PaddingType { - CUSTOM = 0; - VALID = 1; - SAME = 2; - } - PaddingType paddingType = 70; - - /* Padding before the input in the depth direction. Must be zero or a positive integer. - * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. - */ - int32 customPaddingFront = 80; - - /* Padding after the input in the depth direction. Must be zero or a positive integer. - * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. - */ - int32 customPaddingBack = 81; - - /* Padding before the input in the height direction. Must be zero or a positive integer. - * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. - */ - int32 customPaddingTop = 82; - - /* Padding after the input in the height direction. Must be zero or a positive integer. - * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. - */ - int32 customPaddingBottom = 83; - - /* Padding before the input in the width direction. Must be zero or a positive integer. - * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. - */ - int32 customPaddingLeft = 84; - - /* Padding after the input in the width direction. Must be zero or a positive integer. - * Used when the `PaddingType` is `CustomPadding`, otherwise ignored by other padding types. - */ - int32 customPaddingRight = 85; - - /* Flag to specify if this is Convolution Transpose or not. - */ - bool isDeconvolution = 86; - - /* - * The output shape, which has length 3 ``[D_out, H_out, W_out]``. - * This is used only for deconvolution (``isDeconvolution == true``). - * If not set, the deconvolution output shape is calculated - * based on ``PaddingType``. - */ - repeated uint64 outputShape = 87; - -} - -/** - * A layer that performs a matrix-vector or matrix-matrix product. - * This is equivalent to a fully-connected, or dense layer. - * The weight parameters correspond to a matrix of dimensions (inputChannels, outputChannels) i.e. (C_in, C_out) - * - * .. code:: - * - * y = InnerProductLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * Input can have rank 1 to rank 5. This is how it is reshaped in to the matrix (for rank > 1): - * rank 1 (x1) : in this case, the layer corresponds to a matrix-vector product. x1 must be equal to C_in - * rank 2 (x1, x2): x2 must be equal to C_in - * rank 3 (x1, x2, x3) --> (x1 * x2, x3). x3 must be equal to C_in - * rank 4 (x1, x2, x3, x4) ---> (x1, x2 * x3 * x4). x2 * x3 * x4 must be equal to C_in - * rank 5 (x1, x2, x3, x4, x5) ---> (x1 * x2, x3 * x4 * x5). x3 * x4 * x5 must be equal to C_in - * - * Output - * Output rank is same as the input rank - * rank 1: (C_out) - * rank 2: (x1, C_out) - * rank 3: (x1, x2, C_out) - * rank 4: (x1, C_out, 1, 1) - * rank 5: (x1, x2, C_out, 1, 1) - * - */ -message InnerProductLayerParams { - - uint64 inputChannels = 1; /// Input size: C_in. - uint64 outputChannels = 2; /// Output size: C_out. - - bool hasBias = 10; /// Whether a bias is added or not. - - WeightParams weights = 20; /// Weight matrix [C_out, C_in]. - WeightParams bias = 21; /// Bias vector [C_out]. - - /** - * If set, this layer, at runtime, quantizes the floating point input blob to int8 before applying an - * inner product using INT8 weight matrix parameters, as provided in weights->int8RawValue. The - * result is then dequantized. - * Requires: - * * hasBias == false - * * QuantizationType == LinearQuantizationParams, such that - * * size of the "scale" field is 1 and "bias" field is empty in "LinearQuantizationParams" - * * numberOfBits == 8 - * * weights->rawValue_size to be empty - */ - bool int8DynamicQuantize = 22; - -} - -/** - * A layer that performs a matrix lookup and optionally adds a bias. - * The weights matrix is stored with dimensions [outputChannels, inputDim]. - * - * .. code:: - * - * y = EmbeddingLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * Input values must be in the range ``[0, inputDim - 1]``. - * - * Input must have rank equal to 4 or 5, such that the last 3 dimensions are all 1. - * rank 4: shape (x1, 1, 1, 1). x1 is effectively the batch/sequence length. - * rank 5: shape (x1, x2 , 1, 1, 1). x1 * x2 is effectively the combined batch/sequence length. - * - * Output - * Output rank is same as the input rank. Please see input description above. - * rank 4: shape (x1, outputChannels, 1, 1) - * rank 5: shape (x1, x2, outputChannels, 1, 1) - * - */ -message EmbeddingLayerParams { - - uint64 inputDim = 1; /// Size of the input dictionary. - uint64 outputChannels = 2; /// Size of the output vectors. - - bool hasBias = 10; /// Whether a bias is added or not. - - WeightParams weights = 20; /// 2-D weights of dimensions [outputChannels, inputDim]. - WeightParams bias = 21; /// Bias of size [outputChannels]. - -} - -/** - * A layer that performs a matrix lookup and optionally adds a bias. - * The weights matrix is stored with dimensions [embeddingSize, vocabSize]. - * - * .. code:: - * - * y = EmbeddingNDLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * Input values must be in the range ``[0, vocabSize - 1]``. - * Input must have rank at least 2. The last dimension must always be 1. - * rank 2: shape (x1, 1). x1 is the batch/sequence length. - * rank 3: shape (x1, x2, 1). x1 * x2 is effectively the combined batch/sequence length. - * rank 4: shape (x1, x2, x3, 1). x1 * x2 * x2 is effectively the combined batch/sequence length. - * rank 5: shape (x1, x2 , x3, x4, 1). x1 * x2 * x3 * x4 is effectively the combined batch/sequence length. - * - * Output - * Output rank is same as the input rank. Please see input description above. - * rank 2: shape (x1, embeddingSize) - * rank 3: shape (x1, x2, embeddingSize) - * rank 4: shape (x1, x2, x3, embeddingSize) - * rank 5: shape (x1, x2, x3, x4, embeddingSize) - * - */ -message EmbeddingNDLayerParams { - - uint64 vocabSize = 1; /// Size of the input dictionary. - uint64 embeddingSize = 2; /// Size of the output vectors. - bool hasBias = 3; /// Whether a bias is added or not. - WeightParams weights = 20; /// 2-D weights of dimensions [embeddingSize, vocabSize]. - WeightParams bias = 21; /// Bias of size [embeddingSize]. - -} - -/** - * A layer that performs batch normalization, - * which is performed along axis = -3, - * and repeated along the other axes, if present. - * - * .. code:: - * - * y = BatchnormLayer(x) - * - * Requires 1 input and produces 1 output. - * - * This operation is described by the following formula: - * - * .. math:: - * y_i = \gamma_i \dfrac{ (x_i - \mu_i)}{\sqrt{\sigma_i^2 + \epsilon}} + \beta_i \;,\;i=1,....,C - * - * Input - * A blob with rank greater than equal to 3. - * Example: Rank 4 blob represents [Batch, channels, height, width] - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * Output - * A blob with the same shape as the input. - */ -message BatchnormLayerParams { - - uint64 channels = 1; /// Size of the channel dimension in the input. - - /** - * If ``computeMeanVar == true``, - * the mean and variance are calculated from either - * the single input instance, if ``instanceNormalization == true``, - * or the whole batch, if ``instanceNormalization = false``. - * and the values provided in parameters "mean" and "variance" are ignored. - */ - bool computeMeanVar = 5; - bool instanceNormalization = 6; - - /** - * A small constant to avoid division by 0 while normalizing by variance. - * Defaults to ``1e-5`` if not set or set to ``0``. - */ - float epsilon = 10; - - WeightParams gamma = 15; /// Parameter of length [channels] - WeightParams beta = 16; /// Parameter of length [channels] - WeightParams mean = 17; /// Parameter of length [channels] - WeightParams variance = 18; /// Parameter of length [channels] - -} - -/** - * A spatial pooling layer. - * - * .. code:: - * - * y = PoolingLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank greater than equal to 4. - * Rank 4 blob represents [Batch, channels, height, width] - * For ranks greater than 4, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * Output - * Rank is same as the input. e.g.: for rank 4 input, output shape is [B, C, H_out, W_out] - * - * Padding options are similar to ``ConvolutionLayerParams`` - * with the additional option of ``ValidCompletePadding`` (``includeLastPixel``), - * which ensures that the last application of the kernel - * always includes the last pixel of the input image, if there is padding. - * - * .. code:: - * - * H_out = ceil(float(H_in + 2 * paddingAmounts[0] - kernelSize[0])/float(Stride[0])) + 1 - * if (paddingAmounts[0] > 0 or paddingAmounts[1] > 0) - * if ((H_out - 1) * Stride >= H_in + paddingAmounts[0]) { - * H_out = H_out - 1 - * } - * } - * - * The equivalent expressions hold true for ``W_out`` as well. - * Only symmetric padding is supported with this option. - */ -message PoolingLayerParams { - - enum PoolingType { - - MAX = 0; - AVERAGE = 1; - L2 = 2; - - } - PoolingType type = 1; /// Type of pooling operation. - - /** - * Must be length 2 in the order ``[H, W]``. - * If not set, default value ``[3, 3]`` is used. - */ - repeated uint64 kernelSize = 10; - - /** - * Must be length 2 in the order ``[H, W]``. - * If not set, default value ``[1, 1]`` is used. - */ - repeated uint64 stride = 20; - - message ValidCompletePadding { - - /** - * Must be length 2 in order ``[H, W]``. - * If not set, value ``[0, 0]`` is used. - */ - repeated uint64 paddingAmounts = 10; - - } - - oneof PoolingPaddingType { - ValidPadding valid = 30; - SamePadding same = 31; - ValidCompletePadding includeLastPixel = 32; - } - - /** - * If true, padded values are excluded from the count (denominator) - * when computing average pooling. - */ - bool avgPoolExcludePadding = 50; - - /** - * If true, global pooling is performed. - * Kernel size is inferred from the input data spatial dimensions. - */ - bool globalPooling = 60; - -} - -/* - * A layer to pool three spatial dimensions - * - * Input - * A blob with rank equal to 5, representing [Batch, channels, depth, height, width]. - * - * Output - * Rank is same as the input: A blob with rank equal to 5, representing [Batch, channels, depth, height, width]. - * - * Requires 1 input and produces 1 output. - * - * For example, given an input of shape (1,1,2,3,3): - * +----+----+----+ - * / | 10 | 11 | 12 | - * / +----+----+----+ - * / | 13 | 14 | 15 | - * / +----+----+----+ - * / | 16 | 17 | 18 | - * / +----+----+----+ - * +----+----+----+ / - * | 1 | 2 | 3 | / - * +----+----+----+ / - * | 4 | 5 | 6 | / - * +----+----+----+ / - * | 7 | 8 | 9 | / - * +----+----+----+ - * - * And applying MAX pooling using: - * Kernel: 2x2x2 - * Stride: 1x1x1 - * Valid Padding - * We expect to get an output with shape: (1,1,1,2,2) and value: - * +----+----+ - * | 14 | 15 | - * +----+----+ - * | 17 | 18 | - * +----+----+ - */ -message Pooling3DLayerParams { - - enum PoolingType3D { - MAX = 0; - AVERAGE = 1; - } - - // Whether to use Max or Average - PoolingType3D type = 1; - - // Depth of the pooling region. - int32 kernelDepth = 2; - - // Height of the pooling region. - int32 kernelHeight = 3; - - // Width of the pooling region. - int32 kernelWidth = 4; - - // Stride along the depth direction - int32 strideDepth = 5; - - // Stride along the height direction - int32 strideHeight = 6; - - // Stride along the width direction - int32 strideWidth = 7; - - /** - * The type of padding. - * All padding types pad the input shape with zeros. - * CUSTOM padding will add the custom padding values specified below to their respective - * dimensions, e.g., `customPaddingFront` number of zeros will be added to one side of the - * input's depth dimension and `customPaddingBack` number of zeros will be added to the other - * side of the input's depth dimension. - * VALID padding adds no padding to any dimension. In this case, the last pool along - * each dimension will be dropped if the input dimension and the kernel size, and stride do not match. - * SAME padding adds enough padding to each dimension such that the output - * has the same spatial dimensions as the input. Padding is added evenly to both - * sides of each dimension unless the total padding to add is odd, in which case the extra padding - * is added to the back/bottom/right side of the respective dimension. For example, if the the - * total horizontal padding is 3, then there will be 1 padding on the left, and 2 padding on the right. - */ - enum Pooling3DPaddingType { - CUSTOM = 0; - VALID = 1; - SAME = 2; - } - Pooling3DPaddingType paddingType = 15; - - // Padding before the input in the depth direction. - int32 customPaddingFront = 8; - - // Padding after the input in the depth direction. - int32 customPaddingBack = 9; - - // Padding before the input in the height direction. - int32 customPaddingTop = 10; - - // Padding after the input in the height direction. - int32 customPaddingBottom = 11; - - // Padding before the input in the width direction. - int32 customPaddingLeft = 12; - - // Padding after the input in the width direction. - int32 customPaddingRight = 13; - - // If true, exclude zeros from padding in Average pooling. Meaningless in Max Pooling. - bool countExcludePadding = 14; -} - -/* - * A layer to pool three spatial dimensions down to one value. - * This behaves like a special case of Pooling3DLayerParams in which - * the Kernel is the size of the input and there is no padding. - * - * Input - * A blob with rank equal to 5, representing [Batch, channels, depth, height, width]. - * - * Output - * Rank is same as the input: A blob with rank equal to 5, representing [Batch, channels, depth, height, width]. - * Depth, height, and width of the output will always be 1. - * - * Requires 1 input and produces 1 output. - * - * For example, given an input of shape (1,1,2,3,3): - * +----+----+----+ - * / | 10 | 11 | 12 | - * / +----+----+----+ - * / | 13 | 14 | 15 | - * / +----+----+----+ - * / | 16 | 17 | 18 | - * / +----+----+----+ - * +----+----+----+ / - * | 1 | 2 | 3 | / - * +----+----+----+ / - * | 4 | 5 | 6 | / - * +----+----+----+ / - * | 7 | 8 | 9 | / - * +----+----+----+ - * - * And applying MAX global 3d pooling, we expect to get an output with shape: (1,1,1,1,1) and value: - * +----+ - * | 18 | - * +----+ - */ -message GlobalPooling3DLayerParams { - - enum GlobalPoolingType3D { - MAX = 0; - AVERAGE = 1; - } - - // Whether to use Max or Average - GlobalPoolingType3D type = 1; -} - -/** - * A layer that performs padding along spatial dimensions. - * - * .. code:: - * - * y = PaddingLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank at least 2. - * e.g.: blob with shape ``[H_in, W_in]``. - * For ranks greater than 2, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch - * i.e. Padding is applied on last two dimensions. - * - * Output - * Same rank as the input. - * e.g.: blob with shape ``[H_out, W_out]``. - * - * Output dimensions are calculated as follows: - * - * .. code:: - * - * H_out = H_in + topPaddingAmount + bottomPaddingAmount - * W_out = W_in + leftPaddingAmount + rightPaddingAmount - * - * topPaddingAmount == Height startEdgeSize == borderAmounts[0].startEdgeSize - * bottomPaddingAmount == Height endEdgeSize == borderAmounts[0].endEdgeSize - * leftPaddingAmount == Width startEdgeSize == borderAmounts[1].startEdgeSize - * rightPaddingAmount == Width endEdgeSize == borderAmounts[1].endEdgeSize - * - * There are three types of padding: - * - * - ``PaddingConstant``, which fills a constant value at the border. - * - ``PaddingReflection``, which reflects the values at the border. - * - ``PaddingReplication``, which replicates the values at the border. - * - * Given the following input: - * - * .. code:: - * - * [1, 3, 4] : 1 2 3 4 - * 5 6 7 8 - * 9 10 11 12 - * - * Here is the output of applying the padding - * ``(top=2, left=2, bottom=0, right=0)`` - * with each of the supported types: - * - * - ``PaddingConstant`` (``value = 0``): - * .. code:: - * - * [1, 5, 6] : 0 0 0 0 0 0 - * 0 0 0 0 0 0 - * 0 0 1 2 3 4 - * 0 0 5 6 7 8 - * 0 0 9 10 11 12 - * - * - ``PaddingReflection``: - * .. code:: - * - * [1, 5, 6] : 11 10 9 10 11 12 - * 7 6 5 6 7 8 - * 3 2 1 2 3 4 - * 7 6 5 6 7 8 - * 11 10 9 10 11 12 - * - * - ``PaddingReplication``: - * .. code:: - * - * [1, 5, 6] : 1 1 1 2 3 4 - * 1 1 1 2 3 4 - * 1 1 1 2 3 4 - * 5 5 5 6 7 8 - * 9 9 9 10 11 12 - */ -message PaddingLayerParams { - - /** - * Fill a constant value in the padded region. - */ - message PaddingConstant { - float value = 1; - } - - /** - * Reflect the values at the border for padding. - */ - message PaddingReflection { - } - - /** - * Replicate the values at the border for padding. - */ - message PaddingReplication { - } - - oneof PaddingType { - PaddingConstant constant = 1; - PaddingReflection reflection = 2; - PaddingReplication replication = 3; - } - - BorderAmounts paddingAmounts = 10; /// Amounts to be padded to the input. - -} - -/** - * A layer that concatenates along the axis = -3 or -5. - * For general concatenation along any axis, see ConcatNDLayer. - * - * .. code:: - * - * y = ConcatLayer(x1,x2,....) - * - * Requires more than 1 input and produces 1 output. - * - * Input - * All input blobs must have same rank. - * If "sequenceConcat" = False, rank must be greater than equal to 3. In this case concatenation is along axis = -3 - * If "sequenceConcat" = True, rank must be greater than equal to 5. In this case concatenation is along axis = -5 - * - * Output - * Same rank as the input. - * - */ -message ConcatLayerParams { - - /** - * If true, concatenate along the axis = -5 instead of axis = -3. - */ - bool sequenceConcat = 100; - -} - -/** - * A layer that performs local response normalization (LRN). - * - * .. code:: - * - * y = LRNLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank greater than equal to 3. - * Example: Rank 4 blob represents [Batch, channels, height, width] - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * Output - * A blob with the same shape as the input. - * - * This layer is described by the following formula: - * - * .. math:: - * x_i \leftarrow \dfrac{x_i}{\left ( k + \dfrac{\alpha}{C} \sum_j x_j^2 \right )^\beta} - * - * where the summation is done over a ``(localSize, 1, 1)`` neighborhood --- - * that is, over a window "across" channels in 1x1 spatial neighborhoods. - */ -message LRNLayerParams { - - float alpha = 1; - float beta = 2; - uint64 localSize = 3; /// Number of channels in the normalization window. - float k = 4; /// Defaults to 1 if not set or 0. Must be strictly positive. - -} - -/** - * Softmax Normalization Layer - * - * A layer that performs softmax normalization. - * Normalization is applied along axis = -3 or N-3 (where N is the rank of the input) - * For softmax layer that can operate on any axis, see SoftmaxNDLayer. - * - * - * .. code:: - * - * y = SoftmaxLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * Must be a blob with rank >= 3. - * Output - * A blob with the same shape as the input. - * - * This layer is described by the following formula: - * - * .. math:: - * x_i \leftarrow \dfrac{e^{x_i}}{\sum_i{e^{x_i}}} - */ -message SoftmaxLayerParams { - -} - -/** - * A layer that uniformly splits across axis = -3 to produce a specified number of outputs. - * For general split operation along any axis, see SplitNDLayer. - * - * .. code:: - * - * (y1,y2,...yN) = SplitLayer(x), where N = nOutputs - * - * Requires 1 input and produces multiple outputs. - * - * Input - * A blob with rank at least 3. - * e.g.: blob with shape ``[C, H, W]`` - * Output - * ``nOutputs`` blobs each with same rank as the input. - * e.g.: For input that is of shape ``[C, H, W]``, output shapes will be ``[C/nOutputs, H, W]`` - */ -message SplitLayerParams { - - uint64 nOutputs = 1; /// The number of outputs. - -} - -/** - * A layer that performs elementwise addition. - * This layer has limited broadcasting support. For general broadcasting see AddBroadcastableLayer. - * - * .. code:: - * - * y = AddLayer(x1,x2,...) - * - * Requires 1 or more than 1 input and produces 1 output. - * - * Input - * In general, there are no rank constraints. - * However, only certain set of shapes are broadcastable. For example: - * [B, 1, 1, 1], [B, C, 1, 1], [B, 1, H, W], [B, C, H, W] - * Output - * A blob with shape equal to the input blob. - * - * If only one input is provided, scalar addition is performed: - * - * .. math:: - * y = x + \alpha - * - */ -message AddLayerParams { - - /** - * Scalar to be added to the input. - * Only used if there is a single input. - */ - float alpha = 1; - -} - -/** - * A layer that performs elementwise multiplication. - * This layer has limited broadcasting support. For general broadcasting see MultiplyBroadcastableLayer. - * - * .. code:: - * - * y = MultiplyLayer(x1,x2,...) - * - * Requires 1 or more than 1 input and produces 1 output. - * - * Input - * In general, there are no rank constraints. - * However, only certain set of shapes are broadcastable. For example: - * [B, 1, 1, 1], [B, C, 1, 1], [B, 1, H, W], [B, C, H, W] - * Output - * A blob with shape equal to the first input blob. - * - * If only one input is provided, scalar multiplication is performed: - * - * .. math:: - * y = \alpha x - * - */ -message MultiplyLayerParams { - - /** - * Scalar to be multiplied with the input. - * Only used if there is a single input. - */ - float alpha = 1; - -} - -/** - * A layer that applies a unary function. - * - * .. code:: - * - * y = UnaryFunctionLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with no rank constraints. - * Output - * A blob with the same shape as the input. - * - * The input is first modified by shifting and scaling: - * - * .. math:: - * x \leftarrow \text{scale} \cdot x + \text{shift} - */ -message UnaryFunctionLayerParams { - - /** - * A unary operator. - * - * The following functions are supported: - * - * ``SQRT`` - * .. math:: f(x) = \sqrt{x} - * - * ``RSQRT`` - * .. math:: f(x) = \dfrac{1}{\sqrt{x + \epsilon}} - * - * ``INVERSE`` - * .. math:: f(x) = \dfrac{1}{x + \epsilon} - * - * ``POWER`` - * .. math:: f(x) = x^\alpha - * - * ``EXP`` - * .. math:: f(x) = e^x - * - * ``LOG`` - * .. math:: f(x) = \log x - * - * ``ABS`` - * .. math:: f(x) = |x| - * - * ``THRESHOLD`` - * .. math:: f(x) = \text{max}(\alpha, x) - */ - enum Operation { - SQRT = 0; - RSQRT = 1; - INVERSE = 2; - POWER = 3; - EXP = 4; - LOG = 5; - ABS = 6; - THRESHOLD = 7; - } - Operation type = 1; /// The type of unary function. - - /** - * A constant used in ``POWER`` and ``THRESHOLD`` functions. - */ - float alpha = 2; - - /** - * A small constant to avoid division by 0 while normalizing variance. - * Defaults to ``1e-6`` if not set or set to ``0``. - */ - float epsilon = 3; - - /** - * Input is shifted by this amount - * before the unary function is applied. - * Defaults to ``0.0`` if not set. - */ - float shift = 4; - - /** - * Input is scaled by this amount - * before the unary function is applied. - * Defaults to ``1.0`` if not set or set to ``0``. - */ - float scale = 5; - -} - -/** - * A layer that scales up spatial dimensions. - * It supports two modes: nearest neighbour (default) and bilinear. - * - * .. code:: - * - * y = UpsampleLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank at least 3. - * e.g.: blob with shape ``[C, H, W]``. - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * Output - * Same rank as the input. - * e.g.: blob with shape ``[C, scalingFactor[0] * H, scalingFactor[1] * W]`` - */ -message UpsampleLayerParams { - - /** - * Scaling Factor. Mutually exclusive with fractionalScalingFactor. - * Must be length 2 in order ``[H, W]``. - * If not set, default value ``[1, 1]`` is used. - */ - repeated uint64 scalingFactor = 1; - - /** - * Fractional scaling factor. Mutually exclusive with scalingFactor. - * Must be length 2 in order ``[H, W]``. - * If not set, default value ``[1.0, 1.0]`` is used. - */ - repeated float fractionalScalingFactor = 7; - - /* - * Overall mode for interpolating new elements when upsampling. - * NN - Nearest Neighbors - simply pick the nearest true value for interpolated values. - * BILINEAR - Use bilinear interpolation. See LinearUpsamplingMode for behavior. - */ - enum InterpolationMode { - - NN = 0; /// Nearest Neighbour - BILINEAR = 1; /// Bilinear - - } - - InterpolationMode mode = 5; - - /** - * LinearUpsampleMode specifies the behavior for linear upsampling. Only valid when Interpolation Mode is BILINEAR. - * If input grid is [0, Xin-1] (corresponding to an input size of Xin), and if the output size is Xout, - * then the grid points are sampled in the following manner: - * DEFAULT: - * spacing = (Xin-Xin/Xout) / (Xout-1) - * grid_point[i] = min(Xin-1, max(0, i * spacing)), for i = 0,1,2,….,Xout-1 - * ALIGN_CORNERS_TRUE: - * spacing = (Xin-1) / (Xout-1) - * grid_point[i] = min(Xin-1, max(0, i * spacing)), for i = 0,1,2,….,Xout-1 - * ALIGN_CORNERS_FALSE: - * spacing = Xin / Xout - * grid_point[i] = min(Xin-1, max(0, i * spacing + 0.5 * spacing - 0.5)), for i = 0,1,2,….,Xout-1 - */ - enum LinearUpsampleMode { - - DEFAULT = 0; - ALIGN_CORNERS_TRUE = 1; - ALIGN_CORNERS_FALSE = 2; - - } - - LinearUpsampleMode linearUpsampleMode = 6; - -} - -/** -* A layer that resizes the input to a pre-specified spatial size using bilinear interpolation. -* -* .. code:: -* -* y = ResizeBilinearLayer(x) -* -* Requires 1 input and produces 1 output. -* -* Input -* A blob with rank at least 3. -* e.g.: blob with shape ``[C, H_in, W_in]``. -* For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. -* -* Output -* Same rank as the input. -* e.g.: blob with shape ``[C, H_out, W_out]``. -* -*/ -message ResizeBilinearLayerParams { - - /** - * Target Spatial Size. - * Must be length 2 in order ``[Height, Width]``, i.e. ``[H_out, W_out]``. - * If not set, default value ``[1, 1]`` is used. - */ - repeated uint64 targetSize = 1; - - /** - * Mode used to compute the grid on which the spatial output values are evaluated. - * Same mode is applied to both the height and width axes. - */ - SamplingMode mode = 2; - -} - -/** -* A layer that extracts cropped spatial patches or RoIs (regions of interest) from the input and resizes them to a pre-specified size using -* bilinear interpolation. -* Note that RoI Align layer can be implemented with this layer followed by a pooling layer. -* -* .. code:: -* -* y = CropResizeLayer(x) -* -* Requires 2 inputs and produces 1 output. -* -* Input -* There are two inputs. -* First input represents an image feature map. -* Second input represents the bounding box coordinates for N patches or RoIs (region of interest). -* -* First input is rank 5: [1, Batch, C, H_in, W_in]. -* Second input is rank 5. Its shape can be either [N, 1, 4, 1, 1] or [N, 1, 5, 1, 1]. -* -* N: number of patches/RoIs to be extracted -* -* If RoI shape = ``[N, 1, 4, 1, 1]`` -* The axis=-3 corresponds to the four coordinates specifying the bounding box. -* All the N RoIs are extracted from all the batches of the input. -* -* If RoI shape = ``[N, 1, 5, 1, 1]`` -* The first element of the axis=-3 specifies the input batch id from which to extract the RoI and -* must be in the interval ``[0, Batch - 1]``. That is, n-th RoI is extracted from the RoI[n,0,0,0,0]-th -* input batch id. The last four elements of the axis=-3 specify the bounding box coordinates. -* -* Output -* A blob with rank 5. -* - Shape is [N, Batch, C, H_out, W_out] if input RoI shape is [N, 1, 4, 1, 1] -* - Shape is [N, 1, C, H_out, W_out] if input RoI shape is [N, 1, 5, 1, 1] -* -*/ -message CropResizeLayerParams { - - /** - * Target Spatial Size. - * Must be length 2 in order ``[Height, Width]``, i.e. ``[H_out, W_out]``. - * If not set, default value ``[1, 1]`` is used. - */ - repeated uint64 targetSize = 1; - - /** - * If true the bounding box coordinates must be in the interval [0, 1]. - * They are scaled by (H_in - 1), (W_in - 1), i.e. based on the input spatial dimensions. - * If false the bounding box coordinates must be in the interval - * [0, H_in -1] and [0, W_in - 1], respectively for height and width dimensions. - */ - bool normalizedCoordinates = 2; - - /** - * Mode used to compute the grid on which the spatial output values are evaluated. - * Same mode is applied to both the height and width axes. - */ - SamplingMode mode = 3; - - /** - * Representation used to express the bounding box coordinates. - * It determines how the values of the second input are interpreted. - */ - BoxCoordinatesMode boxIndicesMode = 4; - - /** - * Additional spatial scale that multiplies the bounding box coordinates. - * Generally used while implementing the RoI Align layer, - * which uses unnormalized RoI coordinates along with a spatial scale less than or equal to 1. - */ - float spatialScale = 5; - -} - -/** - * A layer that performs elementwise addition of a bias, - * which is broadcasted to match the input shape. - * - * .. code:: - * - * y = BiasLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank at least 3. - * e.g.: blob with shape ``[C, H, W]``. - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * Output - * A blob with the same shape as the input. - */ -message BiasLayerParams { - - /** - * The shape of the bias. - * Must be one of the following: - * ``[1]``, ``[C]``, ``[1, H, W]`` or ``[C, H, W]``. - */ - repeated uint64 shape = 1; - - /** - * The bias values. - * The size must be equal to the product of the ``shape`` dimensions. - */ - WeightParams bias = 2; - -} - -/** - * A layer that performs elmentwise multiplication by a scale factor - * and optionally adds a bias; - * both the scale and bias are broadcasted to match the input shape. - * - * .. code:: - * - * y = ScaleLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank at least 3. - * e.g.: blob with shape ``[C, H, W]``. - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * Output - * A blob with the same shape as the input. - */ -message ScaleLayerParams { - - /** - * The shape of the scale. - * Must be one of the following: - * ``[1]``, ``[C]``, ``[1, H, W]`` or ``[C, H, W]``. - */ - repeated uint64 shapeScale = 1; - - /** - * The scale values. - * The size must be equal to the product of the ``shape`` dimensions. - */ - WeightParams scale = 2; /// Scale values. Size must be equal to the product of dimensions specified in shapeScale. - - bool hasBias = 3; /// If true, a bias is added after scaling. - - /** - * The shape of the bias. - * Must be one of the following: - * ``[1]``, ``[C]``, ``[1, H, W]`` or ``[C, H, W]``. - */ - repeated uint64 shapeBias = 4; - - /** - * The bias values. - * The size must be equal to the product of the ``shape`` dimensions. - */ - WeightParams bias = 5; - -} - -/** - * A layer that loads data as a parameter and provides it as an output. - * The output is rank 5. For general rank, see LoadConstantNDLayer. - * - * .. code:: - * - * y = LoadConstantLayer() - * - * Requires no input and produces 1 output. - * - * Output: - * A blob with rank 5 and shape ``[1, 1, C, H, W]`` - */ -message LoadConstantLayerParams { - - /** - * The shape of the constant to be loaded, - * which must be``[C, H, W]``, that is length 3. - */ - repeated uint64 shape = 1; - - /** - * The data values, - * of size ``C * H * W``. - */ - WeightParams data = 2; - -} - -/** - * A layer that performs L2 normalization, i.e. divides by the - * the square root of the sum of squares of all elements of input. - * - * .. code:: - * - * y = L2NormalizeLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank greater than equal to 3. - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * Output - * A blob with the same shape as the input. - * - * This layer is described by the following formula: - * - * .. math:: - * x_i \leftarrow \dfrac{x_i}{\sqrt{\sum{x_i^2} + \epsilon}} - */ -message L2NormalizeLayerParams { - - /** - * A small constant to avoid division by 0 while normalizing variance. - * Defaults to ``1e-6`` if not set or set to ``0``. - */ - float epsilon = 1; - -} - -/// Data Reorganization Layers -/// -------------------------- - -/** - * A layer that flattens the input. - * - * .. code:: - * - * y = FlattenLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank greater than equal to 3. - * e.g.: Rank 4 blob represents [Batch, C, H, W] - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * Output - * Same rank as the input, such that last two dimensions are both 1. - * e.g.: For rank 4 input, output shape is ``[Batch, C * H * W, 1, 1]`` - * - * There are two X orders: ``CHANNEL_FIRST`` and ``CHANNEL_LAST``. - * ``CHANNEL_FIRST`` does not require data to be rearranged, - * because row major ordering is used by internal storage. - * ``CHANNEL_LAST`` requires data to be rearranged. - */ -message FlattenLayerParams { - - enum FlattenOrder { - - CHANNEL_FIRST = 0; - CHANNEL_LAST = 1; - - } - FlattenOrder mode = 1; - -} - -/** - * A layer that recasts the input into a new shape. - * - * .. code:: - * - * y = ReshapeLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank 5. - * e.g.: ``[1, 1, C, H, W]`` or ``[Seq, 1, C, H, W]``. - * Output - * A blob with rank 5. - * e.g.: ``[1, 1, C_out, H_out, W_out]`` or ``[Seq_out, 1, C_out, H_out, W_out]``. - * - * There are two reshape orders: ``CHANNEL_FIRST`` and ``CHANNEL_LAST``. - * ``CHANNEL_FIRST`` is equivalent to - * flattening the input to ``[Seq, 1, C * H * W, 1, 1]`` in channel first order - * and then reshaping it to the target shape; - * no data rearrangement is required. - * ``CHANNEL_LAST`` is equivalent to - * flattening the input to ``[Seq, 1, H * W * C, 1, 1]`` in channel last order, - * reshaping it to ``[Seq_out, 1, H_out, W_out, C_out]`` (it is now in "H_out-major"" order), - * and then permuting it to ``[C_out, H_out, W_out]``; - * both the flattening and permuting requires the data to be rearranged. - */ -message ReshapeLayerParams { - - /** - * The shape of the output. - * Must be of length 3 or 4. - * If set to 3, ``targetShape`` is interpreted as - * ``[1, 1, C_out, H_out, W_out]``, and sequence length of the input is preserved. - * If set to 4, ``targetShape`` is interpreted as - * ``[Seq_out, 1, C_out, H_out, W_out]``, - * where ``Seq_out`` is the new sequence length. - */ - repeated int64 targetShape = 1; - - enum ReshapeOrder { - - CHANNEL_FIRST = 0; - CHANNEL_LAST = 1; - - } - ReshapeOrder mode = 2; - -} - -/** - * A layer that rearranges the dimensions and data of an input. - * For generic transpose/permute operation see TransposeLayer. - * - * .. code:: - * - * y = PermuteLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * Must be a rank 5 blob. - * e.g.: shape ``[Seq, B, C, H, W]``. - * Output - * Rank 5 blob. Transposed version of the input, such that dimensions at axis=1 or axis=-4 is unchanged. - * - * - * Examples: - * - * Assume input shape is [Seq, B, C, H, W] - * - * - If ``axis`` is set to ``[0, 3, 1, 2]``, - * then the output has shape ``[Seq, B, W, C, H]`` - * - * - If ``axis`` is set to ``[3, 1, 2, 0]``, - * then the output has shape ``[W, B, C, H, Seq]`` - * - * - If ``axis`` is set to ``[0, 3, 2, 1]``, - * then the output has shape ``[Seq, B, W, H, C]`` - * - * - If ``axis`` is not set, or is set to ``[0, 1, 2, 3]``, - * the output is the same as the input. - */ -message PermuteLayerParams { - - /** - * The order in which to permute the dimensions. - * Must have length 4 and a permutation of ``[0, 1, 2, 3]``. - */ - repeated uint64 axis = 1; - -} - -/** - * A layer that reorganizes data in the input in specific ways. - * - * .. code:: - * - * y = ReorganizeDataLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank at least 3. - * e.g.: blob with shape ``[C, H, W]``. - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * Output - * Same rank as the input. - * e.g.: blob with shape ``[C_out, H_out, W_out]``. - * - * mode == SPACE_TO_DEPTH - * ``[C_out, H_out, W_out]`` : ``[C * blockSize * blockSize, H/blockSize, W/blockSize]``. - * blockSize must divide H and W. - * Data is moved from the spatial dimensions to the channel dimension. Input is spatially divided into - * non-overlapping blocks of size blockSize X blockSize and data from each block is moved into the - * channel dimension. - * - * mode == DEPTH_TO_SPACE - * ``[C_out, H_out, W_out]`` : ``[C/(blockSize * blockSize), H * blockSize, W * blockSize]``. - * Square of blockSize must divide C. - * Reverse of SPACE_TO_DEPTH. Data is moved from the channel dimension to the spatial dimensions. - * - * mode == PIXEL_SHUFFLE - * ``[C_out, H_out, W_out]`` : ``[C/(blockSize * blockSize), H * blockSize, W * blockSize]``. - * Square of blockSize must divide C. - * Similar to DEPTH_TO_SPACE, but using the pixel-shuffle semantics for channel order in the output space. - * In both modes, elements along the channel dimension are collapsed into - * blocks in the spatial dimensions. The difference is in the arrangement of - * the input-channels' data in the output space. See below example for more - * detail. - * (Only available in Core ML Specification >= 5 (iOS >= 14, macOS >= 11.0) - * - * - * Examples: - * - * Assume input is the following [C = 8, H = 1, W = 2] tensor: - * - * .. code:: - * - * [[[1 2]] [[3 4]] [[5 6]] [[7 8]] [[9 10]] [[11 12]] [[13 14]] [[15 16]]] - * - * If block_size == 2 and mode == DEPTH_TO_SPACE, output will be the following - * [C = 2, H = 2, W = 4] tensor: - * - * .. code:: - * - * [[[ 1 5 2 6] - * [ 9 13 10 14]] - * - * [[ 3 7 4 8] - * [11 15 12 16]]] - * - * For mode == SPACE_TO_DEPTH, the behavior is the same as mode == - * DEPTH_TO_SPACE, but with the input and output swapped. - * - * If block_size == 2 and mode == PIXEL_SHUFFLE, output will be the following - * [C = 2, H = 2, W = 4] tensor: - * - * .. code:: - * - * [[[ 1 3 2 4] - * [ 5 7 6 8]] - * - * [[ 9 11 10 12] - * [13 15 14 16]]] - * - */ -message ReorganizeDataLayerParams { - - enum ReorganizationType { - - SPACE_TO_DEPTH = 0; - DEPTH_TO_SPACE = 1; - PIXEL_SHUFFLE = 2; - - } - ReorganizationType mode = 1; - uint64 blockSize = 2; /// must be greater than 1 - -} - -/** - * A layer that slices the input data along axis = -1 or -2 or -3. - * For general slice along any axis, please see SliceStaticLayer/SliceDynamicLayer. - * - * .. code:: - * - * y = SliceLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob that can, in general, have any rank. However, depending on the value of "axis" , - * there may be additional rank constraints. - * Output - * A blob with the same rank as the input. - * - * Sliced section is taken from the interval ``[startIndex, endIndex)``, i.e. - * startIndex is inclusive while endIndex is exclusive. - * stride must be positive and represents the step size for slicing. - * Negative indexing is supported for startIndex and endIndex. - * -1 denotes N-1, -2 denotes N-2 and so on, where N is the length of the dimension to be sliced. - * - */ -message SliceLayerParams { - - int64 startIndex = 1; /// start of the sliced section. Inclusive. - int64 endIndex = 2; /// end of sliced section. Exclusive. - uint64 stride = 3; /// The step size. Must be positive. - - enum SliceAxis { - - CHANNEL_AXIS = 0; - HEIGHT_AXIS = 1; - WIDTH_AXIS = 2; - - } - // The following mapping is used for interpreting this parameter: - // CHANNEL_AXIS => axis = -3, input must have rank at least 3. - // HEIGHT_AXIS => axis = -2, input must have rank at least 2. - // WIDTH_AXIS => axis = -1 - SliceAxis axis = 4; - -} - -/** - * A layer that reduces the input using a specified operation. - * - * .. code:: - * - * y = ReduceLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob that can, in general, have any rank. However, depending on the value of "axis" , - * there may be additional rank constraints. - * Output - * A blob with the same rank as the input, which has 1s on the dimensions specified in the parameter "axis" - * - * Values supported for axis are [-1], [-2], [-3], [-2,-1], [-3,-2,-1] - * and the equivalent positive values (depending on the rank of the input) - * For mode == 'ArgMax', axis must be [-1] or [-2] or [-3]. - */ -message ReduceLayerParams { - - /* - * The following reduction operations are supported - * and are applied on the specified axis of the input array: - * - * ``SUM`` - * Sum of all elements - * - * .. math:: \sum{x_i} - * - * ``AVG`` - * Sum of all elements divided by the number of elements - * - * .. math:: \dfrac{\sum^n{x_i}}{n} - * - * ``PROD`` - * Product of all elements - * - * .. math:: \prod{x_i} - * - * ``LOGSUM`` - * Sum of the natural logarithm of all elements - * - * .. math:: \sum{\ln{(x_i + \epsilon)}} - * - * ``SUMSQUARE`` - * Sum of squares of all elements - * - * .. math:: \sum{x^2} - * - * ``L1`` - * L1 normalization of all elements - * - * .. math:: ||x||_1 = \sum{|x_i|} - * - * ``L2`` - * L2 normalization of all elements - * - * .. math:: ||x||_2 = \sqrt{\sum{x_i^2}} - * - * ``MAX`` - * Maximum of all elements - * - * .. math:: \text{max}(x_i) - * - * ``MIN`` - * Minumum of all elements - * - * .. math:: \text{min}(x_i) - * - * ``ARGMAX`` - * Argument of the maximum of all elements - * - * .. math:: \text{argmax}(x_i) - * - */ - enum ReduceOperation { - - SUM = 0; - AVG = 1; - PROD = 2; - LOGSUM = 3; - SUMSQUARE = 4; - L1 = 5; - L2 = 6; - MAX = 7; - MIN = 8; - ARGMAX = 9; /// only supported with axis = C, H or W. - - } - ReduceOperation mode = 1; /// Specifies function used to reduce. - - /** - * Used if mode is ``LOGSUM``. - * Defaults to ``1e-6`` if not set or is set to ``0``. - */ - float epsilon = 2; - - enum ReduceAxis { - - CHW = 0; - HW = 1; - C = 2; - H = 3; - W = 4; - - } - - // The following mapping is used for interpreting this parameter: - // CHW = axis [-3, -2, -1], input must have rank at least 3. - // HW = axis [-2, -1], input must have rank at least 2. - // C = axis [-3] - // H = axis [-2] - // W = axis [-1] - ReduceAxis axis = 3; - -} - -/** - * A layer that crops the spatial dimensions of an input. - * If two inputs are provided, the shape of the second input is used as the reference shape. - * - * .. code:: - * - * y = CropLayer(x1) or y = CropLayer(x1,x2) - * - * Requires 1 or 2 inputs and produces 1 output. - * - * Input - * 1 or 2 tensors, each with rank at least 3, both inputs must have equal rank. - * Example: - * - 1 input case: A blob with shape ``[C, H_in, W_in]``. - * - 2 input case: 1st blob with shape ``[C, H_in, W_in]``, 2nd blob with shape ``[C, H_out, W_out]``. - * - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * Output - * Same rank as the inputs. - * e.g.: A blob with shape ``[C, H_out, W_out]``. - * - * If one input is used, output is computed as follows: - * - * .. code:: - * - * y = x1[:, topCropAmount:H_in - bottomCropAmount, leftCropAmount:W_in - rightCropAmount] - * - * topCropAmount == Height startEdgeSize == borderAmounts[0].startEdgeSize - * bottomCropAmount == Height endEdgeSize == borderAmounts[0].endEdgeSize - * leftCropAmount == Width startEdgeSize == borderAmounts[1].startEdgeSize - * rightCropAmount == Width endEdgeSize == borderAmounts[1].endEdgeSize - * - * H_out = H_in - topCropAmount - bottomCropAmount - * W_out = W_in - leftCropAmount - rightCropAmount - * - * If two inputs are used, output is computed as follows: - * - * .. code:: - * - * y = x1[:, offset[0]:offset[0] + H_out, offset[1]:offset[1] + W_out] - */ -message CropLayerParams { - - /** - * The amounts to be cropped from the input. - * Used only if a single input is provided. - */ - BorderAmounts cropAmounts = 1; - - /** - * The offset amounts. - * Used only if two inputs are provided. - * Must be of length 2, in order ``[H, W]``. - */ - repeated uint64 offset = 5; - -} - -/** - * A layer that computes the elementwise average of the inputs. - * This layer has limited broadcasting support. For general broadcasting see AddBroadcastableLayer. - * - * .. code:: - * - * y = AverageLayer(x1,x2,...) - * - * Requires multiple inputs and produces 1 output. - * - * Input - * In general, there are no rank constraints. - * However, only certain set of shapes are broadcastable. For example: - * [B, 1, 1, 1], [B, C, 1, 1], [B, 1, H, W], [B, C, H, W] - * Output - * A blob with the same shape as each input. - */ -message AverageLayerParams { - -} - -/** - * A layer that computes the elementwise maximum over the inputs. - * - * .. code:: - * - * y = MaxLayer(x1,x2,...) - * - * Requires multiple inputs and produces 1 output. - * - * Input - * In general, there are no rank constraints. - * However, only certain set of shapes are broadcastable. For example: - * [B, C, 1, 1], [B, C, H, W] - * Output - * A blob with the same shape as each input. - */ -message MaxLayerParams { - -} - -/** - * A layer that computes the elementwise minimum over the inputs. - * - * .. code:: - * - * y = MinLayer(x1,x2,...) - * - * Requires multiple inputs and produces 1 output. - * - * Input - * In general, there are no rank constraints. - * However, only certain set of shapes are broadcastable. For example: - * [B, C, 1, 1], [B, C, H, W] - * Output - * A blob with the same shape as each input. - */ -message MinLayerParams { - -} - -/** - * A layer that computes the dot product of two vectors. - * - * .. code:: - * - * y = DotProductLayer(x1,x2) - * - * Requires 2 inputs and produces 1 output. - * - * Input - * Two blobs with rank at least 3, such that the last two dimensions must be 1. - * e.g.: blobs with shape ``[B, C, 1, 1]``. - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * Output - * Same rank as the input. - * e.g. for rank 4 inputs, output shape: [B, 1, 1, 1] - */ -message DotProductLayerParams { - - /** - * If true, inputs are normalized first, - * thereby computing the cosine similarity. - */ - bool cosineSimilarity = 1; - -} - -/** - * A layer that performs mean variance normalization, along axis = -3. - * - * .. code:: - * - * y = MeanVarianceNormalizeLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank greater than equal to 3. - * Example: Rank 4 blob represents [Batch, channels, height, width] - * For ranks greater than 3, the leading dimensions, starting from 0 to -4 (inclusive), are all treated as batch. - * - * Output - * A blob with the same shape as the input. - * - * If ``acrossChannels == true`` - * normalization is performed on flattened input, i.e. the input is reshaped to (Batch,C), where "Batch" contains - * all dimensions from 0 to -4 (inclusive), and C contains dimensions -1, -2, -3. - * - * If ``acrossChannels == false`` - * normalization is performed within a channel, - * across spatial dimensions (i.e. last two dimensions). - */ -message MeanVarianceNormalizeLayerParams { - - /** - * If true, mean and variance are computed across channels. - */ - bool acrossChannels = 1; - - /** - * If false, only mean is subtracted. - */ - bool normalizeVariance = 2; - - /** - * A small constant to avoid division by 0 while normalizing variance. - * Defaults to ``1e-6`` if not set or set to ``0``. - */ - float epsilon = 3; - -} - -/** - * A layer that repeats a sequence or the dimension sitting at axis = -5 - * - * .. code:: - * - * y = SequenceRepeatLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A blob with rank at least 5. - * e.g: shape ``[Seq, B, C, H, W]`` - * Output - * A blob with the same rank as the input. - * e.g.: for input shape ``[Seq, B, C, H, W]``, output shape is ``[nRepetitions * Seq, B, C, H, W]``. - */ -message SequenceRepeatLayerParams { - - /** - * Number of repetitions. - * Defaults to ``1`` if not set or set to ``0``. - */ - uint64 nRepetitions = 1; - -} - -/// Recurrent Layers -/// ---------------- - -/* - * The following activations are supported with recurrent layers: - * - Linear - * - Sigmoid - * - Tanh - * - ReLU - * - Scaled Hyperbolic Tangent: alpha * tanh(beta * x), currently only supported for alpha = 1.7159, beta = 2/3 - * - Hard Sigmoid: min(max(alpha * x + beta, 0), 1), currently only supported for alpha = 0.2, beta = 0.5 - */ - -/** - * A simple recurrent layer. - * - * .. code:: - * - * y_t = SimpleRecurrentLayer(x_t, y_{t-1}) - * - * Input - * A blob of rank 5, with shape `[Seq, Batch, inputVectorSize, 1, 1]``. - * This represents a sequence of vectors of size ``inputVectorSize``. - * Output - * Same rank as the input. - * Represents a vector of size ``outputVectorSize``. It is either the final output or a sequence of outputs at all time steps. - * - * - Output Shape: ``[1, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == false`` - * - Output Shape: ``[Seq, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == true`` - * - * This layer is described by the following equation: - * - * .. math:: - * \boldsymbol{y_t} = f(\mathrm{clip}(W \boldsymbol{x_t} + \ - * R \boldsymbol{y_{t-1}} + b)) - * - * - ``W`` is a 2-dimensional weight matrix - * (``[outputVectorSize, inputVectorSize]``, row-major) - * - ``R`` is a 2-dimensional recursion matrix - * (``[outputVectorSize, outputVectorSize]``, row-major) - * - ``b`` is a 1-dimensional bias vector (``[outputVectorSize]``) - * - ``f()`` is an activation - * - ``clip()`` is a function that constrains values between ``[-50.0, 50.0]`` - */ -message SimpleRecurrentLayerParams { - - uint64 inputVectorSize = 1; /// The size of the input vectors. - uint64 outputVectorSize = 2; /// The size of the output vectors. - - /** - * Activations supported are Linear, Sigmoid, Tanh, ReLU, Scaled Tanh (alpha = 1.71, beta = 2/3), Hard sigmoid (alpha = 0.2, beta = 0.5) - */ - ActivationParams activation = 10; /// The activation function. - - /** - If false output is just the result after final state update. - If true, output is a sequence, containing outputs at all time steps. - */ - bool sequenceOutput = 15; - - bool hasBiasVector = 20; /// If false, no bias is added. - - WeightParams weightMatrix = 30; /// Weight matrix W. - WeightParams recursionMatrix = 31; /// Recursion Weight matrix R. - WeightParams biasVector = 32; /// Bias vector b. - - bool reverseInput = 100; - // If true, then the node processes the input sequence from right to left - -} - -/** - * Gated-Recurrent Unit (GRU) Layer - * - * .. code:: - * - * y_t = GRULayer(x_t, y_{t-1}) - * - * Input - * A blob of rank 5, with shape `[Seq, Batch, inputVectorSize, 1, 1]``. - * This represents a sequence of vectors of size ``inputVectorSize``. - * Output - * Same rank as the input. - * Represents a vector of size ``outputVectorSize``. It is either the final output or a sequence of outputs at all time steps. - * - * - Output Shape: ``[1, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == false`` - * - Output Shape: ``[Seq, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == true`` - * - * This layer is described by the following equations: - * - * Update Gate - * .. math:: - * \boldsymbol{z_t} = \ - * f(\mathrm{clip}(W_z \boldsymbol{x_t} + \ - * R_z \boldsymbol{y_{t-1}} + b_z) - * - * Reset Gate - * .. math:: - * \boldsymbol{r_t} = \ - * f(\mathrm{clip}(W_r \boldsymbol{x_t} + \ - * R_r \boldsymbol{y_{t-1}} + b_r)) - * - * Cell Memory State - * .. math:: - * \boldsymbol{c_t} = \ - * \boldsymbol{y_{t-1}} \odot \boldsymbol{r_t} - * - * Output Gate - * .. math:: - * \boldsymbol{o_t} = \ - * g(\mathrm{clip}(W_o \boldsymbol{x_t} + \ - * R_o \boldsymbol{c_t} + b_o)) - * - * Output - * .. math:: - * \boldsymbol{y_t} = \ - * (1 - \boldsymbol{z_t}) \odot \boldsymbol{o_t} + \ - * \boldsymbol{z_t} \odot \boldsymbol{y_{t-1}} - * - * - ``W_z``, ``W_r``, ``W_o`` are 2-dimensional input weight matrices - * (``[outputVectorSize, inputVectorSize]``, row-major) - * - ``R_z``, ``R_r``, ``R_o`` are 2-dimensional recursion matrices - * (``[outputVectorSize, outputVectorSize]``, row-major) - * - ``b_z``, ``b_r``, ``b_o`` are 1-dimensional bias vectors - * (``[outputVectorSize]``) - * - ``f()``, ``g()`` are activations - * - ``clip()`` is a function that constrains values between ``[-50.0, 50.0]`` - * - ``⊙`` denotes the elementwise product of matrices - */ -message GRULayerParams { - - uint64 inputVectorSize = 1; /// Size of the input vectors. - uint64 outputVectorSize = 2; /// Size of the output vectors. - - /** - * 2 element array representing activations [f(), g()] in that order. - * Typical values used = [sigmoid, tanh]. - * Activations supported are Linear, Sigmoid, Tanh, ReLU, Scaled Tanh (alpha = 1.71, beta = 2/3), Hard sigmoid (alpha = 0.2, beta = 0.5) - */ - repeated ActivationParams activations = 10; - - /** - * If false output is just the result after final state update. - * If true, output is a sequence, containing outputs at all time steps. - */ - bool sequenceOutput = 15; - - /** - * If false, no biases (``b_z``, ``b_r``, ``b_o``) are added. - */ - bool hasBiasVectors = 20; - - WeightParams updateGateWeightMatrix = 30; /// Weight Matrix W_z. - WeightParams resetGateWeightMatrix = 31; /// Weight Matrix W_r. - WeightParams outputGateWeightMatrix = 32; /// Weight Matrix W_o. - - WeightParams updateGateRecursionMatrix = 50; /// Recursion Weight Matrix R_z. - WeightParams resetGateRecursionMatrix = 51; /// Recursion Weight Matrix R_r. - WeightParams outputGateRecursionMatrix = 52; /// Recursion Weight Matrix R_o. - - WeightParams updateGateBiasVector = 70; /// Bias vector b_z. - WeightParams resetGateBiasVector = 71; /// Bias vector b_r. - WeightParams outputGateBiasVector = 72; /// Bias vector b_o. - - /// If true, then the node processes the input sequence from right to left - bool reverseInput = 100; - -} - -/** - * Long short-term memory (LSTM) parameters. - * - * This is described by the following equations: - * - * Input Gate - * .. math:: - * \boldsymbol{i_t} = \ - * f(\mathrm{clip}(W_i \boldsymbol{x_t} + \ - * R_i \boldsymbol{y_{t-1}} + \ - * p_i \odot c_{t-1} + b_i)) - * - * Forget Gate - * .. math:: - * \boldsymbol{f_t} = \ - * f(\mathrm{clip}(W_f \boldsymbol{x_t} + \ - * R_f \boldsymbol{y_{t-1}} + \ - * p_f \odot c_{t-1} + b_f)) - * - * Block Input - * .. math:: - * \boldsymbol{z_t} = \ - * g(\mathrm{clip}(W_z \boldsymbol{x_t} + \ - * R_z \boldsymbol{y_{t-1}} + b_z)) - * - * Cell Memory State - * .. math:: - * \boldsymbol{c_t} = \ - * \boldsymbol{c_{t-1}} \odot \boldsymbol{f_t} + \ - * \boldsymbol{i_t} \odot \boldsymbol{z_t} - * - * Output Gate - * .. math:: - * \boldsymbol{o_t} = \ - * f(\mathrm{clip}(W_o \boldsymbol{x_t} + \ - * R_o \boldsymbol{y_{t-1}} + \ - * p_o \odot c_t + b_o)) - * - * Output - * .. math:: - * \boldsymbol{y_t} = \ - * h(\boldsymbol{c_t}) \odot \boldsymbol{o_t} - * - * - ``W_i``, ``W_f``, ``W_z``, ``W_o`` are 2-dimensional input weight matrices - * (``[outputVectorSize, inputVectorSize]``, row-major) - * - ``R_i``, ``R_f``, ``R_z``, ``R_o`` are 2-dimensional recursion matrices - * (``[outputVectorSize, outputVectorSize]``, row-major) - * - ``b_i``, ``b_f``, ``b_z``, ``b_o`` are 1-dimensional bias vectors - * (``[outputVectorSize]``) - * - ``p_``, ``p_f``, ``p_o`` are 1-dimensional peephole vectors - * (``[outputVectorSize]``) - * - ``f()``, ``g()``, ``h()`` are activations - * - ``clip()`` is a function that constrains values between ``[-50.0, 50.0]`` - * - ``⊙`` denotes the elementwise product of matrices - */ -message LSTMParams { - - /** - * If true, output is a sequence, containing outputs at all time steps. - * If false, output is just the result after final state update. - */ - bool sequenceOutput = 10; - - /** - * If false, no biases (``b_i``, ``b_f``, ``b_z``, ``b_o``) are added. - */ - bool hasBiasVectors = 20; - - /** - * If true, a vector of ``1`` values is added to ``b_f``. - */ - bool forgetBias = 30; - - /** - * If true, peephole vectors are included. - */ - bool hasPeepholeVectors = 40; - - /** - * If the coupled Input and Forget flag is on, the behaviour of - * ``c_t`` is changed to the following (i.e. forget gate is not used): - * - * .. math:: - * \boldsymbol{c_t} = \ - * \boldsymbol{c_{t-1}} \odot (1 - \boldsymbol{i_t}) + \ - * \boldsymbol{i_t} \odot \boldsymbol{z_t} - * - */ - bool coupledInputAndForgetGate = 50; - - /** - * Places a limit on the maximum and minimum values of ``c_t``. - * c_t = min(c_t, cellClipThreshold) - * c_t = max(c_t, -cellClipThreshold) - * If 0, it is set to its default value = 50.0. - */ - float cellClipThreshold = 60; - -} - -/** - * Weights for long short-term memory (LSTM) layers - */ -message LSTMWeightParams { - - WeightParams inputGateWeightMatrix = 1; /// Weight Matrix W_i. - WeightParams forgetGateWeightMatrix = 2; /// Weight Matrix W_f. - WeightParams blockInputWeightMatrix = 3; /// Weight Matrix W_z. - WeightParams outputGateWeightMatrix = 4; /// Weight Matrix W_o. - - WeightParams inputGateRecursionMatrix = 20; /// Recursion Weight Matrix R_i. - WeightParams forgetGateRecursionMatrix = 21; /// Recursion Weight Matrix R_f. - WeightParams blockInputRecursionMatrix = 22; /// Recursion Weight Matrix R_z. - WeightParams outputGateRecursionMatrix = 23; /// Recursion Weight Matrix R_o. - - //biases: - WeightParams inputGateBiasVector = 40; /// Bias vector b_i. - WeightParams forgetGateBiasVector = 41; /// Bias vector b_f. - WeightParams blockInputBiasVector = 42; /// Bias vector b_z. - WeightParams outputGateBiasVector = 43; /// Bias vector b_o. - - //peepholes: - WeightParams inputGatePeepholeVector = 60; /// Peephole vector p_i. - WeightParams forgetGatePeepholeVector = 61; /// Peephole vector p_f. - WeightParams outputGatePeepholeVector = 62; /// Peephole vector p_o. - -} - -/** - * A unidirectional long short-term memory (LSTM) layer. - * - * .. code:: - * - * (y_t, c_t) = UniDirectionalLSTMLayer(x_t, y_{t-1}, c_{t-1}) - * - * Input - * A blob of rank 5, with shape `[Seq, Batch, inputVectorSize, 1, 1]``. - * This represents a sequence of vectors of size ``inputVectorSize``. - * Output - * Same rank as the input. - * Represents a vector of size ``outputVectorSize``. It is either the final output or a sequence of outputs at all time steps. - * - * - Output Shape: ``[1, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == false`` - * - Output Shape: ``[Seq, Batch, outputVectorSize, 1, 1]`` , if ``sequenceOutput == true`` - * - */ -message UniDirectionalLSTMLayerParams { - - uint64 inputVectorSize = 1; /// Size of the input vectors. - uint64 outputVectorSize = 2; /// Size of the output vectors. - - /** - * 3 element array representing activations [f(),g(),h()] in that order. - * Typical values used = [sigmoid, tanh, tanh]. - * Activations supported are Linear, Sigmoid, Tanh, ReLU, Scaled Tanh (alpha = 1.71, beta = 2/3), Hard sigmoid (alpha = 0.2, beta = 0.5) - */ - repeated ActivationParams activations = 10; - - LSTMParams params = 15; - - LSTMWeightParams weightParams = 20; /// Weights, biases and peepholes. - - /// If true, then the node processes the input sequence from right to left - bool reverseInput = 100; - -} - -/** - * Bidirectional long short-term memory (LSTM) layer - * - * .. code:: - * - * (y_t, c_t, y_t_reverse, c_t_reverse) = BiDirectionalLSTMLayer(x_t, y_{t-1}, c_{t-1}, y_{t-1}_reverse, c_{t-1}_reverse) - * - * Input - * A blob of rank 5, with shape `[Seq, Batch, inputVectorSize, 1, 1]``. - * This represents a sequence of vectors of size ``inputVectorSize``. - * Output - * Same rank as the input. - * Represents a vector of size ``2 * outputVectorSize``. It is either the final output or a sequence of outputs at all time steps. - * - * - Output Shape: ``[1, Batch, 2 * outputVectorSize, 1, 1]`` , if ``sequenceOutput == false`` - * - Output Shape: ``[Seq, Batch, 2 * outputVectorSize, 1, 1]`` , if ``sequenceOutput == true`` - * - * - * The first LSTM operates on the input sequence in the forward direction. - * The second LSTM operates on the input sequence in the reverse direction. - * - * Example: given the input sequence ``[x_1, x_2, x_3]``, - * where ``x_i`` are vectors at time index ``i``: - * - * The forward LSTM output is ``[yf_1, yf_2, yf_3]``, - * - * where ``yf_i`` are vectors of size ``outputVectorSize``: - * - * - ``yf_1`` is the output at the end of sequence {``x_1``} - * - ``yf_2`` is the output at the end of sequence {``x_1``, ``x_2``} - * - ``yf_3`` is the output at the end of sequence {``x_1``, ``x_2``, ``x_3``} - * - * The backward LSTM output: ``[yb_1, yb_2, yb_3]``, - * - * where ``yb_i`` are vectors of size ``outputVectorSize``: - * - * - ``yb_1`` is the output at the end of sequence {``x_3``} - * - ``yb_2`` is the output at the end of sequence {``x_3``, ``x_2``} - * - ``yb_3`` is the output at the end of sequence {``x_3``, ``x_2``, ``x_1``} - * - * Output of the bi-dir layer: - * - * - if ``sequenceOutput = True`` : { ``[yf_1, yb_3]``, ``[yf_2, yb_2]``, ``[yf_3, yb_1]`` } - * - if ``sequenceOutput = False`` : { ``[yf_3, yb_3]`` } - */ -message BiDirectionalLSTMLayerParams { - - /** - * Size of the input vectors. - */ - uint64 inputVectorSize = 1; - /** - * Size of the outputs vectors. - * It is same for both forward and backward LSTMs. - */ - uint64 outputVectorSize = 2; - - /** - * 3 element array representing activations [f(),g(),h()] in that order. - * Typical values used = [sigmoid, tanh, tanh]. - * Activations supported are Linear, Sigmoid, Tanh, ReLU, Scaled Tanh (alpha = 1.71, beta = 2/3), Hard sigmoid (alpha = 0.2, beta = 0.5) - */ - repeated ActivationParams activationsForwardLSTM = 10; - /** - * Currently, backward LSTM activations - * must be same as the ones for the forward LSTM. - */ - repeated ActivationParams activationsBackwardLSTM = 11; - - /** - * Common parameters shared by the forward and backward LSTMs. - */ - LSTMParams params = 15; - - /** - * Weights and biases. - * Must be a length 2 message, - * for the forward and backward LSTM respectively. - */ - repeated LSTMWeightParams weightParams = 20; - -} - -message CustomLayerParams { - - message CustomLayerParamValue { - oneof value { - double doubleValue = 10; - string stringValue = 20; - int32 intValue = 30; - int64 longValue = 40; - bool boolValue = 50; - } - } - - string className = 10; // The name of the class (conforming to MLCustomLayer) corresponding to this layer - repeated WeightParams weights = 20; // Any weights -- these are serialized in binary format and memmapped at runtime - map parameters = 30; // these may be handled as strings, so this should not be large - string description = 40; // An (optional) description of the layer provided by the model creator. This information is displayed when viewing the model, but does not affect the model's execution on device. - -} - -/** - * A layer that rearranges the dimensions and data of an input. - * - * .. code:: - * - * y = TransposeLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * A N-Dimensional tensor. - * Output - * A N-Dimensional tensor of the same rank but with dimensions and data permuted according to axes. - * Shape: ``[InputShape[axis[0]], InputShape[axis[1]], ... , InputShape[axis[N-1]]]`` - * - * Examples: - * - * - If ``axes`` is set to ``[3, 1, 2, 0]`` and the input shape is ``[6,7,8,9]``, - * then the output has shape ``[9,7,8,6]`` - */ - -message TransposeLayerParams { - - /** - * Length of "axes" should match the rank of input & output tensor - * "axes" should be a permutation of "[0,1,2,...,N-1]" where N is the rank. - */ - repeated uint64 axes = 1; // - -} - -/** - * A layer that computes the matrix multiplication of two tensors with numpy-like broadcasting - * where the matrices reside in the last two indices of the tensor. - * - * .. code:: - * - * y = BatchedMatMul(a,b) - * - * Requires 1 or 2 inputs and produces 1 output. - * - * The first tensor, "a", must be provided as an input. The second tensor can either be an input or provided as a weight matrix parameter. - * - * Input - * - a: First N-Dimensional tensor - * - b: Second N-Dimensional tensor (either a rank-N input or a matrix, i.e. N=2, provided as a layer parameter) - * - * Output - * A tensor containing the matrix product of two tensors. - * When there are two inputs: rank is max(2, rank(a), rank(b)) - * When there is one input: rank is same as that of the input. - * - * This operation behaves as following: - * - * When there are two inputs: - * - If N >= 2 for both tensors, it is treated as a batch of matrices residing in the last two indices. - * All the indices, except for the last two, are broadcasted using conventional rules. - * - If the first tensor is 1-D, it is converted to a 2-D tensor by prepending a 1 to its shape. Eg. (D) -> (1,D) - * - If the second tensor is 1-D, it is converted to a 2-D tensor by appending a 1 to its shape. Eg. (D) -> (D,1) - * - * When there is one input: - * - The weight matrix corresponds to a matrix, of shape (X1, X2). Values of X1, X2 must be provided as layer parameters. - * - The input, "a", is reshaped into a matrix by combining all the leading dimensions, except the last, into a batch dimension. eg: - * - if "a" is rank 1 (X1,) --> (1, X1). Output shape will be (X2,) - * - if "a" is rank 2 (B1, X1) --> no need to reshape. Output shape will be (B1, X2) - * - if "a" is rank 3 (B1, B2, X1) --> (B1 * B2, X1). Output shape will be (B1, B2, X2) - * - etc - */ -message BatchedMatMulLayerParams { - - /** - * If transposeA is true, it transposes the left matrix on the fly before matrix multiplication. - * (is ignored when there is one input) - */ - bool transposeA = 1; - /** - * If transposeB is true, it transposes the right matrix on the fly before matrix multiplication. - * (is ignored when there is one input) - */ - bool transposeB = 2; - - /* - * Following parameters are ignored when there are two inputs. - */ - - uint64 weightMatrixFirstDimension = 5; /// X1: same as the last dimension of the input tensor - uint64 weightMatrixSecondDimension = 6; /// X2: same as the last dimension of the output tensor - - bool hasBias = 7; /// Whether a bias is added or not. Supported only when there is one input. - - /* - * Weight matrix representing shape [X1, X2]. - * Values are however stored in column major order, - * in the "repeated float" or "bytes" fields of the message "WeightParams" - */ - WeightParams weights = 8; - WeightParams bias = 9; /// Bias vector [X2]. Supported only when there is one input. - - /** - * If set, this layer, at runtime, quantizes the floating point input blob to int8 before applying the - * matrix multiplication using the INT8 weight parameters provided in weights->int8RawValue. The - * result is then dequantized. - * Requires: - * * number of inputs to be 1 - * * hasBias == false - * * QuantizationType == LinearQuantizationParams, such that - * * size of the "scale" field is 1 and "bias" field is empty in "LinearQuantizationParams" - * * numberOfBits == 8 - * * weights->rawValue_size to be empty - */ - bool int8DynamicQuantize = 10; - -} - -/** - * A layer that concatenates a list of tensors along a specified axis. - * - * .. code:: - * - * y = ConcatNDLayer(x1,x2,....) - * - * Requires at least 2 input and produces 1 output. - * - * Input - * The rank of the input tensors must match and all dimensions also must match, except for the dimension 'axis'. - * - * - * Output - * Same rank as the input. The dimension along "axis", is the sum of the dimensions of the inputs. - * - * example: - * - * in1 : shape (3, 2), value = [[1, 2], [3, 4], [5, 6]] - * in2 : shape (3, 2), value = [[7, 8], [9, 10], [11, 12]] - * axis = 0 - * - * if interleave = False (default) - * output : shape (6, 2) - * output[0:3, :] = in1 - * output[3:6, :] = in2 - * value = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]] - * - * if interleave = True - * output : shape (6, 2) - * output[0::2, :] = in1 - * output[1::2, :] = in2 - * value = [[1, 2], [7, 8], [3, 4], [9, 10], [5, 6], [11, 12]] - * - */ -message ConcatNDLayerParams { - - /** - * Dimension along which to concatenate. Supports negative values of the parameter 'axis'. - */ - int64 axis = 1; - - /** - * (Only available in Core ML Specification >= 5 (iOS >= 14, macOS >= 11.0) - * Interleave option. If True, concatenation is done via interleaving the inputs. - * This requires all inputs to have the exact same shape. - */ - bool interleave = 2; - - -} - -/** - * A layer that performs softmax normalization along a specified axis. - * - * .. code:: - * - * y = SoftmaxNDLayer(x) - * - * Requires 1 input and produces 1 output. - * - * Output shape is same as the input. - */ -message SoftmaxNDLayerParams { - - /** - * Dimension on which the softmax would be performed. Supports negative values of the parameter 'axis'. - */ - int64 axis = 1; - -} - -/** - * A layer that reverses specific dimensions of the input tensor. - * It is similar in functionality to the numpy.flip method. - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - */ -message ReverseLayerParams { - - /** - * Reverses each dimension of the input tensor for which corresponding reverseDim is set to True. - * Requires len(reverseDim) == rank(inputTensor) - */ - repeated bool reverseDim = 1; - -} - -/** - * A layer that reverses variable length slices. - * - * Requires 2 inputs and produces 1 output. - * - * 2 inputs, in order are denoted by "data", "seq_lengths". - * "seq_lenghts" must be a rank 1 tensor, i.e. seq_lengths.shape = (B,) - * which contains the lengths of the amount of sequence to be reversed, for each element of the batch. - * Dimension "batchAxis" in "data" must be equal to B, i.e, - * data.shape[batchAxis] = B. - * - * According to the batch axis, input "data" is first divided into a batch of B inputs, - * each of which is flipped along the dimension "sequenceAxis", by the amount specified in - * "seq_lengths", the second input. - * - * e.g.: - * - * data [shape = (2,4)]: - * [0 1 2 3] - * [4 5 6 7] - * seq_lengths [shape = (2,)]: - * [3, 0] - * batchAxis = 0 - * sequenceAxis = 1 - * - * output [shape = (2,4)]: - * [2 1 0 3] - * [4 5 6 7] - * - * - * data [shape = (2,3,2)]: - * [0 1] - * [2 3] - * [4 5] (slice = 0) - * [6 7] - * [8 9] - * [10 11] (slice = 1) - * seq_lengths [shape = (2,)]: - * [2, 3] - * batchAxis = 0 - * sequenceAxis = 1 - * - * output [shape = (2,3,2)]: - * [2 3] - * [0 1] - * [4 5] (slice = 0) - * [10 11] - * [8 9] - * [6 7] (slice = 1) - * - * Output shape is same as the input. - */ -message ReverseSeqLayerParams { - - int64 batchAxis = 1; // batch axis has to be strictly less than seq_axis - int64 sequenceAxis = 2; - -} - -/** - * A layer that loads data as a parameter and provides it as an output. - * - * .. code:: - * - * y = LoadConstantNDLayer() - * - * Requires no input and produces 1 output. - * - * Output: A tensor with shape as provided in the parameter "shape" - */ -message LoadConstantNDLayerParams { - - /** - * The shape of the constant to be loaded. - */ - repeated uint64 shape = 1; - WeightParams data = 2; - -} - -/** - * A layer that generates an output tensor with a constant value. - * Input is only used to determine the shape of the output. - * This layer is used to allocate a tensor with a dynamic shape (that of the input) and constant value. - * - * Requires 1 input and produces 1 output. - * - * .. code:: - * - * y = FillLikeLayer(x) - * - * Input - * A N-Dimensional tensor, whose values are ignored. Only the shape is used to - * infer the shape of the output. - * - * Output - * A N-Dimensional tensor with the same shape as the input tensor. - * - */ -message FillLikeLayerParams { - - float value = 1; - -} - -/** - * A layer that generates an output tensor with a constant value. - * This layer is used to allocate a tensor with a static shape and constant value. - * - * Requires no input and produces 1 output. - * - * .. code:: - * - * y = FillStaticLayer(x) - * - * Output - * A N-Dimensional tensor of shape "targetShape". - * - */ -message FillStaticLayerParams { - - float value = 1; - repeated uint64 targetShape = 2; - -} - -/** - * A layer that generates an output tensor with a constant value. - * This layer is used to allocate a tensor with a dynamic shape (as specified by the input) and constant value. - * - * Requires 1 input and produces 1 output. - * - * .. code:: - * - * y = FillDynamicLayer(x) - * - * Input - * A rank 1 tensor specifying the shape of the output - * - * Output - * An N-Dimensional tensor with the shape specified by the values in the input tensor. - * - */ -message FillDynamicLayerParams { - - float value = 1; - -} - -/** - * A layer that returns the elements either from tensor x or tensor y, - * depending on the value in the condition tensor. - * It is similar in functionality to the numpy.where method with 3 inputs. - * - * Requires 3 inputs and produces 1 output. - * Inputs, in order, are the condition tensor, x and y. - * - * for each vector index (i,...,j): - * output[i,...,j] = x[i,...,j] if condition[i,...,j] = True - * y[i,...,j] if condition[i,...,j] = False - * - * All the 3 inputs are first broadcasted to a common shape. - * (the shapes must be broadcastable) - * - * output.rank = max(input[0].rank, input[1].rank, input[2].rank) - * - */ -message WhereBroadcastableLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric sine function. - * - * - * .. code:: - * - * y = SinLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message SinLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric cosine function. - * - * - * .. code:: - * - * y = CosLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message CosLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric tangent function. - * - * - * .. code:: - * - * y = TanLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message TanLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric arcsine function. - * - * - * .. code:: - * - * y = AsinLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message AsinLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric arccosine function. - * - * - * .. code:: - * - * y = AcosLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message AcosLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric arctangent function. - * - * - * .. code:: - * - * y = AtanLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message AtanLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric hyperbolic sine function. - * - * - * .. code:: - * - * y = SinhLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message SinhLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric hyperbolic cosine function. - * - * - * .. code:: - * - * y = CoshLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message CoshLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric hyperbolic tangent function. - * - * - * .. code:: - * - * y = TanhLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message TanhLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric hyperbolic arcsine function. - * - * - * .. code:: - * - * y = AsinhLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message AsinhLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric hyperbolic arccosine function. - * - * - * .. code:: - * - * y = AcoshLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message AcoshLayerParams { - -} - -/** - * A layer that computes elementwise trigonometric hyperbolic arctangent function. - * - * - * .. code:: - * - * y = AtanhLayer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message AtanhLayerParams { - -} -/** - * A layer that raises each element in first tensor to the power of - * corresponding element in the second tensor. - * Supports conventional numpy-like broadcasting. - * - * .. code:: - * - * y = PowBroadcastableLayer(x) - * - * Requires 2 inputs and produces 1 output. - * - * Input - * - First N-Dimensional tensor - * - Second N-Dimensional tensor - * - * Output - * An N-Dimensional tensor with the broadcast shape. - * - */ -message PowBroadcastableLayerParams { - -} - -/** - * A layer that computes the exponential of all elements in the input tensor, with the base 2. - * - * - * .. code:: - * - * y = Exp2Layer(x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message Exp2LayerParams { - -} - -/** - * A layer that returns a tensor containing the indices of all non-zero - * elements of input tensor. - * It is similar in functionality to the numpy.where method with 1 input. - * - * Requires 1 input and produces 1 output. - * Output is of rank 2, of shape (N,R), - * where N is the number of non-zero elements in the input and R is the rank of the input. - * - * Output contains indices represented in the multi-index form - * - * e.g.: - * input {shape = (4,)}: - * [0 1 0 2] - * output {shape = (2,1)}: - * [1] - * [3] - * - * - * input {shape = (3, 3)}: - * [1 2 1] - * [0 2 2] - * [2 1 0] - * output {shape = (7,1)}: - * [0. 0.] - * [0. 1.] - * [0. 2.] - * [1. 1.] - * [1. 2.] - * [2. 0.] - * [2. 1.] - * - */ -message WhereNonZeroLayerParams { - -} - -/** - * A layer that copies a tensor setting everything outside a central band in - * each inner-most matrix to zero. - * - * Requires 1 input and produces 1 output. - * - * Parameters for matrix_band_part layer - * band(m, n) = (num_lower < 0 || (m-n) <= num_lower) && (num_upper < 0 || (n-m) <= num_upper). - * output[i, j, k, ..., m, n] = band(m, n) * input[i, j, k, ..., m, n] - * - * - * Output shape is same as the input shape. - * Rank of the input must be at least 2. - * For rank higher than 2, the last 2 dimensions are treated as the matrix, while the rest are treated as batch. - */ -message MatrixBandPartLayerParams { - - int64 numLower = 1; - int64 numUpper = 2; - -} - -/** - * A layer that copies a tensor setting everything outside upper triangular to zero. - * - * Requires 1 input and produces 1 output. - * - * Output shape is same as the input shape. - * Rank of the input must be at least 2. - * For rank higher than 2, the last 2 dimensions are treated as the matrix, while the rest are treated as batch. - */ -message UpperTriangularLayerParams { - - int64 k = 1; // Diagonal below which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above - -} - -/** - * A layer that copies a tensor setting everything outside lower triangular to zero. - * - * Requires 1 input and produces 1 output. - * - * Output shape is same as the input shape. - * Rank of the input must be at least 2. - * For rank higher than 2, the last 2 dimensions are treated as the matrix, while the rest are treated as batch. - */ -message LowerTriangularLayerParams { - - int64 k = 1; // Diagonal above which to zero elements. k = 0 (the default) is the main diagonal, k < 0 is below it and k > 0 is above - -} - -/** - * - * A layer that broadcasts a tensor to a new shape. - * - * Requires 2 inputs and produces 1 output. - * - * First input is broadcast to produce the output, while the second input is only - * used to determine the shape of the output. Values of second input are not used. - * - * Output is a tensor with the same shape as the second input. - * - */ -message BroadcastToLikeLayerParams { - -} - -/** - * - * A layer that broadcasts a tensor to a new shape. - * - * Requires 1 input and produces 1 output. - * - * Output tensor is the broadcasted version of the input and has shape as specified in the - * parameter "targetShape". - */ -message BroadcastToStaticLayerParams { - - repeated uint64 targetShape = 1; - -} - -/** - * - * A layer that broadcasts a tensor to a new shape. - * - * Requires 2 inputs and produces 1 output. - * - * First input is the one that is broadcasted to produce the output. - * Second input is a rank 1 tensor specifying the shape of the output. - * Output tensor has shape as specified by the values in the 2nd input tensor. - */ -message BroadcastToDynamicLayerParams { - -} - -/** - * A layer that performs element-wise addition operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message AddBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise maximum operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message MaxBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise minimum operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message MinBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise modular operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message ModBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise floor division operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message FloorDivBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise subtract operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message SubtractBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise multiply operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message MultiplyBroadcastableLayerParams { - -} - -/** - * A layer that performs element-wise division operation with broadcast support. - * - * Requires 2 inputs and produces 1 output. - */ -message DivideBroadcastableLayerParams { - -} - -/** - * Gather layer that gathers elements from the first input, along a specified axis, - * at indices specified in the second input. - * It is similar in functionality to the numpy.take method. - * - * Requires 2 inputs and produces 1 output. - * - * Given two inputs, 'data' and 'indices', gather the slices of 'data' - * and store into output. - * e.g. - * for i in [0, length(indices) - 1] - * output[i] = data[indices[i]] (1-D case, axis=0) - * - * if axis = 0: - * for each vector index (i,...,j) - * output[i,...,j,:,..,:] = data[indices[i,...,j],:,..,:] - * - * output.rank = (data.rank - 1) + indices.rank - * - * Negative indices and negative axis are supported. - * - * e.g: - * - * data shape = (2, 3) - * indices shape = (6, 8) - * axis = 0 - * output shape = (6, 8) + (3,) = (6, 8, 3) - * - * data shape = (2, 3, 5) - * indices shape = (6, 8) - * axis = 1 - * output shape = (2,) + (6, 8) + (5,) = (2, 6, 8, 5) - * - */ -message GatherLayerParams { - - int64 axis = 1; - -} - -/* - * Scatter accumulation mode. - */ -enum ScatterMode { - - SCATTER_UPDATE = 0; - SCATTER_ADD = 1; /// add - SCATTER_SUB = 2; /// subtract - SCATTER_MUL = 3; /// multiply - SCATTER_DIV = 4; /// divide - SCATTER_MAX = 5; /// maximum - SCATTER_MIN = 6; /// minimum - -} - -/* - * A layer that scatters data into a new tensor according to indices from the input. - * This is the inverse operation of Gather. - * - * Requires 3 inputs and produces 1 output. - * - * Output is initialized with the first input. - * Then updated with the values in the third input, at indices specified by the second input. - * - * An example when axis=0: - * Given three inputs, in order, "container", "indices", "updates", where - * - * - "container" is a rank R+1 tensor of shape [D_0, D_1, ..., D_R], which - * contains D_0 number of tensors, each with shape [D_1, ..., D_R]. - * - * - "indices" is a rank 1 tensor with shape [N], where N is the number of updates. - * The values in this tensor must be in the range [0, D_0 - 1]. (negative indexing is supported) - * - * - "updates" is a rank R+1 tensor with shape [N, D_1, ..., D_R], which represents - * a total number of N tensors, each of shape [D_1, ..., D_R]. - * - * The effect of this operation is as follows: - * - * output = container; - * For each i in 0, ..., N - 1 - * output[indices[i], :, ..., :] = updates[i, :, ..., :] // if mode == "SCATTER_UPDATE" - * - * or - * For each i in 0, ..., N - 1 - * output[indices[i], :, ..., :] += updates[i, :, ..., :] // if mode == "SCATTER_ADD" - * - * etc - * - * When "indices" is a tensor of rank greater than 1, the equation becomes (for axis=0): - * For each vector index (i,...,j) - * output[indices[i,...,j],...] -= updates[i,...,j,...] // if mode == "SCATTER_SUB" - * - * - * The output has the same shape as the first input. - * "indices" input must have rank less than or equal to the "updates" input and its shape - * must be a subset of the the shape of the "updates" input. - * - * e.g: - * - * container shape = (4, 3) - * indices shape = (5, 2, 3) - * updates shape = (4, 5, 2, 3) - * axis = 1 - * output shape = (4, 3) - * - * container shape = (4, 4, 3) - * indices shape = (6,) - * updates shape = (4, 6, 3) - * axis = -2 - * output shape = (4, 4, 3) - * - * container shape = (5,) - * indices shape = (5, 7, 5, 6) - * updates shape = (5, 7, 5, 6) - * axis = -1 - * output shape = (5,) - */ - -message ScatterLayerParams { - - int64 axis = 1; - ScatterMode mode = 2; /// mode of accumulation. - -} - -/** - * A layer that gathers elements from the first input, 'params', at the multi-indices specified - * by the second input, 'indices'. - * - * Requires 2 inputs and produces 1 output. - * - * 'params' = input[0], 'indices' = input[1] - * - * 'indices' is a rank K+1 tensor of shape [I_0, I_1, .., I_(K-1), I_K] which is viewed as a collection of - * indices of (I_0 * I_1 * ... * I_(K-1)) points in the I_K dimensional space. For instance, the multi-index of the first point - * is indices[0,0,...,0,:]. - * - * Here is how the output is constructed: - * - * for i = 0,1,...,(I_0-1) - * ... - * for j = 0,1,....,(I_(K-1)-1) - * output[i,....,j,:,:,..,:] = params[indices[i,...,j,:], :,:,..,:] - * - * Hence, output shape is [I_0, I_1,...,I(K-1)] + params.shape[I_K:] - * - * output.rank = indices.rank - 1 + params.rank - indices.shape[-1] - * - * e.g: - * - * input[0] shape = (4, 2, 3, 4) - * input[1] shape = (6, 2) - * output shape = (6,) + (3, 4) = (6, 3, 4) - * - * input[0] shape = (3, 3, 3, 4, 7) - * input[1] shape = (3, 5) - * output shape = (3,) + () = (3,) - * - * input[0] shape = (5, 3, 2, 5) - * input[1] shape = (2, 7, 3, 2) - * output shape = (2, 7, 3) + (2, 5) = (2, 7, 3, 2, 5) - * - */ -message GatherNDLayerParams { - -} - -/* - * A layer that scatters data into a new tensor according to multi-indices from the input. - * This is the inverse operation of GatherND. - * - * Requires 3 inputs and produces 1 output. - * 3 inputs, in order are denoted as "container", "indices", "updates". - * - * 'indices' is a rank K+1 tensor of shape [I_0, I_1, .., I_(K-1), I_K] which is viewed as a collection of - * indices of (I_0 * I_1 * ... * I_(K-1)) points in the I_K dimensional space. For instance, the multi-index of the first point - * is indices[0,0,...,0,:]. - * - * container.rank >= I_K - * updates.rank = K + (container.rank - I_K) - * shape of 'updates' = [I_0, I_1,...,I(K-1)] + container.shape[I_K:] - * - * output = container - * For each vector index (i,...,j) s.t. 0<=i shape: (3,) - * reps = N/A [Ignored] - * output shape = (2, 8, 12) - * - */ -message TileLayerParams { - - repeated uint64 reps = 1; - -} - -/** - * A layer that returns the shape of an input tensor. - * - * Requires 1 input and produces 1 output. - * - * Input: a tensor. - * Output: a vector of length R, where R is the rank of the input tensor - * Output is always a rank 1 tensor. - */ -message GetShapeLayerParams { - -} - -/** - * A layer that computes the Gauss error function, - * which is defined as: - * - * .. math:: - * f(x) = \dfrac{1}{\sqrt{\pi}}\int_{-x}^{x}{e^{-t^2}dt} - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - */ -message ErfLayerParams { - -} - -/** - * A layer that evaluates the Gaussian Error Linear Unit (GELU) activation. - * Following equations are used to compute the activation based on the value of the "mode" parameter: - * - * mode == 'EXACT': - * .. math:: - * f(x) = 0.5x\left ( 1+\rm{erf}\left ( \frac{x}{\sqrt{2}} \right ) \right ) - * - * mode == 'TANH_APPROXIMATION': - * .. math:: - * f(x) = 0.5x\left ( 1+\rm{tanh}\left ( \sqrt{2/\pi}\left ( x + 0.044715x^3 \right ) \right ) \right ) - * - * mode == 'SIGMOID_APPROXIMATION': - * .. math:: - * f(x) = x*\rm{sigmoid}(1.702x) - * - * Requires 1 input and produces 1 output. - * Output shape is same as the input. - * - */ -message GeluLayerParams { - - enum GeluMode { - - EXACT = 0; - TANH_APPROXIMATION = 1; - SIGMOID_APPROXIMATION = 2; - - } - - GeluMode mode = 1; /// mode of GELU operation. - -} - -/** - * RangeStatic layer that returns a tensor that contains evenly spaced values. - * It is similar in functionality to the numpy.arange method. - * - * Requires no input and produces 1 output. - * Output is a rank 1 tensor. - */ -message RangeStaticLayerParams { - - float endValue = 1; - float startValue = 2; - float stepSizeValue = 3; - -} - -/** - * A layer that returns a tensor that contains evenly spaced values. - * Its functionality is similar to the numpy.arange method. - * - * Requires at least 1 input, up to a maximum of 3 inputs. - * Produces 1 output, which is a rank 1 tensor. - * - * Each input must be a scalar, or rank 1 and shape (1,). - * - * The first input represents the "endValue". - * The second input, if present, corresponds to "startValue". In this case the value of the "startValue" parameter is ignored. - * The third input, if present, corresponds to "stepSizeValue". In this case the value of the "stepSizeValue" parameter is ignored. - * - */ -message RangeDynamicLayerParams { - - float startValue = 2; - float stepSizeValue = 3; - -} - -/** - * A layer that returns a tensor containing all windows of size ``windowSize`` - * separated by ``step`` along the dimension ``axis``. - * - * .. code:: - * - * y = SlidingWindows(x) - * - * Requires 1 input and produces 1 output. - * - * Input - * An N-Dimensional tensor. - * - * Output - * An (N+1)-Dimensional tensor. - * - * This operation behaves as following: - * - if axis = 0 & input is rank 1 (L,). Output shape will be (M, W). - * - if axis = 1 & input is rank 3 (B1, L, C1). Output shape will be (B1, M, W, C1) - * - if axis = 2 & input is rank 5 (B1, B2, L, C1, C2) --> (B1 * B2, L, C1 * C2) --> (B1 * B2, M, W, C1 * C2). Output shape will be (B1, B2, M, W, C1, C2) - * - etc. - * where - * - L, C, B refer to input length, feature dimension length & batch size respectively - * - W is the window size. - * - M is the number of windows/slices calculated as M = (L - W) / step + 1 - */ -message SlidingWindowsLayerParams { - - int64 axis = 1; - uint64 windowSize = 2; - uint64 step = 3; - -} - -/** - * A layer that applies layer normalization over the input tensor. - * - * Requires 1 input and produces 1 output. - * - * output = gamma * (input - computed_mean) / (sqrt(computed_variance + eps)) + beta - * - * Parameters - * normalizedShape: subset of the input shape, along with layer norm is performed, rest of the input shape is treated as the batch dimension. The mean and variance are computed for the input, over the last few dimensions as specified by the normalizedShape parameter. - * gamma: must have shape = "normalizedShape" - * beta: must have shape = "normalizedShape" - * eps: small constant to avoid division by 0 - * - * Output shape is same as the input. - * - * e.g.: - * input shape = (10,5) - * normalized shape = (5,) or (10,5) - * - * input shape = (10,5,6,7) - * normalized shape = (7,) or (6,7) or (5,6,7) or (10,5,6,7) - */ -message LayerNormalizationLayerParams { - - repeated int64 normalizedShape = 1; - float eps = 2; - WeightParams gamma = 3; - WeightParams beta = 4; - -} - -/** - * Non maximum suppression (NMS) layer. - * Applies the non maximum suppression algorithm to input bounding box coordinates. - * The effect of this layer is similar to the functionality of the "NonMaximumSuppression" - * model type (for details please see NonMaximumSuppression.proto) with a couple of differences. - * One, this is a layer in a neural network model, whereas that is a different model type. Second, - * this layer supports a batch of bounding boxes. - * - * The NMS layer requires at least 2 inputs, and up to a maximum of 5 inputs. It produces 4 outputs. - * Following is the description of inputs and outputs: - * - * input 1, shape (B,N,4): coordinates of N boxes, for a batch size B. - * input 2, shape (B,N,C): class scores for each box. C can be 1 when there is only 1 score per box, i.e., no class specific score. - * - * input 3, optional, shape (1,): IoU threshold. When present, it overwrites the value provided in layer parameter "iouThreshold". - * input 4, optional, shape (1,): Score threshold. When present, it overwrites the value provided in layer parameter "scoreThreshold". - * input 5, optional, shape (1,): Maximum number of boxes. When present, it overwrites the value provided in layer parameter "maxBoxes". - * - * output 1, shape (B,maxBoxes,4): box coordinates, corresponding to the surviving boxes. - * output 2, shape (B,maxBoxes,C): box scores, corresponding to the surviving boxes. - * output 3, shape (B,maxBoxes): indices of the surviving boxes. Hence it will have values in the range [0,N-1], except for padding. - * output 4, shape (B,): number of boxes selected after the NMS algorithm, for each batch. - * - * When surviving boxes are less than "maxBoxes", the first 3 outputs are padded. - * For the first two outputs, the padding is done using values 0, whereas for the third output the - * padding value used is -1, since the output values represent indices. - * - * If no box survives, that is, all the scores are below the "scoreThreshold", - * then for that batch, number of boxes (value of the fourth output) will be 1. The first 3 outputs will - * correspond to the box with the highest score. This is to avoid generating an "empty" output. - * - * The four values that describe the box dimensions are (in order): - * - * - x (center location of the box along the horizontal axis) - * - y (center location of the box along the vertical axis) - * - width (size of box along the horizontal axis) - * - height (size of box on along the vertical axis) - * - * In each batch, - * the N scores for N boxes, used for suppression, are generated by taking the max of the matrix (N,C) - * along the columns. - * If "perClassSuppression" flag is false, suppression happens across all classes. - * If "perClassSuppression" flag is true, each box is assigned to the class with the highest - * score and then the suppression happens separately for boxes within the same class. - * - * Note that the 4th output can be used to dynamically slice the first 3 outputs, in case - * the padded outputs are not required. - * - */ -message NonMaximumSuppressionLayerParams { - /** - * The intersection over union (IoU) threshold over which boxes are suppressed. - */ - float iouThreshold = 1; - - /** - * Before IoU suppression is performed, boxes with class scores below this threshold are rejected. - */ - float scoreThreshold = 2; - - /** - * The maximum number of boxes to be given out as output. - * If the number of surviving boxes are less, output is padded up to this number. - */ - uint64 maxBoxes = 3; - - /** - * If true, suppression is performed independently within boxes of each class. - */ - bool perClassSuppression = 4; -} - -/** - * A layer that performs element-wise clamped ReLU operation. - * - * Requires 1 input and produces 1 output. - * - * This function has the following formula: - * - * .. math:: - * f(x) = \begin{cases} - * \text{min}(\text{beta},x) \;\; \text{if} \;\; x \geq 0\\ - * \text{min}(\text{beta} ,\text{alpha}\cdot x) \;\; \text{if} \;\; x<0 - * \end{cases} - * - * Output shape is same as the input. - * - * Available (iOS >= 14, macOS >= 11.0, watchOS >= 7) - */ -message ClampedReLULayerParams { - - float alpha = 1; - float beta = 2; - -} - -/** -* A layer that returns the indices that would sort the input tensor, along a specified axis. -* -* Requires 1 input and produces 1 output. -* -* Output has the same rank and shape as the input. -* -* Value of "axis" must be positive and less than the rank of the input. -* -* e.g.: -* -* input shape = (5,) -* axis = 0 -* input values = [3.1, 5.4, 32.9, 3.2, 77.0] -* output shape = (5,) -* output values = [0, 3, 1, 2, 4], descending = False -* output values = [4, 2, 1, 3, 0], descending = True -* -* input shape = (2,3) -* axis = 1 -* input values = [[3, 5, 32], [3, 77, 6]] -* output shape = (2,3) -* output values = [[0, 1, 2], [0, 2, 1]], descending = False -* output values = [[2, 1, 0], [1, 2, 0]], descending = True -* -*/ -message ArgSortLayerParams { - - int64 axis = 1; /// must be between [0, input_rank - 1] - bool descending = 2; - -} - -/** - * A layer that does slice operation by providing size to be extracted - * from the given input tensor. - * - * Requires 2 inputs and produces 1 output. - * Rank of the output is same as the rank of the first input. - * - * The 1st input represents the tensor to be sliced. - * The 2nd input represents the beginning index to be sliced from. - * - * Example: - * Input 1: x (x.shape = (2, 3, 4)) - * Input 2: begin - * size: 2 - * axis: 1 - * - * Output: x[:, begin:begin+2, :] - * - */ -message SliceBySizeLayerParams { - - int64 size = 2; - int64 axis = 3; - -} - - -/// Neural Network Specializations -/// ------------------------------ - -/** - * A neural network specialized as a classifier. - */ -message NeuralNetworkClassifier { - - repeated NeuralNetworkLayer layers = 1; - repeated NeuralNetworkPreprocessing preprocessing = 2; - - // use this enum value to determine the input tensor shapes to the neural network, for multiarray inputs - NeuralNetworkMultiArrayShapeMapping arrayInputShapeMapping = 5; - - // use this enum value to determine the input tensor shapes to the neural network, for image inputs - NeuralNetworkImageShapeMapping imageInputShapeMapping = 6; - - NetworkUpdateParameters updateParams = 10; - - // The set of labels for every possible class. - oneof ClassLabels { - StringVector stringClassLabels = 100; - Int64Vector int64ClassLabels = 101; - } - - // The name of the output blob containing the probability of each class. - // In other words, the score vector. Must be a 1-D tensor with the same - // number and order of elements as ClassLabels. - string labelProbabilityLayerName = 200; -} - - -/** - * A layer that computes the one hot representation of the input. - * - * Requires 1 or 2 inputs and produces 1 output. - * Rank of the output is one more than the first input. - * If the second input is present, it is used to determine the value of "oneHotVectorSize" and the parameter "oneHotVectorSize" is ignored. - * - * Input values correspond to indices and should typically be in the range [0,"oneHotVectorSize" -1]. If it is outside this range, a vector of all "offValue" will be chosen. - * - * Typically one hot vectors contain 0s everywhere, except 1 at the index that the input corresponds to. - * However, instead of 0, any float value could be generated by using the "offValue" parameter. - * Similarly, instead of 1, any other value can be used by employing the "onValue" parameter. - * - * e.g.: - * input shape: (10,), "oneHotVectorSize" : 32, axis=-1, then output shape will be (10,32) - * input shape: (10,23), "oneHotVectorSize" : 32, axis=1, then output shape will be (10,32,23) - * input shape: (10,), "oneHotVectorSize" : 32, axis=0, then output shape will be (32,10) - * - * input shape: (2,), "oneHotVectorSize" : 4, axis=-1, then output shape will be (2,4) - * say input values = [2, 0], and "onValue" = 5, and "offValue" = -1, then output will be: - * [-1, -1, 5, -1 - * 5, -1, -1, -1] - * - * say input values = [2, -1], and "onValue" = 5, and "offValue" = -1, then output will be: - * [-1, -1, 5, -1 - * -1, -1, -1, -1] - * - * Available (iOS >= 14, macOS >= 11.0, watchOS >= 7) - */ - -message OneHotLayerParams { - - uint64 oneHotVectorSize = 1; /// size of the one hot vector - int64 axis = 2; /// negative indexing is supported. It refers to the axis in the output tensor. - float onValue = 3; - float offValue = 4; -} - - -/** - * A layer that computes the cumsum values of the input along a given axis. - * - * Requires 1 or 2 inputs and produces 1 output. - * - * Output shape and rank is same as the first input. - * If the second input is present, it is used to determine the value of "axis" and the parameter "axis" is ignored. - * - * e.g.: - * Input shape = (3,), values it has: [4, 6, 7] - * - * Then output values will be: - * - * if "excludeFinalSum" = False and "reverse" = False: - * output values : [4, 10, 17] - * - * if "excludeFinalSum" = True and "reverse" = False: - * output values : [0, 4, 10] - * - * if "excludeFinalSum" = False and "reverse" = True: - * output values : [17, 13, 7] - * - * if "excludeFinalSum" = True and "reverse" = True: - * output values : [13, 7, 0] - * - * - * Available (iOS >= 14, macOS >= 11.0, watchOS >= 7) - */ - - -message CumSumLayerParams { - - int64 axis = 1; /// negative indexing is supported - - /// if true, the first element of the output is 0, and the last element contains the sum of the input up to the penultimate value - /// if false, the first element of the output is same as the input and the last element is the sum of all the input values - /// (this behavior is reversed when "reverse" flag is True) - bool excludeFinalSum = 2; - - bool reverse = 3; /// if true, cumsum is performed in the opposite direction -} - - -/** - * A neural network specialized as a regressor. - */ -message NeuralNetworkRegressor { - - repeated NeuralNetworkLayer layers = 1; - repeated NeuralNetworkPreprocessing preprocessing = 2; - - // use this enum value to determine the input tensor shapes to the neural network, for multiarray inputs - NeuralNetworkMultiArrayShapeMapping arrayInputShapeMapping = 5; - - // use this enum value to determine the input tensor shapes to the neural network, for image inputs - NeuralNetworkImageShapeMapping imageInputShapeMapping = 6; - - NetworkUpdateParameters updateParams = 10; - -} - -/// --------------------------------------------------------- -/// On-device Training related messages -/// --------------------------------------------------------- - -/** - * Details on how the network will be updated - */ -message NetworkUpdateParameters { - - repeated LossLayer lossLayers = 1; - Optimizer optimizer = 2; - Int64Parameter epochs = 3; - - /** - * Describes whether to shuffle the batch of data between epochs. - */ - BoolParameter shuffle = 10; - - /** - * The seed to be used in an associated random number generator. - */ - Int64Parameter seed = 20; -} - -/** - * Loss layer - categorical cross entropy and mean squared error are the only supported loss functions currently - */ -message LossLayer { - - string name = 1; - oneof LossLayerType { - - CategoricalCrossEntropyLossLayer categoricalCrossEntropyLossLayer = 10; - MeanSquaredErrorLossLayer meanSquaredErrorLossLayer = 11; - - } - -} - -/** - * Categorical cross entropy loss layer - * Categorical cross entropy is used for single label categorization (only one category is applicable for each data point). - * - * The input is a vector of length N representing the distribution over N categories. It must be the output of a softmax. - * - * The target is a single value representing the true category or class label. If the target is the predictedFeatureName of a neural network classifier it will be inverse mapped to the corresponding categorical index for you. - * - * math: - * Loss_{CCE}(input, target) = -\sum_{i=1}^{N} (target == i) log( input[i] ) = - log (input[target]) - */ -message CategoricalCrossEntropyLossLayer { - - string input = 1; - string target = 2; - -} - -/** - * Mean squared error loss layer, - * specifying input and target - */ -message MeanSquaredErrorLossLayer { - - string input = 1; - string target = 2; - -} - -/** - * Optimizer - stochastic gradient descent and adam are the only supported optimizers currently - */ -message Optimizer { - - oneof OptimizerType { - - SGDOptimizer sgdOptimizer = 10; - AdamOptimizer adamOptimizer = 11; - - } - -} - -/** - * Stochastic gradient descent optimizer, - * specifying configurable learning rate, mini batch size, and momentum - */ -message SGDOptimizer { - - DoubleParameter learningRate = 1; - Int64Parameter miniBatchSize = 2; - DoubleParameter momentum = 3; - -} - -/** - * Adam optimizer, - * specifying configurable learning rate, mini batch size, betas, and eps - */ -message AdamOptimizer { - - DoubleParameter learningRate = 1; - Int64Parameter miniBatchSize = 2; - DoubleParameter beta1 = 3; - DoubleParameter beta2 = 4; - DoubleParameter eps = 5; - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/NonMaximumSuppression.proto b/onnxruntime/core/providers/coreml/mlmodel_format/NonMaximumSuppression.proto deleted file mode 100644 index c98949a0c2e21..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/NonMaximumSuppression.proto +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright (c) 2018, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/* -* Non-maximum suppression of axis-aligned bounding boxes. -* -* This is used primarily for object detectors that tend to produce multiple -* boxes around a single object. This is a byproduct of the detector's -* robustness to spatial translation. If there are two or more bounding boxes -* that are very similar to one another, the algorithm should return only a -* single representative. -* -* Similarity between two bounding boxes is measured by intersection-over-union -* (IOU), the fraction between the area of intersection and area of the union. -* Here is an example where the areas can be calculated by hand by counting glyphs:: -* -* +-------+ +-------+ -* | | | | -* | +------+ +--+ | +---+ -* | | | | | | | | -* +-------+ | +--+ +----+ | -* | | | | -* +------+ +------+ -* Intersection Union -* IOU: 0.16 = 12 / 73 -* -* All IOU scores are fractions betwen 0.0 (fully disjoint) and 1.0 (perfect -* overlap). The standard algorithm (PickTop) is defined as follows: -* -* 1. Sort boxes by descending order of confidence -* 2. Take the top one and mark it as keep -* 3. Suppress (mark it as discard) all boxes within a fixed IOU radius of the -* keep box -* 4. Go to 2 and repeat on the subset of boxes not already kept or discarded -* 5. When all boxes are processed, output only the ones marked as keep -* -* Before the algorithm, boxes that fall below the confidence threshold are -* discarded. -*/ -message NonMaximumSuppression { - // Suppression methods: - /* - * Pick the bounding box of the top confidence, suppress all within a radius. - */ - message PickTop { - /* - * Suppression is only done among predictions with the same label - * (argmax of the confidence). - */ - bool perClass = 1; - } - - /* - * Choose which underlying suppression method to use - */ - oneof SuppressionMethod { - PickTop pickTop = 1; - } - - /* - * Optional class label mapping. - */ - oneof ClassLabels { - StringVector stringClassLabels = 100; - Int64Vector int64ClassLabels = 101; - } - - /* - * This defines the radius of suppression. A box is considered to be within - * the radius of another box if their IOU score is less than this value. - */ - double iouThreshold = 110; - - /* - * Remove bounding boxes below this threshold. The algorithm run-time is - * proportional to the square of the number of incoming bounding boxes - * (O(N^2)). This threshold is a way to reduce N to make the algorithm - * faster. The confidence threshold can be any non-negative value. Negative - * confidences are not allowed, since if the output shape is specified to be - * larger than boxes after suppression, the unused boxes are filled with - * zero confidence. If the prediction is handled by Core Vision, it is also - * important that confidences are defined with the following semantics: - * - * 1. Confidences should be between 0 and 1 - * 2. The sum of the confidences for a prediction should not exceed 1, but is - * allowed to be less than 1 - * 3. The sum of the confidences will be interpreted as the confidence of - * any object (e.g. if the confidences for two classes are 0.2 and 0.4, - it means there is a 60% (0.2 + 0.4) confidence that an object is - present) - */ - double confidenceThreshold = 111; - - /* - * Set the name of the confidence input. - * - * The input should be a multi-array of type double and shape N x C. N is - * the number of boxes and C the number of classes. Each row describes the - * confidences of each object category being present at that particular - * location. Confidences should be nonnegative, where 0.0 means the highest - * certainty the object is not present. - * - * Specifying shape is optional. - */ - string confidenceInputFeatureName = 200; - - /* - * Set the name of the coordinates input. - * - * The input should be a multi-array of type double and shape N x 4. The - * rows correspond to the rows of the confidence matrix. The four values - * describe (in order): - * - * - x (center location of the box along the horizontal axis) - * - y (center location of the box along the vertical axis) - * - width (size of box along the horizontal axis) - * - height (size of box on along the vertical axis) - * - * Specifying shape is optional. - */ - string coordinatesInputFeatureName = 201; - - /* - * The iouThreshold can be optionally overridden by specifying this string - * and providing a corresponding input of type double. This allows changing - * the value of the parameter during run-time. - * - * The input should be a scalar double between 0.0 and 1.0. Setting it to 1.0 - * means there will be no suppression based on IOU. - */ - string iouThresholdInputFeatureName = 202; - - /* - * The confidenceThreshold can be optionally overridden by specifying this - * string and providing a corresponding input. This allows changing the - * value of the parameter during run-time, which can aid setting it just - * right for a particular use case. - * - * The input should be a scalar double with nonnegative value. - */ - string confidenceThresholdInputFeatureName = 203; - - /* - * Set the name of the confidence output. The output will be the same type - * and shape as the corresponding input. The only difference is that the - * number of rows may have been reduced. - * - * Specifying shape is optional. One reason to specify shape is to limit - * the number of output boxes. This can be done is several ways: - * - * Fixed shape: - * The output can be pinned to a fixed set of boxes. If this number is larger - * than the number of boxes that would have been returned, the output is padded - * with zeros for both confidence and coordinates. Specifying a fixed shape - * can be done by setting either shape (deprecated) or allowedShapes set to - * fixedsize. - * - * Min/max: - * It is also possible to set both a minimum and a maximum. The same zero-padding - * as for fixed shape is applied when necessary. Setting min/max is done by defining - * two allowedShapes, where the first dimension uses a rangeofsizes defining lowerbound - * and upperbound. - */ - string confidenceOutputFeatureName = 210; - - /* - * Set the name of the coordinates output. The output will be the same type - * and shape as the corresponding input. The only difference is that the - * number of rows may have been reduced. - * - * Specifying shape is optional. See confidence output for a more detailed - * description. Note that to achieve either fixed shape output or a - * constraint range of boxes, only one of confidence or coordinates need to - * set a shape. Both shapes are allowed to be defined, but in such case they - * have to be consistent along dimension 0. - */ - string coordinatesOutputFeatureName = 211; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Normalizer.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Normalizer.proto deleted file mode 100644 index 627f7e2e3afd7..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Normalizer.proto +++ /dev/null @@ -1,38 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * A normalization preprocessor. - */ -message Normalizer { - /** - * There are three normalization modes, - * which have the corresponding formulas: - * - * Max - * .. math:: - * max(x_i) - * - * L1 - * .. math:: - * z = ||x||_1 = \sum_{i=1}^{n} |x_i| - * - * L2 - * .. math:: - * z = ||x||_2 = \sqrt{\sum_{i=1}^{n} x_i^2} - */ - enum NormType { - LMax = 0; - L1 = 1; - L2 = 2; - } - - NormType normType = 1; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/OneHotEncoder.proto b/onnxruntime/core/providers/coreml/mlmodel_format/OneHotEncoder.proto deleted file mode 100644 index f47cf28166222..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/OneHotEncoder.proto +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * Transforms a categorical feature into an array. The array will be all - * zeros expect a single entry of one. - * - * Each categorical value will map to an index, this mapping is given by - * either the ``stringCategories`` parameter or the ``int64Categories`` - * parameter. - */ -message OneHotEncoder { - enum HandleUnknown { - ErrorOnUnknown = 0; - IgnoreUnknown = 1; // Output will be all zeros for unknown values. - } - - /** - * Mapping to be used for the encoding. The position of the category in - * the below vector determines where the single one entry will be in the - * output. - */ - oneof CategoryType { - StringVector stringCategories = 1; - Int64Vector int64Categories = 2; - } - - // Output can be a dictionary with only one entry, instead of an array. - bool outputSparse = 10; - - HandleUnknown handleUnknown = 11; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Parameters.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Parameters.proto deleted file mode 100644 index ed1ebe525181f..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Parameters.proto +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * Int64 parameter, - * consisting of a default int64 value, and allowed range or set of values - * value is unbounded if AllowedValues is not set. - */ -message Int64Parameter { - int64 defaultValue = 1; - oneof AllowedValues { - Int64Range range = 10; - Int64Set set = 11; - } -} - -/** - * Double parameter, - * consisting of a default double value, and allowed range of values - * value is unbounded if AllowedValues is not set. - */ -message DoubleParameter { - double defaultValue = 1; - oneof AllowedValues { - DoubleRange range = 10; - } -} - -/** - * String parameter, - * A default string value must be provided - */ -message StringParameter { - string defaultValue = 1; -} - -/** - * String parameter, - * A default bool value must be provided - */ -message BoolParameter { - bool defaultValue = 1; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/README.md b/onnxruntime/core/providers/coreml/mlmodel_format/README.md deleted file mode 100644 index e5eba65f982ad..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/README.md +++ /dev/null @@ -1,16 +0,0 @@ -# Core ML Model Format Specification -This directory contains the protobuf message definitions that comprise the Core ML model document (``.mlmodel``) format. - -The top-level message is ``Model``, which is defined in ``Model.proto``. -Other message types describe data structures, feature types, feature engineering model types, and predictive model types. - -# Update the Core ML Model Format Specification -Please do not modify protobuf message definitions, they are copied directly from [Core ML Tools](https://github.com/apple/coremltools) repository. - -To update the Core ML Model Format Schema schema files to a more recent version: -1. Delete all the protobuf message definitions (`.proto`) from this directory. -2. Copy the new version of protobuf message definitions (`.proto`) from the `mlmodel/format/` directory of preferred coremltools release branch. - -# Core ML Model Format Schema version history -## [coremltools 4.0](https://github.com/apple/coremltools/releases/tag/4.0) -[Core ML Model Format Specification](https://github.com/apple/coremltools/tree/4.0/mlmodel/format) diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/SVM.proto b/onnxruntime/core/providers/coreml/mlmodel_format/SVM.proto deleted file mode 100644 index 932a4ec216682..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/SVM.proto +++ /dev/null @@ -1,195 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/// Kernel Definitions -/// ------------------ - -/** - * A linear kernel. - * - * This function has the following formula: - * - * .. math:: - * K(\boldsymbol{x}, \boldsymbol{x'}) = \boldsymbol{x}^T \boldsymbol{x'} - */ -message LinearKernel { -} - -/** - * A Gaussian radial basis function (RBF) kernel. - * - * This function has the following formula: - * - * .. math:: - * K(\boldsymbol{x}, \boldsymbol{x'}) = \ - * \exp(-\gamma || \boldsymbol{x} - \boldsymbol{x'} ||^2 ) - * - */ -message RBFKernel { - double gamma = 1; -} - -/** - * A polynomial kernel. - * - * This function has the following formula: - * - * .. math:: - * K(\boldsymbol{x}, \boldsymbol{x'}) = \ - * (\gamma \boldsymbol{x}^T \boldsymbol{x'} + c)^{degree} - */ -message PolyKernel { - int32 degree = 1; - double c = 2; - double gamma = 3; -} - -/** - * A sigmoid kernel. - * - * This function has the following formula: - * - * .. math:: - * K(\boldsymbol{x}, \boldsymbol{x'}) = \ - * \tanh(\gamma \boldsymbol{x}^T \boldsymbol{x'} + c) - */ -message SigmoidKernel { - double gamma = 1; - double c = 2; -} - -/** - * A kernel. - */ -message Kernel { - oneof kernel { - LinearKernel linearKernel = 1; - RBFKernel rbfKernel = 2; - PolyKernel polyKernel = 3; - SigmoidKernel sigmoidKernel = 4; - } -} - - -/// Support Vector Definitions -/// -------------------------- - -/** - * A sparse node. - */ -message SparseNode { - int32 index = 1; // 1-based indexes, like libsvm - double value = 2; -} - -/** - * A sparse vector. - */ -message SparseVector { - repeated SparseNode nodes = 1; -} - -/** - * One or more sparse support vectors. - */ -message SparseSupportVectors { - repeated SparseVector vectors = 1; -} - -/** - * A dense vector. - */ -message DenseVector { - repeated double values = 1; -} - -/** - * One or more dense support vectors. - */ -message DenseSupportVectors { - repeated DenseVector vectors = 1; -} - -/** - * One or more coefficients. - */ -message Coefficients { - repeated double alpha = 1; -} - -/** - * A support vector regressor. - */ -message SupportVectorRegressor { - Kernel kernel = 1; - - // Support vectors, either sparse or dense format - oneof supportVectors { - SparseSupportVectors sparseSupportVectors = 2; - DenseSupportVectors denseSupportVectors = 3; - } - - // Coefficients, one for each support vector - Coefficients coefficients = 4; - - double rho = 5; -} - -/** - * A support vector classifier - */ -message SupportVectorClassifier { - Kernel kernel = 1; - - /** - * The number of support vectors for each class. - */ - repeated int32 numberOfSupportVectorsPerClass = 2; - - /** - * The support vectors, in either sparse or dense format. - */ - oneof supportVectors { - SparseSupportVectors sparseSupportVectors = 3; - DenseSupportVectors denseSupportVectors = 4; - } - - /** - * The coefficients, essentially a two dimensional array of - * size: (numberOfClasses-1) by (total number of support vectors) - */ - repeated Coefficients coefficients = 5; - - /** - * Constants for decision function, - * with K*(K-1) / 2 elements, - * where K is the number of classes. - */ - repeated double rho = 6; - - /** - * Pairwise probability information for A vs B classifier. - * Total of K*(K-1)/2 elements where K is the number of classes. - * These fields are optional, - * and only required if you want probabilities or multi class predictions. - */ - repeated double probA = 7; - repeated double probB = 8; - - /** - * Class label mapping. - */ - oneof ClassLabels { - StringVector stringClassLabels = 100; - Int64Vector int64ClassLabels = 101; - } -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/Scaler.proto b/onnxruntime/core/providers/coreml/mlmodel_format/Scaler.proto deleted file mode 100644 index f0e13d54be2e8..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/Scaler.proto +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification; - -/** - * A scaling operation. - * - * This function has the following formula: - * - * .. math:: - * f(x) = scaleValue \cdot (x + shiftValue) - * - * If the ``scaleValue`` is not given, the default value 1 is used. - * If the ``shiftValue`` is not given, the default value 0 is used. - * - * If ``scaleValue`` and ``shiftValue`` are each a single value - * and the input is an array, then the scale and shift are applied - * to each element of the array. - * - * If the input is an integer, then it is converted to a double to - * perform the scaling operation. If the output type is an integer, - * then it is cast to an integer. If that cast is lossy, then an - * error is generated. - */ -message Scaler { - repeated double shiftValue = 1; - repeated double scaleValue = 2; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/SoundAnalysisPreprocessing.proto b/onnxruntime/core/providers/coreml/mlmodel_format/SoundAnalysisPreprocessing.proto deleted file mode 100644 index 05bb744a9af94..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/SoundAnalysisPreprocessing.proto +++ /dev/null @@ -1,60 +0,0 @@ -// Copyright (c) 2019, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification.CoreMLModels; - -/** -* A model which takes audio signal samples as input and outputs an array of -* preprocessed samples according to the specified preprocessing types -*/ -message SoundAnalysisPreprocessing { - - // Specific preprocessing types for sound analysis - - /* Vggish preprocesses input audio samples and makes them ready to - be fed to Vggish feature extractor. - c.f. https://arxiv.org/pdf/1609.09430.pdf - - The preprocessing takes input a single channel (monophonic) audio samples - 975 miliseconds long, sampled at 16KHz, i.e., 15600 samples 1D multiarray - and produces preprocessed samples in multiarray of shape [1, 96, 64] - - (1) Splits the input audio samples into overlapping frames, where each - frame is 25 milliseconds long and hops forward by 10 milliseconds. - Any partial frames at the end are dropped. - - (2) Hann window: apply a periodic Hann with a window_length of - 25 milliseconds, which translates to 400 samples in 16KHz sampling rate - - w(n) = 0.5 - 0.5 * cos(2*pi*n/window_length_sample), - where 0 <= n <= window_lenth_samples - 1 and window_lenth_samples = 400 - - Then, the Hann window is applied to each frame as below - - windowed_frame(n) = frame(n) * w(n) - where 0 <= n <= window_lenth_samples - 1 and window_lenth_samples = 400 - - (3) Power spectrum: calculate short-time Fourier transfor magnitude, with - an FFT length of 512 - - (4) Log Mel filter bank: calculates a log magnitude mel-frequency - spectrogram minimum frequency of 125Hz and maximum frequency of 7500Hz, - number of mel bins is 64, log_offset is 0.01, number of spectrum bins - is 64. - */ - - message Vggish { - // no specific parameter - } - - // Vision feature print type - oneof SoundAnalysisPreprocessingType { - Vggish vggish = 20; - } - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/TextClassifier.proto b/onnxruntime/core/providers/coreml/mlmodel_format/TextClassifier.proto deleted file mode 100644 index bf6d3c7f7f3e5..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/TextClassifier.proto +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) 2018, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification.CoreMLModels; - -/** -* A model which takes a single input string and outputs a -* label for the input. -*/ -message TextClassifier { - - /* - * Stores the resivion number for the model, revision 1 is available on - * iOS, tvOS 12.0+, macoOS 10.14+ - */ - uint32 revision = 1; - - /* - * Stores the language of the model, as specified in BCP-47 format, - * e.g. "en-US". See https://tools.ietf.org/html/bcp47 - */ - string language = 10; - - /* - * Stores the byte representation of learned model parameters - */ - bytes modelParameterData = 100; - - /* - * Stores the set of output class labels - */ - oneof ClassLabels { - StringVector stringClassLabels = 200; - } - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/TreeEnsemble.proto b/onnxruntime/core/providers/coreml/mlmodel_format/TreeEnsemble.proto deleted file mode 100644 index defebee98852c..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/TreeEnsemble.proto +++ /dev/null @@ -1,161 +0,0 @@ -// Copyright (c) 2017, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -/** - * Each tree is a collection of nodes, - * each of which is identified by a unique identifier. - * - * Each node is either a branch or a leaf node. - * A branch node evaluates a value according to a behavior; - * if true, the node identified by ``true_child_node_id`` is evaluated next, - * if false, the node identified by ``false_child_node_id`` is evaluated next. - * A leaf node adds the evaluation value to the base prediction value - * to get the final prediction. - * - * A tree must have exactly one root node, - * which has no parent node. - * A tree must not terminate on a branch node. - * All leaf nodes must be accessible - * by evaluating one or more branch nodes in sequence, - * starting from the root node. - */ - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification; - -/** - * A tree ensemble post-evaluation transform. - */ -enum TreeEnsemblePostEvaluationTransform { - NoTransform = 0; - Classification_SoftMax = 1; - Regression_Logistic = 2; - Classification_SoftMaxWithZeroClassReference = 3; -} - -/** - * Tree ensemble parameters. - */ -message TreeEnsembleParameters { - message TreeNode { - uint64 treeId = 1; - uint64 nodeId = 2; - - enum TreeNodeBehavior { - BranchOnValueLessThanEqual = 0; - BranchOnValueLessThan = 1; - BranchOnValueGreaterThanEqual = 2; - BranchOnValueGreaterThan = 3; - BranchOnValueEqual = 4; - BranchOnValueNotEqual = 5; - LeafNode = 6; - } - - /** - * The branch mode parameters. - * - * If branch is false, - * then the parameters in this section must be filled in - * to determine how the branching functions. - */ - TreeNodeBehavior nodeBehavior = 3; - - /** - * If the node behavior mode is a branch mode, - * then these values must be filled in. - */ - uint64 branchFeatureIndex = 10; - double branchFeatureValue = 11; - uint64 trueChildNodeId = 12; - uint64 falseChildNodeId = 13; - bool missingValueTracksTrueChild = 14; - - /** - * The leaf mode. - * - * If ``nodeBahavior`` == ``LeafNode``, - * then the evaluationValue is added to the base prediction value - * in order to get the final prediction. - * To support multiclass classification - * as well as regression and binary classification, - * the evaluation value is encoded here as a sparse vector, - * with evaluationIndex being the index of the base vector - * that evaluation value is added to. - * In the single class case, - * it is expected that evaluationIndex is exactly 0. - */ - message EvaluationInfo { - uint64 evaluationIndex = 1; - double evaluationValue = 2; - } - - repeated EvaluationInfo evaluationInfo = 20; - - /** - * The relative hit rate of a node for optimization purposes. - * - * This value has no effect on the accuracy of the result; - * it allows the tree to optimize for frequent branches. - * The value is relative, - * compared to the hit rates of other branch nodes. - * - * You typically use a proportion of training samples - * that reached this node - * or some similar metric to derive this value. - */ - double relativeHitRate = 30; - } - - repeated TreeNode nodes = 1; - - /** - * The number of prediction dimensions or classes in the model. - * - * All instances of ``evaluationIndex`` in a leaf node - * must be less than this value, - * and the number of values in the ``basePredictionValue`` field - * must be equal to this value. - * - * For regression, - * this is the dimension of the prediction. - * For classification, - * this is the number of classes. - */ - uint64 numPredictionDimensions = 2; - - /** - * The base prediction value. - * - * The number of values in this must match - * the default values of the tree model. - */ - repeated double basePredictionValue = 3; -} - -/** - * A tree ensemble classifier. - */ -message TreeEnsembleClassifier { - TreeEnsembleParameters treeEnsemble = 1; - TreeEnsemblePostEvaluationTransform postEvaluationTransform = 2; - - // Required class label mapping - oneof ClassLabels { - StringVector stringClassLabels = 100; - Int64Vector int64ClassLabels = 101; - } -} - -/** - * A tree ensemble regressor. - */ -message TreeEnsembleRegressor { - TreeEnsembleParameters treeEnsemble = 1; - TreeEnsemblePostEvaluationTransform postEvaluationTransform = 2; -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/VisionFeaturePrint.proto b/onnxruntime/core/providers/coreml/mlmodel_format/VisionFeaturePrint.proto deleted file mode 100644 index cd13d290e421e..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/VisionFeaturePrint.proto +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) 2018, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -package CoreML.Specification.CoreMLModels; - -/** -* A model which takes an input image and outputs array(s) of features -* according to the specified feature types -*/ -message VisionFeaturePrint { - - // Specific vision feature print types - - // Scene extracts features useful for identifying contents of natural images - // in both indoor and outdoor environments - message Scene { - enum SceneVersion { - SCENE_VERSION_INVALID = 0; - // VERSION_1 is available on iOS,tvOS 12.0+, macOS 10.14+ - // It uses a 299x299 input image and yields a 2048 float feature vector - SCENE_VERSION_1 = 1; - } - - SceneVersion version = 1; - } - - // Objects extracts features useful for identifying and localizing - // objects in natural images - message Objects { - enum ObjectsVersion { - OBJECTS_VERSION_INVALID = 0; - // VERSION_1 is available on iOS,tvOS 14.0+, macOS 11.0+ - // It uses a 299x299 input image and yields two multiarray - // features: one at high resolution of shape (288, 35, 35) - // the other at low resolution of shape (768, 17, 17) - OBJECTS_VERSION_1 = 1; - } - - ObjectsVersion version = 1; - - /* - * Stores the names of the output features according to the - * order of them being computed from the neural network, i.e., - * the first element in the output is the earliest being - * computed, while the last is the latest being computed. In - * general, the order reflects the resolution of the feature. - * The earlier it is computed, the higher the feature resolution. - */ - repeated string output = 100; - } - - // Vision feature print type - oneof VisionFeaturePrintType { - Scene scene = 20; - Objects objects = 21; - } - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/WordEmbedding.proto b/onnxruntime/core/providers/coreml/mlmodel_format/WordEmbedding.proto deleted file mode 100644 index ec11a67ca5294..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/WordEmbedding.proto +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2019, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification.CoreMLModels; - -/** -* A model which maps a set of strings into a finite-dimensional real vector space. -*/ -message WordEmbedding { - - /* - * Stores the revision number for the model, revision 2 is available on - * iOS, tvOS 13.0+, macOS 10.15+ - */ - uint32 revision = 1; - - /* - * Stores the language of the model, as specified in BCP-47 format, - * e.g. "en-US". See https://tools.ietf.org/html/bcp47 - */ - string language = 10; - - /* - * Stores efficient representation of emebedding as encoded by the Natural Language Framework - */ - bytes modelParameterData = 100; - -} diff --git a/onnxruntime/core/providers/coreml/mlmodel_format/WordTagger.proto b/onnxruntime/core/providers/coreml/mlmodel_format/WordTagger.proto deleted file mode 100644 index 8523e05df2c0b..0000000000000 --- a/onnxruntime/core/providers/coreml/mlmodel_format/WordTagger.proto +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright (c) 2018, Apple Inc. All rights reserved. -// -// Use of this source code is governed by a BSD-3-clause license that can be -// found in LICENSE.txt or at https://opensource.org/licenses/BSD-3-Clause - -syntax = "proto3"; -option optimize_for = LITE_RUNTIME; - -import public "DataStructures.proto"; - -package CoreML.Specification.CoreMLModels; - -/** -* A model which takes a single input string and outputs a -* sequence of tokens, tags for tokens, along with their -* locations and lengths, in the original string. -*/ -message WordTagger { - - /* - * Stores the resivion number for the model, revision 1 is available on - * iOS, tvOS 12.0+, macoOS 10.14+ - */ - uint32 revision = 1; - - /* - * Stores the language of the model, as specified in BCP-47 format, - * e.g. "en-US". See https://tools.ietf.org/html/bcp47 - */ - string language = 10; - - /* - * Stores the name of tokens output. The output will be - * a sequence of strings that contains the tokens in the - * input string - */ - string tokensOutputFeatureName = 20; - - /* - * Stores the name of token tags output. The output will be - * a sequence of strings that contains the tags for each - * token in the input string - */ - string tokenTagsOutputFeatureName = 21; - - /* - * Stores the name of token locations output. The output will be - * a sequence of integers that contains the locations (indices) - * for each token in the input string, location starts from 0 - */ - string tokenLocationsOutputFeatureName = 22; - - /* - * Stores the name of token lengths output. The output will be - * a sequence of integers that contains the lengths for each - * token in the input string - */ - string tokenLengthsOutputFeatureName = 23; - - /* - * Stores the byte representation of learned model parameters - */ - bytes modelParameterData = 100; - - /* - * Stores the set of output tags - */ - oneof Tags { - StringVector stringTags = 200; - } - - - -} - diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 55f6561b7a44a..95e34cd863915 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.132 + version: 1.0.133 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.132 + version: 1.0.133 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. From 0cba56e0a010fe8c87a7be91bb84d36508edbf6d Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Sun, 4 Feb 2024 16:37:36 +0800 Subject: [PATCH 035/207] [ROCm] Fix CI pipeline by fixing pytest version (#19407) Fix pytest version to 7.4.4, higher version will cause error `from onnxruntime.capi import onnxruntime_validation ModuleNotFoundError: No module named 'onnxruntime.capi'` --- tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 4db9df80ed187..4767c74afd28f 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -125,7 +125,8 @@ RUN pip install \ pytorch_lightning==1.6.0 \ pytest-xdist \ pytest-rerunfailures \ - ml_dtypes==0.3.0 + ml_dtypes==0.3.0 \ + pytest==7.4.4 # Install migraphx RUN apt update && apt install -y migraphx From 435e19953ea54115124fd637a67a87681a7fc8eb Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Mon, 5 Feb 2024 07:26:24 +0800 Subject: [PATCH 036/207] Fix llama.covert_onnx to make it runnable in CI (#19372) ### Description 1. make parity_check use local model to avoid using hf token 2. del the model didn't work because it tried to del the object define out of the function scope. So it caused out of memory in A10. 3. In fact, 16G GPU memory (one T4) is enough. But the conversion process always be killed in T4 and it works on A10/24G. Standard_NC4as_T4_v3 has 28G CPU memory Standard_NV36ads_A10_v5 has 440G memory. It looks that the model conversion needs very huge memory. ### Motivation and Context Last time, I came across some issues in convert_to_onnx.py so I use the onnx model in https://github.com/microsoft/Llama-2-Onnx for testing. Now, these issues could be fixed. So I use onnx model generated by this repo and the CI can cover the model conversion. --- .../models/llama/convert_to_onnx.py | 17 +++-- .../transformers/models/llama/llama_parity.py | 62 ++++++++++++++----- .../models/llama/requirements-cuda.txt | 4 +- .../models/llama/requirements.txt | 4 +- .../azure-pipelines/bigmodels-ci-pipeline.yml | 53 +++++++--------- 5 files changed, 84 insertions(+), 56 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 71f52faa2c1e6..c9ff384a4c856 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -781,6 +781,13 @@ def get_args(): action="store_true", help="Avoid exporting model, only apply quantizations and optimizations to existing model exported from optimum.", ) + + parser.add_argument( + "--small_gpu", + action="store_true", + help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB.", + ) + parser.set_defaults(optimize_optimum=False) args = parser.parse_args() @@ -788,9 +795,7 @@ def get_args(): def main(): - if version.parse(torch.__version__) < version.parse("2.2.0") and "2.2.0.dev" not in torch.__version__: - # Second predicate is for comparing nightly (ex: 2.2.0.dev20230920 vs 2.2.0) since first predicate is false - # in that scenario. It can be removed when torch v2.2.0 is released in stable. + if version.parse(torch.__version__) < version.parse("2.2.0"): logger.error(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.") return @@ -1021,7 +1026,11 @@ def main(): args.precision, "--cache_dir", args.cache_dir, + "--torch_model_directory", + args.input, ] + if args.small_gpu: + parity_cmd.append("--small_gpu") if "with_past" in filename: parity_cmd.append("--use_past_kv") if "merged" in filename: @@ -1030,7 +1039,7 @@ def main(): parity_cmd.append("--use_gqa") try: - logger.debug(f"check parity with cmd: {parity_cmd}") + logger.info(f"check parity with cmd: {parity_cmd}") parity_check(parity_cmd) except Exception as e: logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True) diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index 25d7519769604..f41a90208c51b 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -17,7 +17,7 @@ get_sample_with_past_kv_inputs, ) from llama_torch import setup_torch_model -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig import onnxruntime as ort @@ -67,20 +67,39 @@ def get_inputs(args: argparse.Namespace, config: AutoConfig): def verify_parity( - args: argparse.Namespace, config: AutoConfig, pt_model: AutoModelForCausalLM, kv_cache_ortvalues: dict + args: argparse.Namespace, + location: str, + use_auth_token: bool, + kv_cache_ortvalues: dict, + pytorch_model: None | torch.nn.Module = None, + config: None | AutoConfig = None, ): + # If it's running in a machine which GPU memory < 36GB, it should unload the llama in GPU in time and free the GPU memory for ORT. + py_model = pytorch_model + if py_model is None: + config, py_model = setup_torch_model( + args, + location, + use_auth_token, + torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), + device=args.device, + ) + inputs = get_inputs(args, config) # Run inference with PyTorch if args.execution_provider != "cpu": torch.cuda.synchronize() start_time = time.time() - pt_outputs = pt_model(**inputs).logits.detach().cpu().numpy() + pt_outputs = py_model(**inputs).logits.detach().cpu().numpy() if args.execution_provider != "cpu": torch.cuda.synchronize() end_time = time.time() logger.info(f"PyTorch took {end_time - start_time} s") - del pt_model + + if args.small_gpu and py_model is not None: + del py_model + torch.cuda.empty_cache() # Run inference with ORT past_sequence_length, _, max_sequence_length = get_sequence_lengths(args) @@ -222,6 +241,13 @@ def get_args(argv: list[str]): help="model cache dir to override default HF cache dir to avoid overflood the /home dir", ) + # The argument is used for CI mainly, because the CI machine has 24G GPU memory at most. + parser.add_argument( + "--small_gpu", + action="store_true", + help="Load the llama in GPU every time for parity_check if it's running in a machine which GPU memory < 36GB. ", + ) + args = parser.parse_args() if argv == [] else parser.parse_args(argv) # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models @@ -247,25 +273,29 @@ def main(argv: list[str] = []): # noqa: B006 use_auth_token = args.torch_model_directory == os.path.join(".") location = args.model_name if use_auth_token else args.torch_model_directory - config, llama = setup_torch_model( - args, - location, - use_auth_token, - torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), - device=args.device, - ) - kv_cache_ortvalues = {} if not args.merged: - verify_parity(args, config, llama, kv_cache_ortvalues) + verify_parity(args, location, use_auth_token, kv_cache_ortvalues) else: - # Verify prompt generation in merged model (decoder_model.onnx) + config = llama = None + if not args.small_gpu: + config, llama = setup_torch_model( + args, + location, + use_auth_token, + torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), + device=args.device, + ) + + # Verify prompt processing in merged model (decoder_model.onnx) args.use_past_kv = False - kv_cache_ortvalues = verify_parity(args, config, llama, kv_cache_ortvalues) + kv_cache_ortvalues = verify_parity( + args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config + ) # Verify token generation in merged model (decoder_with_past_model.onnx) args.use_past_kv = True - verify_parity(args, config, llama, kv_cache_ortvalues) + verify_parity(args, location, use_auth_token, kv_cache_ortvalues, pytorch_model=llama, config=config) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt index b634bcc50f6e4..acd9c23aa42d0 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt @@ -1,4 +1,4 @@ -r requirements.txt -# Please manually install torch>=2.2.0.dev20230920 with CUDA enabled for the CUDA version installed in your system. +# Please manually install torch>=2.2.0 with CUDA enabled for the CUDA version installed in your system. # Instructions can be found here: https://pytorch.org/get-started/locally/ -onnxruntime-gpu>=1.16.2 \ No newline at end of file +onnxruntime-gpu>=1.16.2 diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index b72c972e7a16a..8b57279295e35 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,6 +1,6 @@ optimum>=1.14.1 transformers>=4.33.2 -torch>=2.2.0.dev20230920 +torch>=2.2.0 onnx>=1.14.0 datasets>=2.8.0 -protobuf==3.20.2 \ No newline at end of file +protobuf==3.20.2 diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index 0de2ac44215c4..65866fc9827a5 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -268,7 +268,7 @@ stages: skipComponentGovernanceDetection: true workspace: clean: all - pool: onnxruntime-Linux-GPU-T4 + pool: Onnxruntime-Linux-A10-24G steps: - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' @@ -278,10 +278,6 @@ stages: clean: true submodules: none - - checkout: LLaMa2Onnx - clean: true - submodules: none - - template: templates/flex-downloadPipelineArtifact.yml parameters: StepName: 'Download Onnxruntime Artifact' @@ -290,47 +286,40 @@ stages: SpecificArtifact: ${{ parameters.specificArtifact }} BuildId: ${{ parameters.BuildId }} - - task: DownloadPackage@1 - displayName: 'Download Llama2 model' - inputs: - packageType: upack - feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' - version: 1.0.0 - definition: '772ebce3-7e06-46d5-b3cc-82040ec4b2ce' - downloadPath: $(Agent.TempDirectory)/llama2_onnx_ft16 - - template: templates/get-docker-image-steps.yml parameters: - Dockerfile: onnxruntime/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 - Context: onnxruntime/tools/ci_build/github/linux/docker/ - ScriptName: onnxruntime/tools/ci_build/get_docker_image.py + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 + Context: tools/ci_build/github/linux/docker/ + ScriptName: tools/ci_build/get_docker_image.py DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" Repository: onnxruntimeubi8packagestest UpdateDepsTxt: false + - task: DownloadPackage@1 + displayName: 'Download Meta Llama2 model' + inputs: + packageType: upack + feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' + version: 1.0.0 + definition: '6fe0c4ed-9d0e-4d66-94cc-fb6a111d02a5' + downloadPath: $(Agent.TempDirectory)/meta_llama2_7b_hf + - script: | - docker run --rm --gpus all -v $(Build.SourcesDirectory)/Llama-2-Onnx:/workspace \ + docker run --rm --gpus all -v $(Build.SourcesDirectory):/workspace \ -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ - -v $(Agent.TempDirectory)/llama2_onnx_ft16:/models \ + -v $(Agent.TempDirectory)/meta_llama2_7b_hf:/meta-llama2 \ onnxruntimeubi8packagestest \ bash -c " set -ex; \ + pushd /workspace/onnxruntime/python/tools/transformers/ ; \ python3 -m pip install --upgrade pip ; \ + pushd models/llama ; \ + python3 -m pip install -r requirements-cuda.txt ; \ + popd ; \ python3 -m pip install /ort-artifact/*.whl ; \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ - python3 -m pip install sentencepiece ; \ - pushd /workspace ; \ - python3 MinimumExample/Example_ONNX_LlamaV2.py --onnx_file /models/ONNX/LlamaV2_7B_FT_float16.onnx \ - --embedding_file /models/embeddings.pth --tokenizer_path tokenizer.model --prompt 'What is the lightest element?' > /workspace/answer.txt ; \ + python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --input /meta-llama2 --small_gpu ;\ popd ; \ " - displayName: 'Run Llama2 demo' + displayName: 'Run Llama2 to Onnx F16 and parity Test' workingDirectory: $(Build.SourcesDirectory) - - - script: | - set -ex - real=$(cat $(Build.SourcesDirectory)/Llama-2-Onnx/answer.txt) - trim_actual=$(tr -dc '[[:print:]]' <<< "$real") - expected="The lightest element is hydrogen. Hydrogen is the lightest element on the periodic table, with an atomic mass of 1.00794 u (unified atomic mass units)." - [ "$expected" == "$trim_actual" ] && exit 0 || exit 1 - displayName: 'Check result' From e6d3518db9e6a19dc8088d5b9d5589a2b3a395d8 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Mon, 5 Feb 2024 09:41:57 -0800 Subject: [PATCH 037/207] Bump gradle/gradle-build-action from 2 to 3 (#19297) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [gradle/gradle-build-action](https://github.com/gradle/gradle-build-action) from 2 to 3.
Release notes

Sourced from gradle/gradle-build-action's releases.

v3.0.0-rc.1

First release candidate of gradle/gradle-build-action@v3.0.0. This release candidate will the first release available under the v3 version tag.

[!IMPORTANT] As of v3 this action has been superceded by gradle/actions/setup-gradle. Any workflow that uses gradle/gradle-build-action@v3 will transparently delegate to gradle/actions/setup-gradle@v3.

Users are encouraged to update their workflows, replacing:

uses: gradle/gradle-build-action@v3

with

uses: gradle/actions/setup-gradle@v3

See the setup-gradle documentation for up-to-date documentation for gradle/actons/setup-gradle.

Changes from gradle-build-action@v2

This release brings some useful and much requested features, including:

  • save and restore the Gradle configuration-cache data
  • add the Job summary content as a PR comment
  • easily publish Build Scans® to the free Gradle Build Scan service
  • compatibility with Node 20

The only major breaking change from gradle-build-action@v2.12.0 is the update to require a Node 20 runtime environment. Aside from that change, this release should generally serve as a drop-in replacement for gradle-build-action@v2.

Changelog

  • [NEW] - Run with NodeJs 20.x (gradle/gradle-build-action#946)
  • [NEW] - Support for save & restore of configuration-cache data (gradle/gradle-build-action#966)
  • [NEW] - Support for automatic adding PR comment with Job Summary content (gradle/gradle-build-action#1020)
  • [NEW] - Make it easy to publish a Build Scan® to https://scans.gradle.com (gradle/gradle-build-action#1044)
  • [NEW] - Added dependency-graph-continue-on-failure input, which can be set to false to force the Job to fail when dependency graph submission fails (gradle/gradle-build-action#1036). Failure modes include:
  • [NEW] - Add dependency-graph: clear option to clear any dependency-graph previously submitted by the job
  • [FIX] Allow cache entries to be reused by jobs with the same ID in different workflows (gradle/gradle-build-action#1017)
    • Workflow name remains part of the cache key, but cache entries generated by the same job id in a different workflow may be restored
  • [FIX] Register pre-installed JDKs in Maven toolchains.xml file (gradle/gradle-build-action#1024)
    • This allows pre-installed JDKs to be auto-detected by Gradle Toolchain support on Windows
  • [FIX] - Update the Gradle Enterprise injection configuration for product rename to Develocity (gradle/gradle-build-action#995)
  • [FIX] - Avoid submitting an empty dependency graph when state is loaded from configuration-cache
  • [DEPRECATION] - Deprecation of the arguments parameter (gradle/gradle-build-action#996)
  • [BREAKING CHANGE] - Remove the gradle-executable input parameter. Use a separate workflow Step to execute a Gradle from a custom location.

... (truncated)

Commits
  • 4a8703f Delegate to 'setup-gradle@v3.0.0-rc.1'
  • 4a39eed Mention setup-gradle in README
  • 272883a Remove all action sources: these have been migrated to 'gradle/actions'
  • 2a8bfcf Delegate action implementation to gradle/actions/setup-gradle
  • e1ada08 Bump the github-actions group with 1 update (#1047)
  • a8e3e5e Apply dependency version updates
  • 2be01ca Build outputs
  • a00827e Bump the npm-dependencies group with 7 updates
  • ad80850 Bump the github-actions group with 2 updates
  • bd6d0a7 Configure explicit java version for config-cache test
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=gradle/gradle-build-action&package-manager=github_actions&previous-version=2&new-version=3)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/publish-java-apidocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish-java-apidocs.yml b/.github/workflows/publish-java-apidocs.yml index 708842e59f9f2..3e553049a186e 100644 --- a/.github/workflows/publish-java-apidocs.yml +++ b/.github/workflows/publish-java-apidocs.yml @@ -30,7 +30,7 @@ jobs: java-version: '11' distribution: 'adopt' - name: Build with Gradle - uses: gradle/gradle-build-action@v2 + uses: gradle/gradle-build-action@v3 with: build-root-directory: java gradle-executable: java/gradlew From aaf32fb1b1d1cb3a3a4250cc2ee132576b7ddf00 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Mon, 5 Feb 2024 18:15:16 +0000 Subject: [PATCH 038/207] phi2 conversion/optimization script (#19338) ### Description This PR adds onnx conversion script for dynamo exported phi2, optimization script, and inference example script A readme file is added as documentation. https://github.com/microsoft/onnxruntime/tree/wangye/phi2_doc/onnxruntime/python/tools/transformers/models/phi2#readme ### Motivation and Context --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- cmake/onnxruntime_python.cmake | 7 + .../python/tools/symbolic_shape_infer.py | 28 + .../tools/transformers/dynamo_onnx_helper.py | 92 ++ .../python/tools/transformers/float16.py | 8 +- .../tools/transformers/fusion_options.py | 15 + .../tools/transformers/models/phi2/README.md | 119 +++ .../transformers/models/phi2/__init__.py | 12 + .../models/phi2/convert_to_onnx.py | 458 ++++++++++ .../models/phi2/inference_example.py | 215 +++++ .../transformers/models/phi2/requirements.txt | 3 + .../python/tools/transformers/onnx_model.py | 5 + .../tools/transformers/onnx_model_phi.py | 839 ++++++++++++++++++ .../python/tools/transformers/optimizer.py | 2 + setup.py | 1 + 14 files changed, 1801 insertions(+), 3 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/dynamo_onnx_helper.py create mode 100644 onnxruntime/python/tools/transformers/models/phi2/README.md create mode 100644 onnxruntime/python/tools/transformers/models/phi2/__init__.py create mode 100644 onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py create mode 100644 onnxruntime/python/tools/transformers/models/phi2/inference_example.py create mode 100644 onnxruntime/python/tools/transformers/models/phi2/requirements.txt create mode 100644 onnxruntime/python/tools/transformers/onnx_model_phi.py diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 456344aa34d95..3f20787e87425 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -473,6 +473,9 @@ file(GLOB onnxruntime_python_transformers_models_llama_src CONFIGURE_DEPENDS file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py" ) +file(GLOB onnxruntime_python_transformers_models_phi2_src CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/phi2/*.py" +) file(GLOB onnxruntime_python_transformers_models_stable_diffusion_src CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/python/tools/transformers/models/stable_diffusion/*.py" ) @@ -543,6 +546,7 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/gpt2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/llama COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/longformer + COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/phi2 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/stable_diffusion COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/t5 COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/transformers/models/whisper @@ -646,6 +650,9 @@ add_custom_command( COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_longformer_src} $/onnxruntime/transformers/models/longformer/ + COMMAND ${CMAKE_COMMAND} -E copy + ${onnxruntime_python_transformers_models_phi2_src} + $/onnxruntime/transformers/models/phi2/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_transformers_models_stable_diffusion_src} $/onnxruntime/transformers/models/stable_diffusion/ diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 9823e8264e17b..31f3a3a2b30d6 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -205,6 +205,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GemmFastGelu": self._infer_GemmFastGelu, "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, + "GroupQueryAttention": self._infer_GroupQueryAttention, "SkipGroupNorm": self._infer_SkipGroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, @@ -471,6 +472,7 @@ def _onnx_infer_single_node(self, node): "PythonOp", "MultiHeadAttention", "GroupNorm", + "GroupQueryAttention", "SkipGroupNorm", "BiasSplitGelu", "BiasAdd", @@ -2409,6 +2411,32 @@ def _infer_SkipLayerNormalization(self, node): # noqa: N802 def _infer_GroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_GroupQueryAttention(self, node): # noqa: N802 + output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type + + past_shape = self._try_get_shape(node, 3) + if past_shape is not None: + vi = self.known_vi_[node.output[1]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + vi = self.known_vi_[node.output[2]] + vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape)) + + if node.input[1] != "" and node.input[2] != "": + self._propagate_shape_and_type(node, 0, 0) + else: + # combined qkv: (batch_size, sequence_length, num_heads * head_size + 2 * kv_num_heads * head_size) + assert node.input[1] == "" and node.input[2] == "" + num_heads = get_attribute(node, "num_heads") + kv_num_heads = get_attribute(node, "kv_num_heads") + query_shape = self._get_shape(node, 0) + if query_shape is not None: + hidden_size = query_shape[2] + if isinstance(hidden_size, int): + head_size = int(hidden_size / (num_heads + 2 * kv_num_heads)) + query_shape[2] = num_heads * head_size + vi = self.known_vi_[node.output[0]] + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, query_shape)) + def _infer_SkipGroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node, 0, 0) if len(node.output) > 1: diff --git a/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py b/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py new file mode 100644 index 0000000000000..bca5ace916082 --- /dev/null +++ b/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py @@ -0,0 +1,92 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging + +import onnx + + +class DynamoOnnxHelper: + """ + Helper class for processing ONNX models exported by torch Dynamo. + """ + + def __init__(self, model: onnx.ModelProto): + self.model = model + + def update_edges(self, edge_mapping: dict) -> None: + """ + Updates the edges in the model according to the given mapping. + """ + for node in self.model.graph.node: + for i in range(len(node.input)): + if node.input[i] in edge_mapping: + node.input[i] = edge_mapping[node.input[i]] + for i in range(len(node.output)): + if node.output[i] in edge_mapping: + node.output[i] = edge_mapping[node.output[i]] + + for graph_input in self.model.graph.input: + if graph_input.name in edge_mapping: + graph_input.name = edge_mapping[graph_input.name] + for graph_output in self.model.graph.output: + if graph_output.name in edge_mapping: + graph_output.name = edge_mapping[graph_output.name] + + def unroll_function(self, func_name: str) -> None: + """ + Unrolls the function with the given name in the model. + """ + logging.info(f"Unrolling function {func_name}...") + nodes_to_remove = [] + nodes_to_add = [] + edges_to_remove = [] + edges_to_add = [] + for node in self.model.graph.node: + if node.op_type == func_name: + nodes_to_remove.append(node) + edges_to_remove.extend(list(node.input) + list(node.output)) + + func_to_remove = None + for f in self.model.functions: + if f.name == func_name: + nodes_to_add.extend(list(f.node)) + edges_to_add.extend(list(f.input) + list(f.output)) + func_to_remove = f + + assert len(edges_to_remove) == len(edges_to_add) + + for node in nodes_to_remove: + self.model.graph.node.remove(node) + for node in nodes_to_add: + self.model.graph.node.append(node) + if func_to_remove is not None: + self.model.functions.remove(func_to_remove) + + edge_mapping = {} + for i in range(len(edges_to_remove)): + k = edges_to_remove[i] + v = edges_to_add[i] + if k != v: + edge_mapping[k] = v + + return self.update_edges(edge_mapping) + + def remove_dropout_layer(self) -> None: + """ + Removes the dropout layer in the model. + """ + logging.info("Removing dropout layer...") + edge_mapping = {} + nodes_to_remove = [] + for node in self.model.graph.node: + if node.op_type.find("Dropout") != -1: + assert len(node.input) == 1 + assert len(node.output) == 1 + edge_mapping[node.output[0]] = node.input[0] + nodes_to_remove.append(node) + for node in nodes_to_remove: + self.model.graph.node.remove(node) + + self.update_edges(edge_mapping) diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index f680a15fc2c1b..48c79b1d5fa0f 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -174,6 +174,7 @@ def convert_float_to_float16( node_block_list=None, force_fp16_initializers=False, force_fp16_inputs=None, + use_bfloat16_as_blocked_nodes_dtype=False, ): """Convert tensor float type in the input ONNX model to tensor float16. @@ -436,6 +437,7 @@ def convert_float_to_float16( node.input[i] = output_name break + accuracy_type = TensorProto.BFLOAT16 if use_bfloat16_as_blocked_nodes_dtype else TensorProto.FLOAT # process the nodes in block list that doesn't support tensor(float16) for node in node_list: # if input's name is in the value_info_list meaning input is tensor(float16) type, @@ -450,10 +452,10 @@ def convert_float_to_float16( new_value_info.CopyFrom(value_info) output_name = node.name + "_input_cast_" + str(i) new_value_info.name = output_name - new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT + new_value_info.type.tensor_type.elem_type = accuracy_type # add Cast node (from tensor(float16) to tensor(float) before current node node_name = node.name + "_input_cast" + str(i) - new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)] + new_node = [helper.make_node("Cast", [input_name], [output_name], to=accuracy_type, name=node_name)] model.graph.node.extend(new_node) # change current node's input name node.input[i] = output_name @@ -469,7 +471,7 @@ def convert_float_to_float16( new_value_info.CopyFrom(value_info) input_name = node.name + "_output_cast_" + str(i) new_value_info.name = input_name - new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT + new_value_info.type.tensor_type.elem_type = accuracy_type # add Cast node (from tensor(float) to tensor(float16) after current node node_name = node.name + "_output_cast" + str(i) new_node = [helper.make_node("Cast", [input_name], [output], to=10, name=node_name)] diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index b9b92d2fe8a00..c65464a3069c5 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from argparse import ArgumentParser +from enum import Enum class AttentionMaskFormat: @@ -19,6 +20,15 @@ class AttentionMaskFormat: NoMask = 3 +class AttentionOpType(Enum): + Attention = "Attention" + MultiHeadAttention = "MultiHeadAttention" + GroupQueryAttention = "GroupQueryAttention" + + def __str__(self): + return self.value + + class FusionOptions: """Options of fusion in graph optimization""" @@ -57,6 +67,8 @@ def __init__(self, model_type): elif model_type == "vit": self.attention_mask_format = AttentionMaskFormat.NoMask + self.attention_op_type = None + # options for stable diffusion if model_type in ["unet", "vae", "clip"]: self.enable_nhwc_conv = True @@ -76,6 +88,9 @@ def use_raw_attention_mask(self, use_raw_mask=True): def disable_attention_mask(self): self.attention_mask_format = AttentionMaskFormat.NoMask + def set_attention_op_type(self, attn_op_type: AttentionOpType): + self.attention_op_type = attn_op_type + @staticmethod def parse(args): options = FusionOptions(args.model_type) diff --git a/onnxruntime/python/tools/transformers/models/phi2/README.md b/onnxruntime/python/tools/transformers/models/phi2/README.md new file mode 100644 index 0000000000000..526fdc3dd7863 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/README.md @@ -0,0 +1,119 @@ +# Phi2 Optimizations +## Prerequisites +A Linux machine for [TorchDynamo-based ONNX Exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter)\ +Install onnx, onnxscript and transformers by running +```bash +pip install -r requirements.txt +``` +To export ONNX, PyTorch version 2.2.0 or higher is required. The [official website](https://pytorch.org/) offers packages compatible with CUDA 11.8 and 12.1. Please select the appropriate version according to your needs. +\ +\ +**There are two options to run the conversion script:**\ +_From source:_ +```bash +pip install onnxruntime-gpu==1.17.0 # or onnxruntime==1.17.0 if using cpu +git clone git@github.com:microsoft/onnxruntime.git +cd onnxruntime/onnxruntime/python/tools/transformers +python -m models.phi2.convert_to_onnx -h +``` +_From wheel:_ \ +Install [ORT nightly package](https://onnxruntime.ai/docs/install/#inference-install-table-for-all-languages) +```bash +python -m onnxruntime.transformers.models.phi2.convert_to_onnx -h +``` + +## Export optimized phi2 onnx model for different scenarios +**Export FP32 ONNX model for Nvidia GPUs** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp32_gpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_gpu +``` +\ +**Export FP16 ONNX model for Nvidia GPUs** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp16_gpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu +``` +\ +**Export INT4 ONNX model for Nvidia GPUs** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --int4_gpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_gpu +``` +\ +**Export FP16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp16_gpu_sm8x +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu_sm8x +``` +\ +**Export INT4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --int4_gpu_sm8x +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_gpu_sm8x +``` +\ +**Export FP32 ONNX model for CPU** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp32_cpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_cpu +``` +\ +**Export INT4 ONNX model for CPU** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --int4_cpu +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --int4_cpu +``` +\ +**Export all at once** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp32_cpu --int4_cpu --fp32_gpu --fp16_gpu --int4_gpu --fp16_gpu_sm8x --int4_gpu_sm8x +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp32_cpu --int4_cpu --fp32_gpu --fp16_gpu --int4_gpu --fp16_gpu_sm8x --int4_gpu_sm8x +``` +## Run example with ORT +**(e.g) Export FP16 and INT4 ONNX models for Nvidia GPUs with CUDA architecture SM=80~89 and run examples.** \ +_From source:_ +``` +python -m models.phi2.convert_to_onnx --fp16_gpu_sm8x --int4_gpu_sm8x --run_example +``` +_From wheel:_ +``` +python -m onnxruntime.transformers.models.phi2.convert_to_onnx --fp16_gpu_sm8x --int4_gpu_sm8x --run_example +``` +The inference example currently supports all models running on CUDA. + +## Limitations +- TorchDynamo-based ONNX Exporter only supports Linux. +- The program may not run as expected if the machine has limited memory. e.g Dynamo export may use ~11.6GB; Optimization may use ~4.5GB for each. diff --git a/onnxruntime/python/tools/transformers/models/phi2/__init__.py b/onnxruntime/python/tools/transformers/models/phi2/__init__.py new file mode 100644 index 0000000000000..e80f36a391fe1 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/__init__.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os +import sys + +sys.path.append(os.path.dirname(__file__)) + +transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", "..")) +if transformers_dir not in sys.path: + sys.path.append(transformers_dir) diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py new file mode 100644 index 0000000000000..ac3ca40e41be0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -0,0 +1,458 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import argparse +import logging +import os +from pathlib import Path + +import onnx +import torch +from benchmark_helper import Precision +from fusion_options import AttentionOpType +from transformers import AutoConfig, AutoModelForCausalLM + +from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer + + +class ConvertPhi2ToONNX: + def __init__( + self, + device: torch.device, + model_class: str = "microsoft/phi-2", + cache_dir: str = "./cache", + ): + self.model_class = model_class + self.device = device + self.cache_dir = cache_dir + self.phi_config = AutoConfig.from_pretrained(self.model_class, trust_remote_code=True, cache_dir=self.cache_dir) + self.phi_model = None + self.batch_size = 2 + self.sequence_length = 8 + self.attn_op_type = None + self.precision = None + self.block_size = 16 + self.accuracy_level = None + + def set_quantization_params(self, block_size: int, accuracy_level: int | None): + self.block_size = block_size + self.accuracy_level = accuracy_level + + def init_attn_type_and_precision(self, attn_op_type: AttentionOpType, precision: Precision): + self.attn_op_type = attn_op_type + self.precision = precision + + def erase_onnx_model(self, onnx_path: str) -> None: + assert onnx_path.endswith(".onnx") + if not os.path.exists(onnx_path): + return + + model = onnx.load_model(onnx_path, load_external_data=False) + onnx_data_path = None + for initializer in model.graph.initializer: + if initializer.data_location == 1 and initializer.external_data[0].key == "location": + onnx_data_path = "./" + initializer.external_data[0].value + break + logging.info(f"Erasing {onnx_path}...") + os.remove(onnx_path) + if onnx_data_path is not None: + onnx_data_path = os.path.join(Path(onnx_path).parent, onnx_data_path) + logging.info(f"Erasing {onnx_data_path}...") + os.remove(onnx_data_path) + + def get_phi2_torch_model(self): + logging.info("Loading phi2 torch model...") + if self.phi_model is not None: + return + self.phi_model = AutoModelForCausalLM.from_pretrained( + self.model_class, trust_remote_code=True, cache_dir=self.cache_dir + ) + self.phi_model.eval() + self.phi_model.to(self.device) + + def get_phi2_torch_inputs(self, batch_size: int, sequence_length: int): + input_ids = torch.randint( + low=0, + high=self.phi_config.vocab_size, + size=(batch_size, sequence_length), + dtype=torch.int64, + device=self.device, + ) + self.get_phi2_torch_model() + torch_inputs = self.phi_model.prepare_inputs_for_generation( + input_ids, past_key_values=self.phi_model(input_ids, use_cache=True)["past_key_values"] + ) + return torch_inputs["input_ids"], torch_inputs["attention_mask"], torch_inputs["past_key_values"] + + def dynamo_export(self, onnx_path: str): + input_ids, attention_mask, past_key_values = self.get_phi2_torch_inputs(self.batch_size, self.sequence_length) + self.phi_model(input_ids, attention_mask=attention_mask, past_key_values=past_key_values) + + from torch._dynamo import config + + config.capture_scalar_outputs = True + + logging.info("Exporting Phi2 torch model to ONNX...") + torch.onnx.dynamo_export( + self.phi_model, + input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + export_options=torch.onnx.ExportOptions(dynamic_shapes=True), + ).save(onnx_path) + onnx.checker.check_model(onnx_path) + onnx.shape_inference.infer_shapes_path(onnx_path) + + def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str): + from fusion_options import FusionOptions + from optimizer import optimize_model + + optimization_options = FusionOptions("phi") + optimization_options.set_attention_op_type(self.attn_op_type) + optimizer = optimize_model( + onnx_path, + model_type="phi", + num_heads=self.phi_config.num_attention_heads, + hidden_size=self.phi_config.hidden_size, + opt_level=0, + optimization_options=optimization_options, + only_onnxruntime=False, + ) + + fused_op_count = optimizer.get_fused_operator_statistics() + if optimizer.is_fully_optimized(fused_op_count): + logging.info("Model is fully optimized.") + else: + logging.info("Model is not fully optimized.") + + if self.precision == Precision.FLOAT32: + optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True) + return + + if ( + self.precision == Precision.FLOAT16 or self.precision == Precision.INT4 + ) and self.attn_op_type != AttentionOpType.MultiHeadAttention: + # We keep last three layers of Attention as float32 or bfloat16 to avoid overflow. + node_block_list = [ + "GroupQueryAttention_29", + "GroupQueryAttention_30", + "GroupQueryAttention_31", + "Attention_29", + "Attention_30", + "Attention_31", + ] + logging.info("Converting onnx model to float16/bfloat16...") + optimizer.convert_float_to_float16( + keep_io_types=False, + node_block_list=node_block_list, + use_symbolic_shape_infer=True, + use_bfloat16_as_blocked_nodes_dtype=self.attn_op_type == AttentionOpType.GroupQueryAttention, + ) + logging.info("Converting onnx model to float16/bfloat16 done.") + + if self.precision == Precision.FLOAT16: + optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True) + return + else: + assert self.precision == Precision.INT4 + quant = MatMul4BitsQuantizer( + model=optimizer.model, + block_size=self.block_size, + is_symmetric=True, + accuracy_level=self.accuracy_level, + ) + quant.process() + quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True) + + +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--fp32_cpu", + required=False, + action="store_true", + help="Generate fp32 ONNX model for CPU", + ) + + parser.add_argument( + "--int4_cpu", + required=False, + action="store_true", + help="Generate int4 ONNX model for CPU", + ) + + parser.add_argument( + "--fp32_gpu", + required=False, + action="store_true", + help="Generate fp32 ONNX model for Nvidia GPUs", + ) + + parser.add_argument( + "--fp16_gpu", + required=False, + action="store_true", + help="Generate fp16 ONNX model for Nvidia GPUs", + ) + + parser.add_argument( + "--int4_gpu", + required=False, + action="store_true", + help="Generate int4 ONNX model for Nvidia GPUs", + ) + + parser.add_argument( + "--fp16_gpu_sm8x", + required=False, + action="store_true", + help="Generate fp16 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89", + ) + + parser.add_argument( + "--int4_gpu_sm8x", + required=False, + action="store_true", + help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89", + ) + + parser.add_argument( + "--overwrite", + required=False, + action="store_true", + help="Overwrite existing ONNX models", + ) + + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default="./cache", + help="The cache directory for the pytorch model", + ) + + parser.add_argument( + "--device_id", + required=False, + type=int, + default=0, + help="The device id for the pytorch model", + ) + + parser.add_argument( + "--run_example", + required=False, + action="store_true", + help="Run ORT inference example", + ) + + parser.add_argument( + "--skip_export", + required=False, + action="store_true", + help="Skip exporting ONNX model", + ) + + parser.add_argument( + "--output_dir", + type=str, + help="The output directory for the ONNX models", + default="phi2_onnx_models", + ) + + parser.add_argument( + "--block_size", + required=False, + default=16, + type=int, + help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.", + ) + + parser.add_argument( + "--int4_accuracy_level", + required=False, + type=int, + help="Accuracy level of the 4-bit quantized MatMul computation. " + "Refer to the MatMulNBits contrib op's 'accuracy_level' attribute for details " + "(https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#commicrosoftmatmulnbits).", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_arguments() + + device = torch.device("cuda", args.device_id) if torch.cuda.is_available() else torch.device("cpu") + + converter = ConvertPhi2ToONNX(device, cache_dir=args.cache_dir) + converter.set_quantization_params(args.block_size, args.int4_accuracy_level) + + output_dir = args.output_dir + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + original_onnx_path = os.path.join(output_dir, "phi2_original.onnx") + + if not args.skip_export: + if not os.path.exists(original_onnx_path) or args.overwrite: + converter.dynamo_export(original_onnx_path) + + model_type_to_args = { + "fp32_cpu": ( + AttentionOpType.MultiHeadAttention, + Precision.FLOAT32, + os.path.join(output_dir, "phi2_decoder_fp32_cpu.onnx"), + ), + "int4_cpu": ( + AttentionOpType.MultiHeadAttention, + Precision.INT4, + os.path.join(output_dir, "phi2_decoder_int4_cpu.onnx"), + ), + "fp32_gpu": ( + AttentionOpType.Attention, + Precision.FLOAT32, + os.path.join(output_dir, "phi2_decoder_fp32_gpu.onnx"), + ), + "fp16_gpu": ( + AttentionOpType.Attention, + Precision.FLOAT16, + os.path.join(output_dir, "phi2_decoder_fp16_gpu.onnx"), + ), + "int4_gpu": (AttentionOpType.Attention, Precision.INT4, os.path.join(output_dir, "phi2_decoder_int4_gpu.onnx")), + "fp16_gpu_sm8x": ( + AttentionOpType.GroupQueryAttention, + Precision.FLOAT16, + os.path.join(output_dir, "phi2_decoder_fp16_gpu_sm8x.onnx"), + ), + "int4_gpu_sm8x": ( + AttentionOpType.GroupQueryAttention, + Precision.INT4, + os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"), + ), + } + + if not args.skip_export: + from multiprocessing import Process + + def run_optimize_phi2_onnx( + converter: ConvertPhi2ToONNX, + original_onnx_path: str, + attention_type: AttentionOpType, + precision: Precision, + optimized_onnx_path: str, + ): + converter.init_attn_type_and_precision(attention_type, precision) + converter.optimize_phi2_onnx(original_onnx_path, optimized_onnx_path) + + processes = [] + if args.fp32_cpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_cpu"]) + ) + ) + + if args.int4_cpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_cpu"]) + ) + ) + + if args.fp32_gpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp32_gpu"]) + ) + ) + + if args.fp16_gpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu"]) + ) + ) + + if args.int4_gpu: + processes.append( + Process( + target=run_optimize_phi2_onnx, args=(converter, original_onnx_path, *model_type_to_args["int4_gpu"]) + ) + ) + + if args.fp16_gpu_sm8x: + processes.append( + Process( + target=run_optimize_phi2_onnx, + args=(converter, original_onnx_path, *model_type_to_args["fp16_gpu_sm8x"]), + ) + ) + + if args.int4_gpu_sm8x: + processes.append( + Process( + target=run_optimize_phi2_onnx, + args=(converter, original_onnx_path, *model_type_to_args["int4_gpu_sm8x"]), + ) + ) + + [p.start() for p in processes] + [p.join() for p in processes] + + if args.run_example: + from inference_example import run_phi2 + + if args.fp16_gpu_sm8x: + logging.info("Running fp16_gpu_sm8x example...") + run_phi2( + onnx_model_path=model_type_to_args["fp16_gpu_sm8x"][2], + use_buffer_share=True, + device_id=args.device_id, + use_step=True, + ) + if args.int4_gpu_sm8x: + logging.info("Running int4_gpu_sm8x example...") + run_phi2( + onnx_model_path=model_type_to_args["int4_gpu_sm8x"][2], + use_buffer_share=True, + device_id=args.device_id, + use_step=True, + ) + if args.fp32_gpu: + logging.info("Running fp32_gpu example...") + run_phi2( + onnx_model_path=model_type_to_args["fp32_gpu"][2], + use_buffer_share=False, + device_id=args.device_id, + packed_kv=True, + use_fp16=False, + ) + if args.fp16_gpu: + logging.info("Running fp16_gpu example...") + run_phi2( + onnx_model_path=model_type_to_args["fp16_gpu"][2], + use_buffer_share=False, + device_id=args.device_id, + packed_kv=True, + ) + if args.int4_gpu: + logging.info("Running int4_gpu example...") + run_phi2( + onnx_model_path=model_type_to_args["int4_gpu"][2], + use_buffer_share=False, + device_id=args.device_id, + packed_kv=True, + ) + if args.fp32_cpu or args.int4_cpu: + raise NotImplementedError("CPU inference example is not implemented yet.") + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/python/tools/transformers/models/phi2/inference_example.py b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py new file mode 100644 index 0000000000000..28828ffb853cb --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py @@ -0,0 +1,215 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import numpy as np +import torch +from transformers import AutoTokenizer + +import onnxruntime as ort + +pt_to_np = { + "torch.int32": np.int32, + "torch.int64": np.int64, + "torch.float32": np.float32, + "torch.float16": np.float16, +} + + +class ORTGenerator: + def __init__(self, decoder_path): + self.onnx_decoder_path = decoder_path + self.num_heads = 32 + self.head_size = 80 + self.num_layers = 32 + self.max_sequence_length = 2048 + + def get_initial_inputs_and_outputs(self, encodings_dict): + self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32 + + input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32) + attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32) + step = torch.tensor([0], device=self.device, dtype=torch.int64) + + inputs = { + "input_ids": input_ids.contiguous(), + "attention_mask": attention_mask.contiguous(), + } + + if self.use_step: + inputs["step"] = step.contiguous() + + batch_size, sequence_length = input_ids.shape + + past_seq_length = self.max_sequence_length if self.use_buffer_share else 0 + past_shape = ( + (2, batch_size, self.num_heads, past_seq_length, self.head_size) + if self.packed_kv + else (batch_size, self.num_heads, past_seq_length, self.head_size) + ) + for i in range(self.num_layers): + past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype) + inputs.update( + {f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()} + ) if not self.packed_kv else inputs.update({f"past_{i}": past.contiguous()}) + + logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype) + outputs = {"logits": logits.contiguous()} + + if not self.use_buffer_share: + present_shape = ( + (2, batch_size, self.num_heads, sequence_length, self.head_size) + if self.packed_kv + else (batch_size, self.num_heads, sequence_length, self.head_size) + ) + for i in range(self.num_layers): + present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype) + outputs.update( + {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.contiguous()} + ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) + + return inputs, outputs + + def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: dict): + io_binding = model.io_binding() + device = None + + for k, v in inputs.items(): + io_binding.bind_input( + name=k, + device_type=v.device.type, + device_id=0 if v.device.type == "cpu" else v.device.index, + element_type=pt_to_np[repr(v.dtype)], + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) + device = v.device + + for output in model.get_outputs(): + name = output.name + if self.use_buffer_share and "present" in name: + v = inputs[name.replace("present", "past")] + io_binding.bind_output( + name=name, + device_type=v.device.type, + device_id=v.device.index, + element_type=(np.float16 if self.use_fp16 else np.float32), + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) + else: + v = outputs[name] + io_binding.bind_output( + name=name, + device_type=device.type, + device_id=0 if device.type == "cpu" else device.index, + element_type=(np.float16 if self.use_fp16 else np.float32), + shape=tuple(v.shape), + buffer_ptr=v.data_ptr(), + ) + + return io_binding + + def create_session(self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False): + sess_options = ort.SessionOptions() + ep = ("CUDAExecutionProvider", {"device_id": device_id}) if device_id >= 0 else "CPUExecutionProvider" + self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep]) + + self.device = torch.device("cuda", device_id) if torch.cuda.is_available() else torch.device("cpu") + self.use_fp16 = use_fp16 + self.use_buffer_share = use_buffer_share + self.packed_kv = packed_kv + self.use_step = use_step + + self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) + self.tokenizer.pad_token = "[PAD]" + + def generate(self, prompt, max_length): + encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True) + + inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict) + + all_token_ids = inputs["input_ids"].clone() + batch_size, sequence_length = all_token_ids.shape + + current_length = sequence_length + has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool) + + while current_length < max_length: + io_binding = self.apply_io_binding(self.sess, inputs, outputs) + + io_binding.synchronize_inputs() + self.sess.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + + # Sample with argmax (greedy search) + next_token_logits = outputs["logits"][:, -1, :] + next_tokens = torch.argmax(next_token_logits, dim=-1) + + # Check if we previously reached EOS token id or if generated token id is EOS token id + has_eos = has_eos | next_tokens == self.tokenizer.eos_token_id + + # Determine which new tokens to add to list of all token ids + # Add EOS token ids for batch entries that ended early (ragged batching scenario where some batch entries ended early and some haven't) + tokens_to_add = next_tokens.masked_fill(has_eos, self.tokenizer.eos_token_id).reshape([batch_size, 1]) + all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1) + + # Return early if all batch entries have reached EOS token id + if torch.all(has_eos): + break + + # Update inputs for next inference run + current_length += 1 + inputs["input_ids"] = tokens_to_add.to(torch.int32) + if self.use_step: + inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64) + inputs["attention_mask"] = torch.cat([inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1).to( + torch.int32 + ) + + # Set logits to zeros for next inference run and re-use memory buffer + if outputs["logits"].shape[1] != 1: + outputs["logits"] = outputs["logits"][:, :1, :].contiguous() + outputs["logits"].zero_() + + if not self.use_buffer_share: + for i in range(self.num_layers): + if not self.packed_kv: + inputs[f"past_key_{i}"] = outputs[f"present_key_{i}"] + inputs[f"past_value_{i}"] = outputs[f"present_value_{i}"] + else: + inputs[f"past_{i}"] = outputs[f"present_{i}"] + + new_sequence_length = inputs["attention_mask"].shape[1] + present_shape = ( + (2, batch_size, self.num_heads, new_sequence_length, self.head_size) + if self.packed_kv + else (batch_size, self.num_heads, new_sequence_length, self.head_size) + ) + for i in range(self.num_layers): + present = torch.zeros(present_shape, device=self.device, dtype=self.torch_dtype) + outputs.update( + {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.clone().contiguous()} + ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) + + texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True) + return texts + + +def run_phi2(onnx_model_path, use_buffer_share, device_id, packed_kv=False, use_fp16=True, use_step=False): + prompt = [ + '''```python + def print_prime(n): + """ + Print all primes between 1 and n + """''' + ] + + generator = ORTGenerator(onnx_model_path) + generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step) + texts = generator.generate(prompt, max_length=200) + + for i in range(len(texts)): + print("Prompt: ", prompt[i]) + print("Texts: ", texts[i]) diff --git a/onnxruntime/python/tools/transformers/models/phi2/requirements.txt b/onnxruntime/python/tools/transformers/models/phi2/requirements.txt new file mode 100644 index 0000000000000..af6f441c149d0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/phi2/requirements.txt @@ -0,0 +1,3 @@ +onnx>=1.15.0 +transformers>=4.36.2 +onnxscript>=0.1.0.dev20240126 diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 9d1066b6e372b..0e20b1f871645 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -82,6 +82,10 @@ def output_name_to_node(self): output_name_to_node[output_name] = node return output_name_to_node + def functions(self): + all_functions = [list(self.model.functions)] + return all_functions + def nodes(self): all_nodes = [] for graph in self.graphs(): @@ -733,6 +737,7 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): "node_block_list", "force_fp16_initializers", "force_fp16_inputs", + "use_bfloat16_as_blocked_nodes_dtype", ] if key in kwargs } diff --git a/onnxruntime/python/tools/transformers/onnx_model_phi.py b/onnxruntime/python/tools/transformers/onnx_model_phi.py new file mode 100644 index 0000000000000..df8830b0d0495 --- /dev/null +++ b/onnxruntime/python/tools/transformers/onnx_model_phi.py @@ -0,0 +1,839 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from logging import getLogger +from typing import List, Optional + +import numpy as np +from dynamo_onnx_helper import DynamoOnnxHelper +from fusion_base import Fusion +from fusion_options import AttentionOpType, FusionOptions +from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization +from fusion_utils import NumpyHelper +from onnx import ModelProto, NodeProto, TensorProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = getLogger(__name__) + + +class ProcessGemmWFunc: + def __call__(self, x): + return np.transpose(x, (1, 0)) + + +class ProcessMatMulQFunc: + def __call__(self, x): + return np.transpose(np.split(x, 3, 0)[0], (1, 0)) + + +class ProcessMatMulKFunc: + def __call__(self, x): + return np.transpose(np.split(x, 3, 0)[1], (1, 0)) + + +class ProcessMatMulVFunc: + def __call__(self, x): + return np.transpose(np.split(x, 3, 0)[2], (1, 0)) + + +class ProcessBiasQFunc: + def __call__(self, x): + x = np.split(x, 3, -1)[0] + return x + + +class ProcessBiasKFunc: + def __call__(self, x): + x = np.split(x, 3, -1)[1] + return x + + +class ProcessBiasVFunc: + def __call__(self, x): + x = np.split(x, 3, -1)[2] + return x + + +class ProcessRotCacheFunc: + def __call__(self, x): + # half rotary embedding + assert len(x.shape) == 2 + if x.shape[1] == 32: + return x[:, 0:16] + return x + + +# TODO: move to a seperate file +class Fission(Fusion): + def __init__( + self, + model: OnnxModel, + nodes_to_find: List[str], + ): + super().__init__(model, "DONOTUSE", nodes_to_find) + + def set_attention_op_type(self, attn_op_type: AttentionOpType): + self.attn_op_type = attn_op_type + + def get_uname(self, layer_id, name): + return name + "_" + str(layer_id) + + def get_io_by_name(self, node, name): + for input in node.input: + if input == name or input.endswith(name) or input.startswith(name): + return input + for output in node.output: + if output == name or output.endswith(name) or output.startswith(name): + return output + raise Exception(f"input {name} not found in node {node.name}") + + def process_initializer(self, initializer_name, functor, custom_name=None): + i = self.model.get_initializer(initializer_name) + i_np_array = NumpyHelper.to_array(i) + processed_i_np_array = functor(i_np_array) + new_tensor = helper.make_tensor( + initializer_name + "_processed" if custom_name is None else custom_name, + data_type=TensorProto.FLOAT, + dims=processed_i_np_array.shape, + vals=processed_i_np_array.flatten().tobytes(), + raw=True, + ) + self.model.add_initializer(new_tensor, self.this_graph_name) + return new_tensor.name + + def add_fp32_value_info(self, name): + new_value_info = self.model.graph().value_info.add() + new_value_info.name = name + new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT + + def add_int64_value_info(self, name): + new_value_info = self.model.graph().value_info.add() + new_value_info.name = name + new_value_info.type.tensor_type.elem_type = TensorProto.INT64 + + def replace_fp32_value_info(self, name, shape): + for value_info in self.model.graph().value_info: + if value_info.name == name: + self.model.graph().value_info.remove(value_info) + break + new_value_info = helper.make_tensor_value_info( + name, + elem_type=TensorProto.FLOAT, + shape=shape, + ) + self.model.graph().value_info.extend([new_value_info]) + + def set_unique_name_and_add_nodes( + self, subgraph_nodes: List[NodeProto], layer_id: int, layer_known_edges_names: List[str] + ): + for new_node in subgraph_nodes: + for i, name in enumerate(new_node.input): + if name == "": + continue + elif name not in layer_known_edges_names: + new_node.input[i] = self.get_uname(layer_id, name) + self.add_fp32_value_info(new_node.input[i]) + for i, name in enumerate(new_node.output): + if name == "": + continue + elif name not in layer_known_edges_names: + new_node.output[i] = self.get_uname(layer_id, name) + self.add_fp32_value_info(new_node.output[i]) + new_node.name = self.get_uname(layer_id, new_node.name) + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + def layernorm(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 3 + assert len(outputs) == 1 + node = helper.make_node( + "LayerNormalization", + inputs=inputs, + outputs=outputs, + name=prefix + "_LayerNormalization", + epsilon=9.999999747378752e-06, + ) + return [node] + + def gemm(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 3 + assert len(outputs) == 1 + matmul = helper.make_node( + "MatMul", + inputs=[inputs[0], inputs[1]], + outputs=[prefix + "matmul_out"], + name=prefix + "MatMul", + ) + add = helper.make_node( + "Add", + inputs=[prefix + "matmul_out", inputs[2]], + outputs=outputs, + name=prefix + "Bias", + ) + return [matmul, add] + + def rotary(self, inputs: List[str], outputs: List[str], prefix: str = "", rot_dim=32, num_heads=32): + assert len(inputs) == 4 + assert len(outputs) == 1 + node = helper.make_node( + "RotaryEmbedding", + inputs=inputs, + outputs=outputs, + name=prefix + "RotaryEmbedding", + domain="com.microsoft", + rotary_embedding_dim=rot_dim, + num_heads=num_heads, + ) + return [node] + + def fastgelu(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 1 + assert len(outputs) == 1 + node = helper.make_node( + "FastGelu", + inputs=inputs, + outputs=outputs, + name=prefix + "FastGelu", + domain="com.microsoft", + ) + return [node] + + def add(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 2 + assert len(outputs) == 1 + node = helper.make_node( + "Add", + inputs=inputs, + outputs=outputs, + name=prefix + "Add", + ) + return [node] + + def mha(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 8 + assert len(outputs) == 3 + node = helper.make_node( + "MultiHeadAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "MultiHeadAttention", + domain="com.microsoft", + num_heads=num_heads, + unidirectional=1, + ) + return [node] + + def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 7 + assert len(outputs) == 3 + node = helper.make_node( + "GroupQueryAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "GroupQueryAttention", + domain="com.microsoft", + num_heads=num_heads, + kv_num_heads=num_heads, + ) + return [node] + + def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 5 + assert len(outputs) == 2 + node = helper.make_node( + "Attention", + inputs=inputs, + outputs=outputs, + name=prefix + "Attention", + domain="com.microsoft", + num_heads=num_heads, + unidirectional=1, + do_rotary=1, + rotary_embedding_dim=32, + ) + return [node] + + +class Phi2PreProcessor(DynamoOnnxHelper): + def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): + super().__init__(model) + self.num_hidden_layers = 32 + self.num_attention_heads = num_heads + self.hidden_size = hidden_size + + self.phi2_edge_dict = self.get_phi2_edge_dict() + self.func_name = "modeling_phi_PhiModel_model_1" + + def get_phi2_edge_dict(self) -> dict: + edge_dict = {} + edge_dict["lm_head_1"] = "logits" + edge_dict["l_input_ids_"] = "input_ids" + edge_dict["key_states"] = "past_key_0" + edge_dict["value_states"] = "past_value_0" + for i in range(self.num_hidden_layers): + edge_dict[f"key_states_{i}"] = f"past_key_{i}" + edge_dict[f"value_states_{i}"] = f"past_value_{i}" + edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}" + edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}" + return edge_dict + + def simplify_phi2_op_type(self): + phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers" + for node in self.model.graph.node: + index = node.op_type.find(phi2_transformer_layer_name) + if index != -1: + node.op_type = node.op_type[index:] + + def process_graph_io(self, attn_op_type: AttentionOpType): + self.use_attn = attn_op_type == AttentionOpType.Attention + graph = self.model.graph + new_inputs = [] + for vi in graph.input: + if "input_ids" in vi.name: + vi_iid = helper.make_tensor_value_info( + vi.name, + elem_type=TensorProto.INT32, + shape=["batch_size", "seq_len"], + ) + vi_pid = helper.make_tensor_value_info( + "step", + elem_type=TensorProto.INT64, + shape=[1], + ) + vi_mask = helper.make_tensor_value_info( + "attention_mask", + elem_type=TensorProto.INT32, + shape=["batch_size", "seq_len"], + ) + new_inputs.extend([vi_iid, vi_pid, vi_mask]) + if not self.use_attn: + if "past_key" in vi.name or "past_value" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "batch_size", + self.num_attention_heads, + "past_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_inputs.extend([vi_cache]) + else: + if "past_key" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name.replace("past_key", "past"), + elem_type=vi.type.tensor_type.elem_type, + shape=[ + 2, + "batch_size", + self.num_attention_heads, + "past_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_inputs.extend([vi_cache]) + + graph.ClearField("input") + graph.input.extend(new_inputs) + + new_outputs = [] + for i, vi in enumerate(graph.output): + if i == 0: + new_outputs.extend([vi]) + else: + if not self.use_attn: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "batch_size", + self.num_attention_heads, + "total_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_outputs.extend([vi_cache]) + else: + if "present_key" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name.replace("present_key", "present"), + elem_type=vi.type.tensor_type.elem_type, + shape=[ + 2, + "batch_size", + self.num_attention_heads, + "total_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_outputs.extend([vi_cache]) + + graph.ClearField("output") + graph.output.extend(new_outputs) + + def preprocess_onnx(self, attn_op_type: AttentionOpType): + function_name = None + for func in self.model.functions: + if func.name.endswith(self.func_name): + function_name = func.name + break + assert function_name is not None + self.unroll_function(function_name) + self.update_edges(self.phi2_edge_dict) + self.simplify_phi2_op_type() + self.remove_dropout_layer() + self.process_graph_io(attn_op_type) + + +class FissionTransformerEmbeddingPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + logger.info("Optimizing %s...", node.name) + + assert len(node.input) == 2 + assert len(node.output) == 1 + + input = node.input[0] + output = node.output[0] + + embedding = self.get_io_by_name(node, "embed_tokens.weight") + + layer_known_edges_names = [input, output, embedding] + + subgraph_nodes = [ + helper.make_node( + "Gather", + inputs=[embedding, input], + outputs=[output], + name="Embedding_Gather", + ), + ] + + self.set_unique_name_and_add_nodes(subgraph_nodes, 0, layer_known_edges_names) + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerLayerNormPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + logger.info("Optimizing %s...", node.name) + + assert len(node.input) == 3 + assert len(node.output) == 1 + + input = node.input[0] + output = node.output[0] + + ln_weight = self.get_io_by_name(node, "final_layernorm.weight") + ln_bias = self.get_io_by_name(node, "final_layernorm.bias") + + layer_known_edges_names = [input, output, ln_weight, ln_bias] + + subgraph_nodes = [] + subgraph_nodes.extend(self.layernorm([input, ln_weight, ln_bias], [output], "Final")) + + self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names) + + self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(output, ["batch_size", "seq_len", "hidden_size"]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerCausalLMHeadPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + logger.info("Optimizing %s...", node.name) + + assert len(node.input) == 5 + assert len(node.output) == 1 + + input = node.input[2] + output = node.output[0] + + fc_weight = self.process_initializer(self.get_io_by_name(node, "lm_head.weight"), ProcessGemmWFunc()) + fc_bias = self.get_io_by_name(node, "lm_head.bias") + + layer_known_edges_names = [input, output, fc_weight, fc_bias] + + subgraph_nodes = [] + subgraph_nodes.extend(self.gemm([input, fc_weight, fc_bias], [output], "LMHead_")) + + self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names) + + self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(output, ["batch_size", "seq_len", 51200]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerBlockPhi(Fission): + def __init__( + self, + model: OnnxModel, + num_heads: int, + ): + self.num_heads = num_heads + max_num_layers = 32 + self.func_to_layer_id = {} + nodes_to_find = [] + for layer in range(max_num_layers): + func_name = f"modeling_phi_PhiDecoderLayer_model_layers_{layer}_1" + nodes_to_find.append(func_name) + self.func_to_layer_id[func_name] = layer + + super().__init__(model, nodes_to_find) + + def get_layer_id(self, node): + return self.func_to_layer_id[node.op_type] + + def get_gqa_aux_nodes(self): + gqa_aux_nodes = [ + helper.make_node( + "Cast", + inputs=["attention_mask"], + outputs=["mask_int64"], + name="Cast_gqa_aux_0", + to=TensorProto.INT64, + ), + helper.make_node( + "ReduceSum", + inputs=["mask_int64", "one"], + outputs=["mask_row_sums"], + name="ReduceSum_gqa_aux", + ), + helper.make_node( + "Sub", + inputs=["mask_row_sums", "one"], + outputs=["seqlens_k_int64"], + name="Sub_gqa_aux", + ), + helper.make_node( + "Cast", + inputs=["seqlens_k_int64"], + outputs=["seqlens_k"], + name="Cast_gqa_aux_1", + to=TensorProto.INT32, + ), + helper.make_node("Shape", inputs=["mask_int64"], outputs=["mask_shape"], name="Shape_gqa_aux_0"), + helper.make_node( + "Gather", + inputs=["mask_shape", "one"], + outputs=["total_seq_len_int64"], + name="Gather_gqa_aux_0", + axis=0, + ), + helper.make_node( + "Cast", + inputs=["total_seq_len_int64"], + outputs=["total_sequence_length"], + name="Cast_gqa_aux_2", + to=TensorProto.INT32, + ), + ] + return gqa_aux_nodes + + def pack_qkv_gemm(self, q_w, k_w, v_w, q_b, k_b, v_b, weight_name, bias_name): + q_weight = self.model.get_initializer(q_w) + k_weight = self.model.get_initializer(k_w) + v_weight = self.model.get_initializer(v_w) + qw = np.transpose(NumpyHelper.to_array(q_weight), (1, 0)) + kw = np.transpose(NumpyHelper.to_array(k_weight), (1, 0)) + vw = np.transpose(NumpyHelper.to_array(v_weight), (1, 0)) + qkv_weight = np.stack((qw, kw, vw), axis=1) + + q_bias = self.model.get_initializer(q_b) + k_bias = self.model.get_initializer(k_b) + v_bias = self.model.get_initializer(v_b) + qb = NumpyHelper.to_array(q_bias) + kb = NumpyHelper.to_array(k_bias) + vb = NumpyHelper.to_array(v_bias) + qkv_bias = np.stack((qb, kb, vb), axis=0) + + hidden_size = qkv_weight.shape[0] + + weight = helper.make_tensor( + weight_name, + data_type=TensorProto.FLOAT, + dims=[hidden_size, hidden_size * 3], + vals=qkv_weight.flatten().tobytes(), + raw=True, + ) + self.model.add_initializer(weight, self.this_graph_name) + + bias = helper.make_tensor( + bias_name, + data_type=TensorProto.FLOAT, + dims=[hidden_size * 3], + vals=qkv_bias.flatten().tobytes(), + raw=True, + ) + self.model.add_initializer(bias, self.this_graph_name) + + self.add_fp32_value_info(weight.name) + self.add_fp32_value_info(bias.name) + + return weight_name, bias_name + + def fuse( + self, + node, + input_name_to_nodes, + output_name_to_node, + ): + logger.info("Optimizing %s...", node.name) + + logger.info(f"AttentionOpType: {self.attn_op_type}") + + layer_id = self.get_layer_id(node) + + i_hidden_states = node.input[0] + i_key_cache = self.get_io_by_name(node, "past_key") + i_value_cache = self.get_io_by_name(node, "past_value") + + o_hidden_states = node.output[3] + o_key_cache = self.get_io_by_name(node, "present_key") + o_value_cache = self.get_io_by_name(node, "present_value") + + ln_weight = self.get_io_by_name(node, "input_layernorm.weight") + ln_bias = self.get_io_by_name(node, "input_layernorm.bias") + + attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = ( + None, + None, + None, + None, + None, + None, + ) + attn_qkv_weight, attn_qkv_bias = None, None + cos_cache, sin_cache = None, None + + if self.attn_op_type != AttentionOpType.Attention: + attn_q_weight = self.process_initializer( + self.get_io_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc() + ) + attn_k_weight = self.process_initializer( + self.get_io_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc() + ) + attn_v_weight = self.process_initializer( + self.get_io_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc() + ) + attn_q_bias = self.get_io_by_name(node, "self_attn.q_proj.bias") + attn_k_bias = self.get_io_by_name(node, "self_attn.k_proj.bias") + attn_v_bias = self.get_io_by_name(node, "self_attn.v_proj.bias") + + cos_cache = self.process_initializer( + self.get_io_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc() + ) + sin_cache = self.process_initializer( + self.get_io_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc() + ) + else: + attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm( + self.get_io_by_name(node, "self_attn.q_proj.weight"), + self.get_io_by_name(node, "self_attn.k_proj.weight"), + self.get_io_by_name(node, "self_attn.v_proj.weight"), + self.get_io_by_name(node, "self_attn.q_proj.bias"), + self.get_io_by_name(node, "self_attn.k_proj.bias"), + self.get_io_by_name(node, "self_attn.v_proj.bias"), + self.get_uname(layer_id, "attn_qkv_weight"), + self.get_uname(layer_id, "attn_qkv_bias"), + ) + + attn_out_weight = self.process_initializer( + self.get_io_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc() + ) + attn_out_bias = self.get_io_by_name(node, "self_attn.dense.bias") + + mlp_fc1_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc()) + mlp_fc2_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc()) + mlp_fc1_bias = self.get_io_by_name(node, "mlp.fc1.bias") + mlp_fc2_bias = self.get_io_by_name(node, "mlp.fc2.bias") + + layer_known_edges_names = [] + layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache]) + layer_known_edges_names.extend([o_hidden_states, o_key_cache, o_value_cache]) + layer_known_edges_names.extend([ln_weight, ln_bias]) + if self.attn_op_type != AttentionOpType.Attention: + layer_known_edges_names.extend( + [ + attn_q_weight, + attn_q_bias, + attn_k_weight, + attn_k_bias, + attn_v_weight, + attn_v_bias, + cos_cache, + sin_cache, + ] + ) + else: + layer_known_edges_names.extend([attn_qkv_weight, attn_qkv_bias]) + layer_known_edges_names.extend( + [attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias] + ) + layer_known_edges_names.extend(["attention_mask", "step", "seqlens_k", "total_sequence_length"]) + + subgraph_nodes = [] + subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"])) + subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_")) + subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_")) + subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"])) + subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_")) + subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1")) + subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2")) + if self.attn_op_type != AttentionOpType.Attention: + subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_")) + subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_")) + subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_")) + subgraph_nodes.extend(self.rotary(["query", "step", cos_cache, sin_cache], ["query_rot"], "Q_")) + subgraph_nodes.extend(self.rotary(["key", "step", cos_cache, sin_cache], ["key_rot"], "K_")) + if self.attn_op_type == AttentionOpType.MultiHeadAttention: + subgraph_nodes.extend( + self.mha( + ["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache], + ["attn_out", o_key_cache, o_value_cache], + ) + ) + elif self.attn_op_type == AttentionOpType.GroupQueryAttention: + subgraph_nodes.extend( + self.gqa( + [ + "query_rot", + "key_rot", + "value", + i_key_cache, + i_value_cache, + "seqlens_k", + "total_sequence_length", + ], + ["attn_out", o_key_cache, o_value_cache], + ) + ) + if layer_id == 0: + gqa_aux_nodes = self.get_gqa_aux_nodes() + for new_node in gqa_aux_nodes: + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + self.model.add_initializer( + numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name + ) + else: + past_name = f"past_{layer_id}" + present_name = f"present_{layer_id}" + layer_known_edges_names.extend([past_name, present_name]) + subgraph_nodes.extend( + self.attention( + ["ln_out", attn_qkv_weight, attn_qkv_bias, "attention_mask", past_name], ["attn_out", present_name] + ) + ) + + self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names) + + self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class PhiOnnxModel(OnnxModel): + def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): + super().__init__(model) + self.phi2_preprocessor = Phi2PreProcessor(self.model, num_heads, hidden_size) + self.fission_transformer_block = FissionTransformerBlockPhi(self, num_heads) + self.fission_causal_lm_head = FissionTransformerCausalLMHeadPhi(self) + self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self) + self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self) + + def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): + assert options is not None + attn_op_type = options.attention_op_type + + self.fission_transformer_block.set_attention_op_type(attn_op_type) + + self.phi2_preprocessor.preprocess_onnx(attn_op_type) + + self.fission_transformer_block.apply() + self.fission_transformer_layernorm.apply() + self.fission_causal_lm_head.apply() + self.fission_transformer_embedding.apply() + + super().prune_graph() + + # SLN ctor is placed here intentionally to delay the symbolic shape inference + self.fuse_sln = FusionSkipLayerNormalization(self) + self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self) + self.fuse_sln.apply() + self.fuse_bias_sln.apply() + + def get_fused_operator_statistics(self): + """ + Returns node count of fused operators. + """ + op_count = {} + ops = [ + "Attention", + "MultiHeadAttention", + "GroupQueryAttention", + "Gelu", + "BiasGelu", + "FastGelu", + "LayerNormalization", + "SkipLayerNormalization", + ] + for op in ops: + nodes = self.get_nodes_by_op_type(op) + op_count[op] = len(nodes) + + logger.info(f"Optimized operators: {op_count}") + return op_count + + def is_fully_optimized(self, fused_op_count=None): + """ + Returns True when the model is fully optimized. + """ + if fused_op_count is None: + fused_op_count = self.get_fused_operator_statistics() + + def op_count(op_name: str): + return fused_op_count.get(op_name) or 0 + + attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("GroupQueryAttention") + gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu") + layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization") + + is_perfect = (attention > 0) and (attention == gelu) and (layer_norm >= attention) + + if layer_norm == 0: + logger.debug("Layer Normalization not fused") + + if gelu == 0: + logger.debug("Gelu (or FastGelu) not fused") + + if attention == 0: + logger.warning("Attention (or MultiHeadAttention) not fused") + + return is_perfect diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index ba61f4f6e43ba..ce0be6b3449ed 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -34,6 +34,7 @@ from onnx_model_clip import ClipOnnxModel from onnx_model_conformer import ConformerOnnxModel from onnx_model_gpt2 import Gpt2OnnxModel +from onnx_model_phi import PhiOnnxModel from onnx_model_t5 import T5OnnxModel from onnx_model_tnlr import TnlrOnnxModel from onnx_model_unet import UnetOnnxModel @@ -58,6 +59,7 @@ "vae": (VaeOnnxModel, "pytorch", 1), # UAE in Stable Diffusion "vit": (BertOnnxModel, "pytorch", 1), "conformer": (ConformerOnnxModel, "pytorch", 1), + "phi": (PhiOnnxModel, "pytorch", 0), } diff --git a/setup.py b/setup.py index 67d34b065ad03..03e1cb75ba581 100644 --- a/setup.py +++ b/setup.py @@ -419,6 +419,7 @@ def finalize_options(self): "onnxruntime.transformers.models.gpt2", "onnxruntime.transformers.models.llama", "onnxruntime.transformers.models.longformer", + "onnxruntime.transformers.models.phi2", "onnxruntime.transformers.models.t5", "onnxruntime.transformers.models.stable_diffusion", "onnxruntime.transformers.models.whisper", From d120104dcd1e0942ca0856a096c55be05483d1d6 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 5 Feb 2024 13:11:37 -0800 Subject: [PATCH 039/207] add ATen support for bicubic interpolation (#19380) ### Description Add ATen fallback support for bicubic interpolation algorithm. ### Motivation and Context Required for facebook/dinov2 model architecture as part of ONNX Runtime integration with AML Vision models. --- .../python/tools/symbolic_shape_infer.py | 1 + .../ortmodule/_custom_gradient_registry.py | 7 ++++- .../ortmodule/_custom_op_symbolic_registry.py | 13 +++++++++ .../python/orttraining_test_ortmodule_api.py | 28 +++++++++++++++++++ 4 files changed, 48 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 31f3a3a2b30d6..e7b7074783162 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -241,6 +241,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "upsample_nearest1d": self._infer_aten_upsample, "upsample_nearest2d": self._infer_aten_upsample, "upsample_nearest3d": self._infer_aten_upsample, + "upsample_bicubic2d": self._infer_aten_upsample, } self.run_ = True self.suggested_merge_ = {} diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 77317242727b4..4883075112dcb 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -241,7 +241,7 @@ def native_group_norm_gradient(): # are available for all versions, though they are not that convienent to use. def _upsample_gradient(backward_fn, dims): scales = ["" for _ in range(dims)] - if "bilinear" in backward_fn: + if "bicubic" in backward_fn: scales = ["I(2)", *scales] return [ ("Shape", ["I(0)"], ["Shape_X"]), @@ -271,3 +271,8 @@ def upsample_nearest2d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec") def upsample_nearest3d_gradient(): return _upsample_gradient("upsample_nearest3d_backward", 3) + + +@register_gradient("org.pytorch.aten", "ATen", "upsample_bicubic2d", "vec") +def upsample_bicubic2d_gradient(): + return _upsample_gradient("upsample_bicubic2d_backward", 2) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 99e8851b6a697..9288027f0188c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -808,3 +808,16 @@ def upsample_nearest2d(g, input, output_size, scale_factors): @register_symbolic("upsample_nearest3d") def upsample_nearest3d(g, input, output_size, scale_factors): return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d") + + +@register_symbolic("upsample_bicubic2d") +def upsample_bicubic2d(g, input, output_size, align_corners, scale_factors): + return g.op( + "org.pytorch.aten::ATen", + input, + output_size, + align_corners, + scale_factors, + operator_s="upsample_bicubic2d", + overload_name_s="vec", + ) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 938d33cc9a714..6a6832e06330a 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1805,6 +1805,34 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) +def test_aten_upsample_bicubic(): + class _NeuralNetUpsampleBicubic(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.interpolate(input, size=(8, 12), mode="bicubic") + + device = "cuda" + pt_model = _NeuralNetUpsampleBicubic().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input): + prediction = model(input) + prediction.sum().backward() + return prediction + + # reset manual seed to reset the generator + torch.manual_seed(2333) + pt_input = torch.randn([2, 4, 6, 8], dtype=torch.float, device=device, requires_grad=True) + ort_input = copy.deepcopy(pt_input) + pt_prediction = run_step(pt_model, pt_input) + ort_prediction = run_step(ort_model, ort_input) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + + def test_gradient_correctness_cast_chain(): class NeuralNetCast(torch.nn.Module): def __init__(self, D): From d2d9b5b5f9639a419412fc68174441330521af91 Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Mon, 5 Feb 2024 13:53:39 -0800 Subject: [PATCH 040/207] fix output shape inference packed gqa (#19374) ### Description fix output shape inference packed gqa --- onnxruntime/core/graph/contrib_ops/bert_defs.cc | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 8583474a1e391..8bf013ed009d5 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -259,6 +259,16 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2]; updateOutputShape(ctx, 0, output_shape); + } else { + ONNX_NAMESPACE::TensorShapeProto output_shape; + int64_t num_heads = getAttribute(ctx, "num_heads", 0); + int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0); + int64_t hidden_size = query_dims[2].dim_value(); + int64_t head_size = hidden_size / (num_heads + 2 * kv_num_heads); + *output_shape.add_dim() = query_dims[0]; + *output_shape.add_dim() = query_dims[1]; + output_shape.add_dim()->set_dim_value(head_size * num_heads); + updateOutputShape(ctx, 0, output_shape); } } From 06a84c8a0d738cf6c45a173e7283aa7dd1b6c2ba Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Mon, 5 Feb 2024 17:33:58 -0500 Subject: [PATCH 041/207] Enable DML on Windows and CUDA on Linux for Node.js binding (#19274) This pull request includes modifications to the `c-api-cpu.yml` Azure Pipelines configuration file. The changes mainly revolve around the Node.js packaging stage and the handling of Node.js artifacts. The most significant changes include renaming the Node.js packaging stage, adding a new dependency to the stage, changing artifact names, adding a new script to list Node.js artifacts, and updating the source folder for copying NuGet binaries. Changes in Node.js packaging: * [`tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml`](diffhunk://#diff-00815920cc190d10fdebceac0c3a4b8a59e408684ae38177dfe7f96cae276c59L503-R508): Renamed the Node.js packaging stage from `Nodejs_Packaging_CPU` to `Nodejs_Packaging` and added `Windows_CI_GPU_DML_Dev` as a new dependency to the stage. Changes in handling of Node.js artifacts: * [`tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml`](diffhunk://#diff-00815920cc190d10fdebceac0c3a4b8a59e408684ae38177dfe7f96cae276c59L568-R569): Changed the artifact name from `drop-onnxruntime-nodejs-win-x64` to `drop-onnxruntime-nodejs-win-x64-dml` in the task to download pipeline artifacts for Windows x64. * [`tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml`](diffhunk://#diff-00815920cc190d10fdebceac0c3a4b8a59e408684ae38177dfe7f96cae276c59R595-R598): Added a new script to list Node.js artifacts from the directory `$(Build.BinariesDirectory)/nodejs-artifacts/win32/x64/`. * [`tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml`](diffhunk://#diff-00815920cc190d10fdebceac0c3a4b8a59e408684ae38177dfe7f96cae276c59L635-R640): Updated the source folder from `$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-x64\lib` to `$(Build.BinariesDirectory)\nodejs-artifacts\win32\x64` in the task to copy NuGet binaries to the directory `$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64`. --------- Co-authored-by: Yulong Wang <7679871+fs-eire@users.noreply.github.com> --- .../nodejs/templates/test_linux.yml | 3 +- .../nodejs/templates/test_macos.yml | 2 +- .../nodejs/templates/test_win.yml | 2 +- .../azure-pipelines/templates/c-api-cpu.yml | 57 +++++++------------ 4 files changed, 23 insertions(+), 41 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml index 864d1002a90fc..7b03c0e82f4bb 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml @@ -4,7 +4,7 @@ parameters: stages: - stage: Nodejs_Test_${{ parameters.StageSuffix }} dependsOn: - - Nodejs_Packaging_CPU + - Nodejs_Packaging condition: succeeded() jobs: - job: @@ -18,4 +18,3 @@ stages: value: '$(Build.BinariesDirectory)' steps: - template: test.yml - diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml index 871d7894e5315..dc52e9a22f05b 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml @@ -3,7 +3,7 @@ parameters: stages: - stage: Nodejs_Test_MacOS_${{ parameters.StageSuffix }} dependsOn: - - Nodejs_Packaging_CPU + - Nodejs_Packaging condition: succeeded() jobs: - job: diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml index c823ac788f925..9b3c61b2d3d85 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml @@ -4,7 +4,7 @@ parameters: stages: - stage: Nodejs_Test_${{ parameters.StageSuffix }} dependsOn: - - Nodejs_Packaging_CPU + - Nodejs_Packaging condition: succeeded() jobs: - job: diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 8bdb395c00dc3..3bcac799d7cf9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -501,12 +501,13 @@ stages: displayName: 'Clean Agent Directories' condition: always() -- stage: Nodejs_Packaging_CPU +- stage: Nodejs_Packaging dependsOn: + - Windows_CI_GPU_DML_Dev + - Windows_CI_GPU_DML_Dev_arm64 - Linux_C_API_Packaging_CPU + - Linux_C_API_Packaging_GPU_TensorRT_x64 - MacOS_C_API_Package_Publish - - Windows_Packaging_CPU_x64_${{ parameters.BuildVariant }} - - Windows_Packaging_CPU_arm64_${{ parameters.BuildVariant }} condition: succeeded() jobs: - job: @@ -533,18 +534,6 @@ stages: workingDirectory: '$(Build.SourcesDirectory)' displayName: 'Testing: force EOL to lf on windows for /js/**' - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet (Win x64)' - inputs: - artifactName: 'onnxruntime-win-x64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - NuGet (Win ARM64)' - inputs: - artifactName: 'onnxruntime-win-arm64' - targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - NuGet (OSX)' inputs: @@ -554,7 +543,7 @@ stages: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - NuGet (Linux x64)' inputs: - artifactName: 'onnxruntime-linux-x64' + artifactName: 'onnxruntime-linux-x64-tensorrt' targetPath: '$(Build.BinariesDirectory)/nuget-artifact' - task: DownloadPipelineArtifact@0 @@ -566,13 +555,13 @@ stages: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Win x64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-win-x64' + artifactName: 'drop-onnxruntime-nodejs-win-x64-dml' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/win32/x64/' - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Win ARM64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-win-arm64' + artifactName: 'drop-onnxruntime-nodejs-win-arm64-dml' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/win32/arm64/' - task: DownloadPipelineArtifact@0 @@ -590,7 +579,7 @@ stages: - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - Nodejs (Linux x64)' inputs: - artifactName: 'drop-onnxruntime-nodejs-linux-x64' + artifactName: 'drop-onnxruntime-nodejs-linux-x64-tensorrt' targetPath: '$(Build.BinariesDirectory)/nodejs-artifacts/linux/x64/' - task: DownloadPipelineArtifact@0 @@ -631,38 +620,32 @@ stages: # Node.js binding win32/x64 - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-x64\lib' - Contents: '*.dll' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64' - - task: CopyFiles@2 - displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64\' + displayName: 'Copy binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\win32\x64' - Contents: '*.node' + Contents: | + *.dll + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\x64' # Node.js binding win32/arm64 - task: CopyFiles@2 - displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64\' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-arm64\lib' - Contents: '*.dll' - TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64' - - task: CopyFiles@2 - displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64\' + displayName: 'Copy binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64\' inputs: SourceFolder: '$(Build.BinariesDirectory)\nodejs-artifacts\win32\arm64' - Contents: '*.node' + Contents: | + *.dll + *.node TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\arm64' # Node.js binding linux/x64 - task: CopyFiles@2 displayName: 'Copy nuget binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64\' inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64\lib' - Contents: 'libonnxruntime.so.*' + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-linux-x64-tensorrt\lib' + Contents: | + libonnxruntime.so.* + libonnxruntime_providers_*.so TargetFolder: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64' - task: CopyFiles@2 displayName: 'Copy nodejs binaries to: $(Build.SourcesDirectory)\js\node\bin\napi-v3\linux\x64\' From 23996bbbbe0406a5c8edbf6b7dbd71e5780d3f4b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 5 Feb 2024 14:35:57 -0800 Subject: [PATCH 042/207] [CUDA][ROCm][Training] Fix cuda/rocm provider info hash (#19398) When I test a new provider option, the training pipeline failed. I found that training uses hash code of provider info to try get provider instance. If a provider option is not used in hashing, the provider instance fetched from cache might have different configuration for that option. Here I fix the hashing to use all provider options (except the default Arena config that cannot be set from python API since training is used with PyTorch in most cases). Fixed a few obvious typo in the touched files. Add regression test cases. --- .../cuda/cuda_execution_provider_info.cc | 17 +++++---- .../cuda/cuda_execution_provider_info.h | 38 +++++++++++++++---- .../rocm/rocm_execution_provider_info.h | 34 +++++++++++++---- .../test/python/onnxruntime_test_python.py | 4 ++ .../python/orttraining_python_module.cc | 20 ++-------- 5 files changed, 75 insertions(+), 38 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index 7b507296d5982..81ddc38820914 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -31,8 +31,9 @@ constexpr const char* kTunableOpEnable = "tunable_op_enable"; constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_strict_mode"; -constexpr const char* kPreferNCHWMode = "prefer_nhwc"; -constexpr const char* KUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; +constexpr const char* kPreferNHWCMode = "prefer_nhwc"; +constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; + } // namespace provider_option_names } // namespace cuda @@ -112,8 +113,8 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kEnableCudaGraph, info.enable_cuda_graph) .AddAssignmentToReference(cuda::provider_option_names::kCudnnConv1dPadToNc1d, info.cudnn_conv1d_pad_to_nc1d) .AddAssignmentToReference(cuda::provider_option_names::kEnableSkipLayerNormStrictMode, info.enable_skip_layer_norm_strict_mode) - .AddAssignmentToReference(cuda::provider_option_names::kPreferNCHWMode, info.prefer_nhwc) - .AddAssignmentToReference(cuda::provider_option_names::KUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) + .AddAssignmentToReference(cuda::provider_option_names::kPreferNHWCMode, info.prefer_nhwc) + .AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -164,8 +165,8 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, {cuda::provider_option_names::kEnableSkipLayerNormStrictMode, MakeStringWithClassicLocale(info.enable_skip_layer_norm_strict_mode)}, - {cuda::provider_option_names::kPreferNCHWMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, - {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, + {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, }; return options; @@ -185,8 +186,8 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op_enable)}, {cuda::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op_tuning_enable)}, {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op_max_tuning_duration_ms)}, - {cuda::provider_option_names::kPreferNCHWMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, - {cuda::provider_option_names::KUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, + {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index b286f5a9161b0..04eea2f6c8e94 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -83,12 +83,36 @@ struct CUDAExecutionProviderInfo { } // namespace onnxruntime template <> -struct std::hash<::onnxruntime::cuda::TunableOpInfo> { - size_t operator()(const ::onnxruntime::cuda::TunableOpInfo& info) const { - size_t seed_and_value{0xbc9f1d34}; - onnxruntime::HashCombine(info.enable, seed_and_value); - onnxruntime::HashCombine(info.tuning_enable, seed_and_value); - onnxruntime::HashCombine(info.max_tuning_duration_ms, seed_and_value); - return seed_and_value; +struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { + size_t operator()(const ::onnxruntime::CUDAExecutionProviderInfo& info) const { + size_t value{0xbc9f1d34}; // seed + + // Bits: device_id (16), arena_extend_strategy/cudnn_conv_algo_search (reserved 2), boolean options (1 each) + size_t data = static_cast(info.device_id) ^ + (static_cast(info.arena_extend_strategy) << 16) ^ + (static_cast(info.cudnn_conv_algo_search) << 18) ^ + (static_cast(info.do_copy_in_default_stream) << 20) ^ + (static_cast(info.has_user_compute_stream) << 21) ^ + (static_cast(info.cudnn_conv_use_max_workspace) << 22) ^ + (static_cast(info.enable_cuda_graph) << 23) ^ + (static_cast(info.tunable_op.enable) << 24) ^ + (static_cast(info.tunable_op.tuning_enable) << 25) ^ + (static_cast(info.cudnn_conv1d_pad_to_nc1d) << 26) ^ + (static_cast(info.enable_skip_layer_norm_strict_mode) << 27) ^ + (static_cast(info.prefer_nhwc) << 28) ^ + (static_cast(info.use_ep_level_unified_stream) << 29); + onnxruntime::HashCombine(data, value); + + onnxruntime::HashCombine(info.gpu_mem_limit, value); + onnxruntime::HashCombine(info.tunable_op.max_tuning_duration_ms, value); + + // Memory pointers + onnxruntime::HashCombine(reinterpret_cast(info.user_compute_stream), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.alloc), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.free), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.empty_cache), value); + + // The default memory arena cfg is not used in hashing right now. + return value; } }; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h index 2f549cc1ac143..c245b18057ca7 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h @@ -74,12 +74,32 @@ struct ROCMExecutionProviderInfo { } // namespace onnxruntime template <> -struct std::hash<::onnxruntime::rocm::TunableOpInfo> { - size_t operator()(const ::onnxruntime::rocm::TunableOpInfo& info) const { - size_t seed_and_value{0xbc9f1d34}; - onnxruntime::HashCombine(info.enable, seed_and_value); - onnxruntime::HashCombine(info.tuning_enable, seed_and_value); - onnxruntime::HashCombine(info.max_tuning_duration_ms, seed_and_value); - return seed_and_value; +struct std::hash<::onnxruntime::ROCMExecutionProviderInfo> { + size_t operator()(const ::onnxruntime::ROCMExecutionProviderInfo& info) const { + size_t value{0xbc9f1d34}; // seed + + // Bits: device_id (16), arena_extend_strategy/miopen_conv_exhaustive_search (reserved 2), boolean options (1 each) + size_t data = static_cast(info.device_id) ^ + (static_cast(info.arena_extend_strategy) << 16) ^ + (static_cast(info.miopen_conv_exhaustive_search) << 18) ^ + (static_cast(info.do_copy_in_default_stream) << 20) ^ + (static_cast(info.has_user_compute_stream) << 21) ^ + (static_cast(info.miopen_conv_use_max_workspace) << 22) ^ + (static_cast(info.enable_hip_graph) << 23) ^ + (static_cast(info.tunable_op.enable) << 24) ^ + (static_cast(info.tunable_op.tuning_enable) << 25); + onnxruntime::HashCombine(data, value); + + onnxruntime::HashCombine(info.gpu_mem_limit, value); + onnxruntime::HashCombine(info.tunable_op.max_tuning_duration_ms, value); + + // Memory pointers + onnxruntime::HashCombine(reinterpret_cast(info.user_compute_stream), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.alloc), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.free), value); + onnxruntime::HashCombine(reinterpret_cast(info.external_allocator_info.empty_cache), value); + + // The default memory arena cfg is not used in hashing right now. + return value; } }; diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 68e441c87860e..5b41806b646af 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -414,6 +414,8 @@ def test_get_and_set_option_with_values(option_name, option_values): str(option_value), ) + test_get_and_set_option_with_values("enable_cuda_graph", ["1", "0"]) + test_get_and_set_option_with_values("arena_extend_strategy", ["kNextPowerOfTwo", "kSameAsRequested"]) test_get_and_set_option_with_values("cudnn_conv_algo_search", ["DEFAULT", "EXHAUSTIVE", "HEURISTIC"]) @@ -553,6 +555,8 @@ def test_get_and_set_option_with_values(option_name, option_values): test_get_and_set_option_with_values("tunable_op_max_tuning_duration_ms", ["-1", "1"]) + test_get_and_set_option_with_values("enable_hip_graph", ["1", "0"]) + run_rocm_options_test() def test_invalid_set_providers(self): diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 55cd2af2d0219..b0d1ed50af126 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -47,7 +47,7 @@ void addObjectMethodsForLazyTensor(py::module& m); #endif bool InitArray(); -bool GetDyanmicExecutionProviderHash( +bool GetDynamicExecutionProviderHash( const std::string& ep_shared_lib_path, const ProviderOptions& provider_options, size_t& hash, @@ -87,13 +87,7 @@ bool GetProviderInstanceHash(const std::string& type, if (auto* cuda_provider_info = TryGetProviderInfo_CUDA()) { const CUDAExecutionProviderInfo info = GetCudaExecutionProviderInfo(cuda_provider_info, provider_options_map); - hash = static_cast(info.device_id) ^ - info.gpu_mem_limit ^ - (static_cast(info.arena_extend_strategy) << 16) ^ - (static_cast(info.cudnn_conv_algo_search) << 18) ^ - (static_cast(info.do_copy_in_default_stream) << 20) ^ - (static_cast(info.has_user_compute_stream) << 22) ^ - std::hash{}(info.tunable_op); + hash = std::hash{}(info); return true; } #endif @@ -102,13 +96,7 @@ bool GetProviderInstanceHash(const std::string& type, if (auto* rocm_provider_info = TryGetProviderInfo_ROCM()) { const ROCMExecutionProviderInfo info = GetRocmExecutionProviderInfo(rocm_provider_info, provider_options_map); - hash = static_cast(info.device_id) ^ - info.gpu_mem_limit ^ - (static_cast(info.arena_extend_strategy) << 16) ^ - (static_cast(info.miopen_conv_exhaustive_search) << 18) ^ - (static_cast(info.do_copy_in_default_stream) << 20) ^ - (static_cast(info.has_user_compute_stream) << 22) ^ - std::hash{}(info.tunable_op); + hash = std::hash{}(info); return true; } #endif @@ -128,7 +116,7 @@ bool GetProviderInstanceHash(const std::string& type, provider_options.insert(option); } } - return GetDyanmicExecutionProviderHash(shared_lib_path_it->second, provider_options, hash); + return GetDynamicExecutionProviderHash(shared_lib_path_it->second, provider_options, hash); } } } From a4cfdc1c28ac95ec6fd0667e856b6a6b8dd1020c Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 5 Feb 2024 22:58:35 -0800 Subject: [PATCH 043/207] update comments for nodejs binding artifact preparation. (#19425) ### Description document update as a following-up for #19274 --- .../azure-pipelines/templates/c-api-cpu.yml | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 3bcac799d7cf9..1ba0b02560aca 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -534,6 +534,50 @@ stages: workingDirectory: '$(Build.SourcesDirectory)' displayName: 'Testing: force EOL to lf on windows for /js/**' + ################################################################## + # Node.js binding artifacts preparation + # + # This stage prepares Node.js binding artifacts for publishing. The artifacts support the following platforms: + # - Windows x64 with DML support + # - Windows arm64 with DML support + # - Linux x64 with TensorRT support + # - Linux arm64 (CPU only) + # - macOS x64 (CPU only) + # - macOS arm64 (CPU only) + # + # ORT Node.js binding artifacts contain 2 parts: + # 1. ONNX Runtime native shared libraries and their dependencies + # - Windows (x64, arm64): + # - onnxruntime.dll + # - DirectML.dll + # - Linux (x64, arm64): + # - libonnxruntime.so{.version} + # - libonnxruntime_providers_shared.so + # - libonnxruntime_providers_{provider}.so + # - macOS (x64, arm64): + # - libonnxruntime.dylib + # 2. ONNX Runtime Node.js binding + # - onnxruntime_binding.node + # + # For windows platform, the artifact is named as 'onnxruntime-nodejs-win-x64-dml' for x64, and + # 'onnxruntime-nodejs-win-arm64-dml' for arm64. Each artifact contains both (1) and (2). + # + # For Linux and macOS platforms, (1) and (2) are packed into separate artifacts. + # The following artifacts contain (1): + # - onnxruntime-osx + # - onnxruntime-linux-x64-tensorrt + # - onnxruntime-linux-aarch64 + # The following artifacts contain (2): + # - drop-onnxruntime-nodejs-linux-x64-tensorrt + # - drop-onnxruntime-nodejs-linux-aarch64 + # - drop-onnxruntime-nodejs-osx-x86_64 + # - drop-onnxruntime-nodejs-osx-arm64 + # + # All binary artifacts will eventually be put into folder before packaging 'onnxruntime-node': + # $(Build.SourcesDirectory)\js\node\bin\napi-v3\{os}\{cpu_arch}\ + # + # {os} is one of 'win32', 'darwin', 'linux' and {cpu_arch} is one of 'x64', 'arm64'. + - task: DownloadPipelineArtifact@0 displayName: 'Download Pipeline Artifact - NuGet (OSX)' inputs: From 5ff27ef02a1b8d8668c6a9f4da2b7a578f4d9f05 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 6 Feb 2024 09:07:31 -0800 Subject: [PATCH 044/207] [js/webgpu] support customop FastGelu (#19392) ### Description Support WebGPU custom operator FastGelu. --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + .../wasm/jsep/webgpu/ops/bias-split-gelu.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts | 69 ++++++ js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 33 ++- js/web/test/data/ops/fast-gelu.jsonc | 211 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + onnxruntime/contrib_ops/js/fast_gelu.cc | 23 ++ onnxruntime/contrib_ops/js/fast_gelu.h | 17 ++ .../contrib_ops/js/js_contrib_kernels.cc | 2 + 10 files changed, 353 insertions(+), 8 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts create mode 100644 js/web/test/data/ops/fast-gelu.jsonc create mode 100644 onnxruntime/contrib_ops/js/fast_gelu.cc create mode 100644 onnxruntime/contrib_ops/js/fast_gelu.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 2557971eb4ded..b21af8e715db3 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -41,6 +41,7 @@ Do not modify directly.* | Erf | ai.onnx(9-12,13+) | | | Exp | ai.onnx(6-12,13+) | | | Expand | ai.onnx(8-12,13+) | | +| FastGelu | com.microsoft(1+) | | | Flatten | ai.onnx(1-8,9-10,11-12,13+) | | | Floor | ai.onnx(6-12,13+) | | | FusedConv | com.microsoft(1+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index d737a28654220..ac08c5fb1f7ab 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -13,6 +13,7 @@ import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose' import {cumsum, parseCumSumAttributes} from './ops/cumsum'; import {einsum, parseEinsumAttributes} from './ops/einsum'; import {expand} from './ops/expand'; +import {fastGelu} from './ops/fast-gelu'; import {gather, parseGatherAttributes} from './ops/gather'; import {gatherElements, parseGatherElementsAttributes} from './ops/gather-elements'; import {gemm, parseGemmAttributes} from './ops/gemm'; @@ -72,6 +73,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Erf', [unaryOps.erf]], ['Exp', [unaryOps.exp]], ['Expand', [expand]], + ['FastGelu', [fastGelu]], ['Floor', [unaryOps.floor]], ['FusedConv', [conv, parseConvAttributes]], ['Gather', [gather, parseGatherAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts index a81a7a8f1df5c..089fecd758e30 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -43,7 +43,7 @@ const createBiasSplitGeluProgramInfo = (inputs: readonly TensorView[]): ProgramI ${shaderHelper.declareVariables(input, bias, output)} - ${erfImpl(`vec4<${dataType}>`, dataType)} + ${erfImpl(dataType)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts new file mode 100644 index 0000000000000..f50a6a3f011fe --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/fast-gelu.ts @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, ProgramInfo} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType, UniformsArrayType, WORKGROUP_SIZE} from './common'; +import * as unary from './unary-op'; + +// GELU is defined as Y=0.5*X*(1+tanh(0.797885*X+0.035677*X*X*X)), where X may pre-add a bias. + +const createFastGeluProgramInfo = (inputTensors: readonly TensorView[]): ProgramInfo => { + const dataType = inputTensors[0].dataType; + const outputSize = ShapeUtil.size(inputTensors[0].dims); + const biasLength = ShapeUtil.size(inputTensors[1].dims); + // can only use vec4 when bias length is multiple of 4 + const useVec4 = biasLength % 4 === 0; + const getShaderSource = (shaderHelper: ShaderHelper): string => { + const x = inputVariable('x', dataType, [1], 4); + const bias = inputVariable('bias', dataType, [1], 4); + const y = outputVariable('y', dataType, [1], 4); + + const uniforms: UniformsArrayType = [{name: 'output_vec_size', type: 'u32'}, {name: 'bias_size', type: 'u32'}]; + + const singleElementBias = (i: 0|1|2|3) => ` + let bias${i}_offset: u32 = (global_idx * 4 + ${i}) % uniforms.bias_size; + let bias${i} = ${bias.getByOffset(`bias${i}_offset / 4`)}[bias${i}_offset % 4];`; + const biasGetExpression = useVec4 ? + ` + let bias = ${bias.getByOffset('global_idx % (uniforms.bias_size / 4)')};` : + `${singleElementBias(0)}${singleElementBias(1)}${singleElementBias(2)}${singleElementBias(3)} + let bias = ${x.type.value}(bias0, bias1, bias2, bias3);`; + + return `${shaderHelper.registerUniforms(uniforms).declareVariables(x, bias, y)} + + ${unary.fastGeluImpl(tensorTypeToWsglValueType(dataType))} + + ${shaderHelper.mainStart(WORKGROUP_SIZE)} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_vec_size')} + + let x = ${x.getByOffset('global_idx')}; + ${biasGetExpression} + let x_in = x + bias; + ${y.setByOffset('global_idx', unary.fastGeluExpression('x_in'))} + }`; + }; + + return { + name: 'FastGeluWithBias', + shaderCache: {hint: `${useVec4}`, inputDependencies: ['type', 'type']}, + getShaderSource, + getRunData: (inputs) => ({ + outputs: [{dims: inputs[0].dims, dataType: inputs[0].dataType}], + programUniforms: + [{type: DataType.uint32, data: Math.ceil(outputSize / 4)}, {type: DataType.uint32, data: biasLength}], + dispatchGroup: {x: Math.ceil(outputSize / WORKGROUP_SIZE / 4)} + }) + }; +}; + +export const fastGelu = (context: ComputeContext): void => { + if (context.inputs.length < 2 || ShapeUtil.size(context.inputs[1].dims) === 0) { + unary.fastGelu(context); + } else { + context.compute(createFastGeluProgramInfo(context.inputs)); + } +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 1accfac18b876..5f105c745739e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -178,7 +178,7 @@ export const elu = (context: ComputeContext, attributes: AlphaAttributes): void attributes.cacheKey)); }; -export const erfImpl = (dataType: string, varType = 'f32') => ` +export const erfImpl = (varType = 'f32') => ` const r0: ${varType} = 0.3275911; const r1: ${varType} = 0.254829592; const r2: ${varType} = -0.284496736; @@ -186,7 +186,7 @@ const r3: ${varType} = 1.421413741; const r4: ${varType} = -1.453152027; const r5: ${varType} = 1.061405429; -fn erf_vf32(v: ${dataType}) -> ${dataType} { +fn erf_vf32(v: vec4<${varType}>) -> vec4<${varType}> { let absv = abs(v); let x = 1.0 / (1.0 + r0 * absv); return sign(v) * (1.0 - ((((r5 * x + r4) * x + r3) * x + r2) * x + r1) * x * exp(-absv * absv)); @@ -194,8 +194,7 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} { export const erf = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); - context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType))); + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(dataType))); }; export const exp = (context: ComputeContext): void => { @@ -209,8 +208,7 @@ export const floor = (context: ComputeContext): void => { export const gelu = (context: ComputeContext): void => { const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, - erfImpl(`vec4<${dataType}>`, dataType))); + context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`, erfImpl(dataType))); }; export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => { @@ -278,10 +276,31 @@ export const tan = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tan', 'tan')); }; +export const tanhExpression = (a: string) => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`; + export const tanh = (context: ComputeContext): void => { // TODO: revisit after https://github.com/gpuweb/gpuweb/issues/4458 is resolved + context.compute(createElementwiseProgramInfo(context.inputs[0], 'Tanh', tanhExpression)); +}; + +export const fastGeluImpl = (varType = 'f32') => ` +const fast_gelu_a: ${varType} = 0.5; +const fast_gelu_b: ${varType} = 0.7978845608028654; +const fast_gelu_c: ${varType} = 0.035677408136300125; + +fn tanh_v(v: vec4<${varType}>) -> vec4<${varType}> { + return ${tanhExpression('v')}; +} +`; + +export const fastGeluExpression = (x: string) => + `(fast_gelu_a + fast_gelu_a * tanh_v(${x} * (fast_gelu_c * ${x} * ${x} + fast_gelu_b))) * ${x}`; + +export const fastGelu = (context: ComputeContext): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); context.compute(createElementwiseProgramInfo( - context.inputs[0], 'Tanh', a => `sign(${a}) * (1 - exp(-2 * abs(${a}))) / (1 + exp(-2 * abs(${a})))`)); + context.inputs[0], 'FastGelu', fastGeluExpression, fastGeluImpl(dataType), undefined, + context.inputs[0].dataType)); }; export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => { diff --git a/js/web/test/data/ops/fast-gelu.jsonc b/js/web/test/data/ops/fast-gelu.jsonc new file mode 100644 index 0000000000000..2550173e95402 --- /dev/null +++ b/js/web/test/data/ops/fast-gelu.jsonc @@ -0,0 +1,211 @@ +[ + { + "name": "FastGelu test without bias", + "operator": "FastGelu", + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "scalar", + "inputs": [ + { + "data": [1], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.841192], + "dims": [], + "type": "float32" + } + ] + }, + { + "name": "[2x4]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.435415, 0.53057, 0.630432], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[3x5]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.0539828, 0.115851, 0.185371, 0.262161, 0.345714, 0.841192, 1.9546, 2.99636, 3.99993, 5, 0.950581, + 1.0617, 1.17393, 1.28671, 1.39957 + ], + "dims": [3, 5], + "type": "float32" + } + ] + } + ] + }, + { + "name": "FastGelu test with bias", + "operator": "FastGelu", + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "scalar", + "inputs": [ + { + "data": [1], + "dims": [], + "type": "float32" + }, + { + "data": [0.5], + "dims": [], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1.39957], + "dims": [], + "type": "float32" + } + ] + }, + { + "name": "[2x4], [4]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.950581, 2.16968, 3.29869, 4.39999, 1.39957, 2.58835, 3.69973, 4.8], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[2x4], [3]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], + "dims": [2, 4], + "type": "float32" + }, + { + "data": [1, 2, 3], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.950581, 2.16968, 3.29869, 1.28671, 2.48492, 3.59959, 1.62411, 2.79331], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[3x5], [2]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + }, + { + "data": [2, 3], + "dims": [2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.06267, 3.19813, 2.27567, 3.39909, 2.48492, 3.99993, 3.99993, 6, 6, 8, 3.09737, 4.19997, 3.29869, + 4.39999, 3.49938 + ], + "dims": [3, 5], + "type": "float32" + } + ] + }, + { + "name": "[3x5], [7]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, 1.5], + "dims": [3, 5], + "type": "float32" + }, + { + "data": [2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7], + "dims": [7], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2.16968, 2.38072, 2.58835, 2.79331, 2.99636, 3.59959, 4.7, 5.1, 6.2, 7.3, 3.49938, 3.69973, 3.89989, + 4.09996, 3.59959 + ], + "dims": [3, 5], + "type": "float32" + } + ] + }, + { + "name": "[4x4], [8]", + "inputs": [ + { + "data": [0.8, -0.5, 0.0, 1, 1.3, 2.1, -0.2, 1.1, 0.5, 0.2, 0.3, -0.6, 3.1, 2.2, -1.1, 0.0], + "dims": [4, 4], + "type": "float32" + }, + { + "data": [-0.5, 0.6, 1.2, 2.1, 1.3, -1, 0, 3.1], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.185371, 0.0539828, 1.0617, 3.09737, 2.58835, 0.950581, -0.0841486, 4.19997, 0, 0.630432, 1.39957, + 1.39957, 4.39999, 1.0617, -0.149419, 3.09737 + ], + "dims": [4, 4], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 56db28b0a379c..55b21283025c2 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1352,6 +1352,7 @@ "equal.jsonc", "exp.jsonc", "expand.jsonc", + "fast-gelu.jsonc", "floor.jsonc", "gather-elements.jsonc", "gemm.jsonc", diff --git a/onnxruntime/contrib_ops/js/fast_gelu.cc b/onnxruntime/contrib_ops/js/fast_gelu.cc new file mode 100644 index 0000000000000..62c538318160d --- /dev/null +++ b/onnxruntime/contrib_ops/js/fast_gelu.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "fast_gelu.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + FastGelu, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + FastGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/fast_gelu.h b/onnxruntime/contrib_ops/js/fast_gelu.h new file mode 100644 index 0000000000000..68c7892741c66 --- /dev/null +++ b/onnxruntime/contrib_ops/js/fast_gelu.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; +JSEP_KERNEL_IMPL(FastGelu, FastGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 498a9f5679eb5..bd58dded026a6 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -8,6 +8,7 @@ namespace contrib { namespace js { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); @@ -24,6 +25,7 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, From cfe93a1eab4af4c911239ad13d7aeaa1307958b9 Mon Sep 17 00:00:00 2001 From: Yongxin Wang Date: Wed, 7 Feb 2024 02:58:46 +0800 Subject: [PATCH 045/207] fix type annotation for sess_options (#19393) ### Description Changed the type annotation of sess_options in InferenceSession's `__init__` method ### Motivation and Context sess_options is one `SessionOptions`, not a sequence of it. It is passed directly into `C.InferenceSession`, and from the definition of [`C.InferenceSession`](https://github.com/microsoft/onnxruntime/blob/efc17e79de8c1a62eb419d19576ccb90b371b0d0/onnxruntime/python/onnxruntime_pybind_state.cc#L1790), we can see that it is not a sequence: ```cpp py::class_(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc") // In Python3, a Python bytes object will be passed to C++ functions that accept std::string or char* // without any conversion. So this init method can be used for model file path (string) and model content (bytes) .def(py::init([](const PySessionOptions& so, const std::string arg, bool is_arg_file_name, bool load_config_from_model = false) { ``` --- onnxruntime/python/onnxruntime_inference_collection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 09f768f53ea65..4106943e8facc 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -358,7 +358,7 @@ class InferenceSession(Session): def __init__( self, path_or_bytes: str | bytes | os.PathLike, - sess_options: Sequence[onnxruntime.SessionOptions] | None = None, + sess_options: onnxruntime.SessionOptions | None = None, providers: Sequence[str | tuple[str, dict[Any, Any]]] | None = None, provider_options: Sequence[dict[Any, Any]] | None = None, **kwargs, From 61b0e04b03ef2ce8b56b5a3a62ea1ae58e52ac73 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 6 Feb 2024 11:02:24 -0800 Subject: [PATCH 046/207] Bump github/issue-labeler from 3.3 to 3.4 (#19410) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [github/issue-labeler](https://github.com/github/issue-labeler) from 3.3 to 3.4.
Release notes

Sourced from github/issue-labeler's releases.

v3.4

What's Changed

New Contributors

Full Changelog: https://github.com/github/issue-labeler/compare/v3.3...v3.4

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=github/issue-labeler&package-manager=github_actions&previous-version=3.3&new-version=3.4)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/labeler.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index ce8fb3160954e..936ab0de899a2 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -7,7 +7,7 @@ jobs: triage: runs-on: ubuntu-latest steps: - - uses: github/issue-labeler@v3.3 + - uses: github/issue-labeler@v3.4 with: repo-token: "${{ secrets.GITHUB_TOKEN }}" configuration-path: .github/labeler.yml From c4b49fb7bf340a0d27c7d8e2cb2508cac7f57ccf Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 6 Feb 2024 12:48:39 -0800 Subject: [PATCH 047/207] [CUDA] remove CUBLAS_TENSOR_OP_MATH mode (#19431) This pull request replaces `CUBLAS_TENSOR_OP_MATH` with `CUBLAS_DEFAULT_MATH`. The changes affect several files, including test cases and a Python script for AMD hipify process. ### Motivation and Context CUBLAS_TENSOR_OP_MATH mode is deprecated: https://docs.nvidia.com/cuda/cublas/index.html#cublasmath-t On CUDA versions prior to 11, users are required to set the math mode to CUBLAS_TENSOR_OP_MATH manually to be able to use tensor cores for FP16. On CUDA 11 and CUDA 12, this is no longer required. Since latest ORT only supports CUDA >= 11 so it is safe to remove CUBLAS_TENSOR_OP_MATH from our code base. --- .../cuda/bert/longformer_attention_impl.cu | 1 - onnxruntime/core/providers/cuda/cuda_common.h | 3 +-- .../providers/cuda/test_cases/gemm_options_test.cc | 12 ++++++------ tools/ci_build/amd_hipify.py | 1 - 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu index f00239460071b..c9c66b73b3e9d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention_impl.cu @@ -1005,7 +1005,6 @@ Status LaunchLongformerAttentionKernel( bool disable_compact_memory, bool use_merged_qkv_weights, bool use_half4) { - CublasMathModeSetter helper(device_prop, cublas, CUBLAS_TENSOR_OP_MATH); size_t softmax_workspace_size = GetLongformerSoftmaxWorkspaceSize(element_size, batch_size, num_heads, diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index e9941ce743bc3..41c999bacee13 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -141,8 +141,7 @@ class HalfGemmOptions { } #else cublasMath_t GetMathMode() const { - // CublasMathModeSetter will check whether device has tensor cores later. - return CUBLAS_TENSOR_OP_MATH; + return CUBLAS_DEFAULT_MATH; } cudaDataType GetComputeType() const { diff --git a/onnxruntime/test/providers/cuda/test_cases/gemm_options_test.cc b/onnxruntime/test/providers/cuda/test_cases/gemm_options_test.cc index 6cac23f14459e..4917701e5197d 100644 --- a/onnxruntime/test/providers/cuda/test_cases/gemm_options_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/gemm_options_test.cc @@ -17,7 +17,7 @@ TEST(CudaGemmOptions, TestDefaultOptions) { EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUBLAS_COMPUTE_32F); #else - EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_TENSOR_OP_MATH); + EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUDA_R_32F); #endif } @@ -30,7 +30,7 @@ TEST(CudaGemmOptions, TestCompute16F) { EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUBLAS_COMPUTE_16F); #else - EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_TENSOR_OP_MATH); + EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUDA_R_16F); #endif } @@ -43,7 +43,7 @@ TEST(CudaGemmOptions, NoReducedPrecision) { EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_MATH_DISALLOW_REDUCED_PRECISION_REDUCTION); EXPECT_EQ(gemm_options.GetComputeType(), CUBLAS_COMPUTE_32F); #else - EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_TENSOR_OP_MATH); + EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUDA_R_32F); #endif } @@ -56,7 +56,7 @@ TEST(CudaGemmOptions, Pedantic) { EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_PEDANTIC_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUBLAS_COMPUTE_32F_PEDANTIC); #else - EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_TENSOR_OP_MATH); + EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUDA_R_32F); #endif } @@ -69,7 +69,7 @@ TEST(CudaGemmOptions, Compute16F_Pedantic) { EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_PEDANTIC_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUBLAS_COMPUTE_16F_PEDANTIC); #else - EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_TENSOR_OP_MATH); + EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUDA_R_16F); #endif } @@ -82,7 +82,7 @@ TEST(CudaGemmOptions, Compute16F_NoReducedPrecision) { EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUBLAS_COMPUTE_16F); #else - EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_TENSOR_OP_MATH); + EXPECT_EQ(gemm_options.GetMathMode(), CUBLAS_DEFAULT_MATH); EXPECT_EQ(gemm_options.GetComputeType(), CUDA_R_16F); #endif } diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 8ea0481c9b101..e286236ba6447 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -117,7 +117,6 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("HIPBLAS_R_16F", "rocblas_datatype_f16_r") s = s.replace("HIPBLAS_R_32F", "rocblas_datatype_f32_r") s = s.replace("ROCBLAS_GEMM_DEFAULT_TENSOR_OP", "rocblas_gemm_algo_standard") - s = s.replace("ROCBLAS_TENSOR_OP_MATH", "0 /* CUBLAS_TENSOR_OP_MATH is deprecated */") # compatible layer s = s.replace("rocblas_gemm_strided_batched_ex", "_compat_rocblas_gemm_strided_batched_ex") From bedf0eee737f1a8ff21de452887d233308446596 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 6 Feb 2024 13:31:33 -0800 Subject: [PATCH 048/207] [CUDA] Add use_tf32 provider option (for FP32 GEMM) (#19357) [TF32](https://blogs.nvidia.com/blog/tensorfloat-32-precision-format/) could help boost performance on GPU of SM >= 80. Sometime, user observes accuracy loss, or need disable TF32 for testing purpose. To disable TF32, it is also possible to set environment variable `NVIDIA_TF32_OVERRIDE = 0`. However, sometime we do not want to use environment variable to avoid impacting other applications, or want to have finer control (like one session using TF32, and another session not). This provider option could help. Here we add a provider option `use_tf32`. When `use_tf32 = 0`, we will disable TF32 for float MatMul/GEMM in cublas. It applies to MatMulNBits, Attention, LongformerAttention, PackedAttention, PackedMultiHeadAttention operators when float GEMM is used internally in the operator. Note that it will not impact other data type, like fp8 gemm could still use TF32 in accumulation. Previously, cublasGemmStridedBatchedHelper does not use TF32 in inference. Here we enabled TF32 by default, so we might observe speed up for FP32 transformers models on SM >= 80. There is another PR that enables the option for cuDNN Conv later. ### Motivation and Context https://github.com/microsoft/onnxruntime/issues/15407 https://github.com/microsoft/onnxruntime/issues/19288 --- .../core/providers/cuda/cuda_context.h | 2 + .../providers/cuda/cuda_provider_options.h | 1 + .../core/providers/cuda/cuda_resource.h | 1 + .../contrib_ops/cpu/bert/attention_common.h | 2 + .../contrib_ops/cuda/bert/attention.cc | 4 +- .../contrib_ops/cuda/bert/attention_impl.cu | 5 +- .../cuda/bert/decoder_attention.cc | 18 ++- .../cuda/bert/decoder_attention_impl.cu | 18 ++- .../cuda/bert/decoder_attention_impl.h | 1 + .../bert/decoder_masked_self_attention.cc | 2 +- .../cuda/bert/longformer_attention.cc | 19 +-- .../cuda/bert/multihead_attention.cc | 2 + .../contrib_ops/cuda/bert/packed_attention.cc | 5 +- .../cuda/bert/packed_attention_impl.cu | 4 +- .../cuda/bert/packed_multihead_attention.cc | 1 + .../bert/packed_multihead_attention_impl.cu | 4 +- .../cuda/bert/relative_attn_bias.cc | 2 +- .../quantization/attention_quantization.cc | 2 + .../cuda/quantization/matmul_bnb4.cc | 3 +- .../cuda/quantization/matmul_nbits.cc | 3 +- .../providers/cuda/cuda_execution_provider.h | 1 + .../cuda/cuda_execution_provider_info.cc | 4 + .../cuda/cuda_execution_provider_info.h | 6 +- onnxruntime/core/providers/cuda/cuda_kernel.h | 4 + .../providers/cuda/cuda_provider_factory.cc | 2 + .../core/providers/cuda/cuda_stream_handle.cc | 3 + .../math/einsum_utils/einsum_auxiliary_ops.cc | 40 ++--- onnxruntime/core/providers/cuda/math/gemm.cc | 6 +- .../core/providers/cuda/math/matmul.cc | 38 +++-- .../providers/cuda/shared_inc/fpgeneric.h | 141 +++++++++++------- onnxruntime/core/providers/rocm/rocm_kernel.h | 4 + .../providers/rocm/shared_inc/fpgeneric.h | 16 +- .../core/session/provider_bridge_ort.cc | 1 + .../kernel_explorer/kernels/cuda/gemm.cu | 14 +- .../contrib_ops/packed_attention_op_test.cc | 3 +- .../test/python/onnxruntime_test_python.py | 2 + 36 files changed, 245 insertions(+), 139 deletions(-) diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 1370f5c4c5e10..108173474db46 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -37,6 +37,7 @@ struct CudaContext : public CustomOpContext { bool cudnn_conv1d_pad_to_nc1d = false; bool enable_skip_layer_norm_strict_mode = false; bool prefer_nhwc = false; + bool use_tf32 = true; void Init(const OrtKernelContext& kernel_ctx) { cuda_stream = FetchResource(kernel_ctx, CudaResource::cuda_stream_t); @@ -52,6 +53,7 @@ struct CudaContext : public CustomOpContext { cudnn_conv1d_pad_to_nc1d = FetchResource(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t); enable_skip_layer_norm_strict_mode = FetchResource(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t); prefer_nhwc = FetchResource(kernel_ctx, CudaResource::prefer_nhwc_t); + use_tf32 = FetchResource(kernel_ctx, CudaResource::use_tf32_t); } template diff --git a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h index 82bb8ba83be4a..6d53760ab60b5 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_provider_options.h +++ b/include/onnxruntime/core/providers/cuda/cuda_provider_options.h @@ -37,4 +37,5 @@ struct OrtCUDAProviderOptionsV2 { // The strict mode has better accuracy but lower performance. int prefer_nhwc = 0; // make the CUDA EP NHWC preferred int use_ep_level_unified_stream = 0; // flag specifying if ep level stream is used or not + int use_tf32 = 1; // use TF32 }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index c0e6328f27122..1fef077860be3 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -18,4 +18,5 @@ enum CudaResource : int { cudnn_conv1d_pad_to_nc1d_t, enable_skip_layer_norm_strict_mode_t, prefer_nhwc_t, + use_tf32_t, }; \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 8afeb874750b4..a34f41d2938c6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -64,6 +64,7 @@ struct AttentionParameters { bool pass_past_in_kv; float mask_filter_value; float scale; + bool use_tf32; AttentionMaskType mask_type; AttentionQkvFormat qkv_format; }; @@ -82,6 +83,7 @@ struct PackedAttentionParameters { int token_count; bool has_relative_position_bias; bool broadcast_res_pos_bias; + bool use_tf32; }; // Parameters deduced from node attributes and inputs/outputs. diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index bf6431cf1afb2..7a807342ad685 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -84,6 +84,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + // Use the second dimension from weight for bias to get q_hidden_size when bias is nullptr std::vector bias_dims{weights->Shape().GetDims()[1]}; const TensorShape bias_shape{bias_dims}; @@ -251,7 +253,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool use_fused_cross_attention = false; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 54c9a5da1e9da..c20f42c4d06bc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -461,7 +461,8 @@ Status UnfusedAttention( total_sequence_length, sequence_length, qk_head_size, &alpha, data.k, qk_head_size, present_size_per_batch_k, data.q, qk_head_size, sequence_length * qk_head_size, - &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); + &zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches, + device_prop, parameters.use_tf32)); DUMP_TENSOR_D("Q", data.q, batch_size, num_heads, sequence_length, qk_head_size); DUMP_TENSOR_D("K", data.k, batch_size, num_heads, qk_head_size, sequence_length); @@ -514,7 +515,7 @@ Status UnfusedAttention( v_head_size, sequence_length, total_sequence_length, &one, data.v, v_head_size, present_size_per_batch_v, scratch2, total_sequence_length, sequence_length * total_sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc index 3f703ae3d05e6..ceee17c2a2d01 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention.cc @@ -273,13 +273,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data()), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (h2, h1)*(h1, S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(q_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_query_buffer_p.get()), n, device_prop, UseTF32())); // gemm_query_buffer in col-base: (h2, S*B) // calcualte k, v @@ -298,13 +298,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { gemm_kv_buffer_p = GetScratchBuffer(static_cast(batch_size) * 2 * key_sequence_length * hidden_size, @@ -318,13 +318,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(key->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } } else { @@ -342,13 +342,13 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, 1, &one, reinterpret_cast(bias->Data() + hidden_size), n, GetConstOnes(m, Stream(context)), 1, - &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // matmul: (2*h2, h1)*(h1, T_S*B) CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(kv_weights->Data()), n, reinterpret_cast(query->Data()), k, - &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop)); + &one, reinterpret_cast(gemm_kv_buffer_p.get()), n, device_prop, UseTF32())); // gemm_kv_buffer in col-base: (2*h2, T_S*B) } else { kv_sequence_length = cache_sequence_length; @@ -372,6 +372,8 @@ Status DecoderAttention::ComputeInternal(OpKernelContext* context) const { device_prop, #ifdef USE_ROCM GetTuningContext(), +#else + UseTF32(), #endif context->GetComputeStream(), cublas, diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu index 1dc22a9c8ea98..e24d9da94c964 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -37,7 +37,8 @@ Status DecoderQkvToContext( T* workspace_buffer, T* output, T* new_key_cache, - T* new_value_cache) { + T* new_value_cache, + bool use_tf32) { const int max_threads_per_block = device_prop.maxThreadsPerBlock; const int BN = batch_size * num_heads; const int BHN = BN * head_size; @@ -128,14 +129,14 @@ Status DecoderQkvToContext( kv_sequence_length, sequence_length, head_size, &alpha, key_cache, head_size, strideA, q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32)); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_T, CUBLAS_OP_N, kv_sequence_length, sequence_length, head_size, &alpha, k, head_size, strideA, q, head_size, strideB, - &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop)); + &zero, scratch1, kv_sequence_length, temp_matrix_size, BN, device_prop, use_tf32)); } constexpr bool is_unidirectional = false; @@ -163,14 +164,14 @@ Status DecoderQkvToContext( head_size, sequence_length, kv_sequence_length, &one, value_cache, head_size, strideA, scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + &zero, scratch3, head_size, strideB, BN, device_prop, use_tf32)); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, head_size, sequence_length, kv_sequence_length, &one, v, head_size, strideA, scratch2, kv_sequence_length, temp_matrix_size, - &zero, scratch3, head_size, strideB, BN, device_prop)); + &zero, scratch3, head_size, strideB, BN, device_prop, use_tf32)); } // scratch3 is BxNxSxH, transpose to output SxBxNxH @@ -180,6 +181,7 @@ Status DecoderQkvToContext( Status LaunchDecoderAttentionKernel( const cudaDeviceProp& device_prop, + bool use_tf32, Stream* stream, cublasHandle_t& cublas, const size_t element_size, @@ -228,7 +230,8 @@ Status LaunchDecoderAttentionKernel( reinterpret_cast(workspace_buffer), reinterpret_cast(output), reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); + reinterpret_cast(new_value_cache), + use_tf32); } else { return DecoderQkvToContext( device_prop, @@ -254,7 +257,8 @@ Status LaunchDecoderAttentionKernel( reinterpret_cast(workspace_buffer), reinterpret_cast(output), reinterpret_cast(new_key_cache), - reinterpret_cast(new_value_cache)); + reinterpret_cast(new_value_cache), + use_tf32); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h index 9db9ccb45e330..f9667a613e648 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.h @@ -11,6 +11,7 @@ namespace cuda { Status LaunchDecoderAttentionKernel( const cudaDeviceProp& prop, // Device Properties + bool use_tf32, // Use TF32 Stream* stream, // ORT Stream cublasHandle_t& cublas, // Cublas handle const size_t element_size, // Element size of input tensor diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc index 72ede2e22b557..07a6fbd60e171 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_masked_self_attention.cc @@ -143,7 +143,7 @@ Status DecoderMaskedSelfAttention::ComputeInternal(OpKernelContext* cont cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); // Update the q, k, and v buffers parameters.q = gemm_buffer.get(); diff --git a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc index e556ae4a490e9..9c5d0e9834f6f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/longformer_attention.cc @@ -136,7 +136,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, weights_data, n, input_data, k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); } else { // q const CudaT* q_weight = weights_data; @@ -145,7 +145,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, q_weight, n, input_data, k, - &zero, q_data, n, device_prop)); + &zero, q_data, n, device_prop, UseTF32())); // k const CudaT* k_weight = q_weight + static_cast(hidden_size) * hidden_size; CudaT* k_data = q_data + static_cast(batch_size) * sequence_length * hidden_size; @@ -153,7 +153,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, k_weight, n, input_data, k, - &zero, k_data, n, device_prop)); + &zero, k_data, n, device_prop, UseTF32())); // v const CudaT* v_weight = k_weight + static_cast(hidden_size) * hidden_size; @@ -162,7 +162,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, v_weight, n, input_data, k, - &zero, v_data, n, device_prop)); + &zero, v_data, n, device_prop, UseTF32())); } // Wait for async copy of batch_global_num @@ -195,7 +195,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(global_weights->Data()), n, input_data, k, - &zero, global_gemm_buffer, n, device_prop)); + &zero, global_gemm_buffer, n, device_prop, UseTF32())); } else { // global q const CudaT* global_q_weight = global_weights_data; @@ -205,7 +205,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_q_weight, n, input_data, k, - &zero, global_q, n, device_prop)); + &zero, global_q, n, device_prop, UseTF32())); } else { CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( cublas, @@ -226,7 +226,8 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { hidden_size, // ldc static_cast(max_num_global) * hidden_size, // strideC batch_size, // batch count - device_prop)); + device_prop, + UseTF32())); } // global k const CudaT* global_k_weight = global_weights_data + static_cast(hidden_size) * hidden_size; @@ -235,7 +236,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_k_weight, n, input_data, k, - &zero, global_k, n, device_prop)); + &zero, global_k, n, device_prop, UseTF32())); // global v const CudaT* global_v_weight = global_k_weight + static_cast(hidden_size) * hidden_size; @@ -244,7 +245,7 @@ Status LongformerAttention::ComputeInternal(OpKernelContext* context) const { cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, global_v_weight, n, input_data, k, - &zero, global_v, n, device_prop)); + &zero, global_v, n, device_prop, UseTF32())); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index f978f50c6851f..2ef011cdd9a21 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -94,6 +94,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { auto& device_prop = GetDeviceProp(); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckInputs(query, key, value, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index ec8b1d051b3d9..55deed55dfd33 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -268,6 +268,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* relative_position_bias = context->Input(5); PackedAttentionParameters parameters; + parameters.use_tf32 = UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(input->Shape(), weights->Shape(), bias->Shape(), @@ -308,12 +309,12 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { cublasHandle_t cublas = this->GetCublasHandle(context); // Gemm, note that CUDA assumes col-major, so result(N, M) = 1 * weights x input + 1 x bias - // The bias part is not included here since we fuse bias, transpose and output 3 matrice into one cuda kernel. + // The bias part is not included here since we fuse bias, transpose and output 3 matrices into one cuda kernel. CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( cublas, CUBLAS_OP_N, CUBLAS_OP_N, n, m, k, &one, reinterpret_cast(weights->Data()), n, reinterpret_cast(input->Data()), k, - &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop)); + &zero, reinterpret_cast(gemm_buffer.get()), n, device_prop, UseTF32())); constexpr size_t element_size = sizeof(T); constexpr bool no_qkv_workspace = false; // need workspace to add bias diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 3b52320839403..ce7ac3796dbe1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -596,7 +596,7 @@ Status UnfusedScaledDotProductAttention( q, qk_head_size, sequence_length * qk_head_size, &zero, scaled_qk, sequence_length, sequence_length * sequence_length, - batches, device_prop)); + batches, device_prop, parameters.use_tf32)); DUMP_TENSOR_D("PackedAttention unfused QK", scaled_qk, batch_size * num_heads, sequence_length, sequence_length); @@ -624,7 +624,7 @@ Status UnfusedScaledDotProductAttention( v_head_size, sequence_length, sequence_length, &one, v, v_head_size, sequence_length * v_head_size, attention_score, sequence_length, sequence_length * sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose and remove padding to output token_countxNxH_v Status result = LaunchTransposeRemovePadding( diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 1b026e64778e3..b4a162989978c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -228,6 +228,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co const Tensor* relative_position_bias = context->Input(6); PackedAttentionParameters parameters; + parameters.use_tf32 = UseTF32(); ORT_RETURN_IF_ERROR(CheckInputs(query->Shape(), key, value, diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 83af018a97ea6..49029da12a308 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -775,7 +775,7 @@ Status UnfusedAttention( q, qk_head_size, sequence_length * qk_head_size, &zero, scaled_qk, sequence_length, sequence_length * sequence_length, - batches, device_prop)); + batches, device_prop, parameters.use_tf32)); // Q, K and V are ready now DUMP_TENSOR_INIT(); @@ -808,7 +808,7 @@ Status UnfusedAttention( v_head_size, sequence_length, sequence_length, &one, v, v_head_size, sequence_length * v_head_size, attention_score, sequence_length, sequence_length * sequence_length, - &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop)); + &zero, temp_output, v_head_size, sequence_length * v_head_size, batches, device_prop, parameters.use_tf32)); // Temp_output is BxNxSxH_v, transpose and remove padding to output TxNxH_v Status result = LaunchTransposeRemovePadding( diff --git a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc index 92ba808dd85c2..05f55d9106d0e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc +++ b/onnxruntime/contrib_ops/cuda/bert/relative_attn_bias.cc @@ -200,7 +200,7 @@ Status GatedRelativePositionBias::ComputeInternal(OpKernelContext* context) c D, BNS, head_size, &one, reinterpret_cast(weight_tensor.template Data()), (int)D, reinterpret_cast(workspace.get()), (int)head_size, - &zero, gemm_output, ld_gemm_output, device_prop)); + &zero, gemm_output, ld_gemm_output, device_prop, UseTF32())); auto status = LaunchGatedRelativePositionBiasKernel( device_prop, stream, diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index 705f2d49fe2bf..001b6070d5e1a 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -106,6 +106,8 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* past_tensor = context->Input(8); AttentionParameters parameters; + parameters.use_tf32 = UseTF32(); + ORT_RETURN_IF_ERROR(CheckInputs(input, weights, bias, diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc index bbcb7de99781f..0534ed6dc7fc0 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -117,7 +117,8 @@ Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { &zero, reinterpret_cast(Y->MutableData()), helper.Ldc(), - GetDeviceProp())); + GetDeviceProp(), + UseTF32())); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 5b0e61e197014..015df70c8ec3c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -135,7 +135,8 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { &zero, reinterpret_cast(Y->MutableData()), helper.Ldc(), - GetDeviceProp())); + GetDeviceProp(), + UseTF32())); } } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index d0bb2321edf0a..55f0b5570e0ee 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -78,6 +78,7 @@ class CUDAExecutionProvider : public IExecutionProvider { bool GetCudnnConv1dPadToNc1d() const { return info_.cudnn_conv1d_pad_to_nc1d; } bool IsSkipLayerNormInStrictMode() const { return info_.enable_skip_layer_norm_strict_mode; } bool IsNHWCPreferred() const { return info_.prefer_nhwc; } + bool UseTF32() const { return info_.use_tf32; } ProviderOptions GetProviderOptions() const override { return CUDAExecutionProviderInfo::ToProviderOptions(info_); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc index 81ddc38820914..c96381e3e68b1 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.cc @@ -33,6 +33,7 @@ constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_dur constexpr const char* kEnableSkipLayerNormStrictMode = "enable_skip_layer_norm_strict_mode"; constexpr const char* kPreferNHWCMode = "prefer_nhwc"; constexpr const char* kUseEPLevelUnifiedStream = "use_ep_level_unified_stream"; +constexpr const char* kUseTF32 = "use_tf32"; } // namespace provider_option_names } // namespace cuda @@ -115,6 +116,7 @@ CUDAExecutionProviderInfo CUDAExecutionProviderInfo::FromProviderOptions(const P .AddAssignmentToReference(cuda::provider_option_names::kEnableSkipLayerNormStrictMode, info.enable_skip_layer_norm_strict_mode) .AddAssignmentToReference(cuda::provider_option_names::kPreferNHWCMode, info.prefer_nhwc) .AddAssignmentToReference(cuda::provider_option_names::kUseEPLevelUnifiedStream, info.use_ep_level_unified_stream) + .AddAssignmentToReference(cuda::provider_option_names::kUseTF32, info.use_tf32) .AddValueParser( cuda::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -167,6 +169,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const CUDAExecution {cuda::provider_option_names::kEnableSkipLayerNormStrictMode, MakeStringWithClassicLocale(info.enable_skip_layer_norm_strict_mode)}, {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, }; return options; @@ -188,6 +191,7 @@ ProviderOptions CUDAExecutionProviderInfo::ToProviderOptions(const OrtCUDAProvid {cuda::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op_max_tuning_duration_ms)}, {cuda::provider_option_names::kPreferNHWCMode, MakeStringWithClassicLocale(info.prefer_nhwc)}, {cuda::provider_option_names::kUseEPLevelUnifiedStream, MakeStringWithClassicLocale(info.use_ep_level_unified_stream)}, + {cuda::provider_option_names::kUseTF32, MakeStringWithClassicLocale(info.use_tf32)}, }; return options; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h index 04eea2f6c8e94..1cac3d1513698 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider_info.h @@ -76,6 +76,9 @@ struct CUDAExecutionProviderInfo { bool use_ep_level_unified_stream{false}; + // By default, enable TF32 to speed up float GEMM/MatMul or cuDNN convolution of float matrices. + bool use_tf32{true}; + static CUDAExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const CUDAExecutionProviderInfo& info); static ProviderOptions ToProviderOptions(const OrtCUDAProviderOptionsV2& info); @@ -100,7 +103,8 @@ struct std::hash<::onnxruntime::CUDAExecutionProviderInfo> { (static_cast(info.cudnn_conv1d_pad_to_nc1d) << 26) ^ (static_cast(info.enable_skip_layer_norm_strict_mode) << 27) ^ (static_cast(info.prefer_nhwc) << 28) ^ - (static_cast(info.use_ep_level_unified_stream) << 29); + (static_cast(info.use_ep_level_unified_stream) << 29) ^ + (static_cast(info.use_tf32) << 30); onnxruntime::HashCombine(data, value); onnxruntime::HashCombine(info.gpu_mem_limit, value); diff --git a/onnxruntime/core/providers/cuda/cuda_kernel.h b/onnxruntime/core/providers/cuda/cuda_kernel.h index e3106e41e77c8..288da23f35ec8 100644 --- a/onnxruntime/core/providers/cuda/cuda_kernel.h +++ b/onnxruntime/core/providers/cuda/cuda_kernel.h @@ -90,6 +90,10 @@ class CudaKernel : public OpKernel { return stream->cublas_handle_; } + bool UseTF32() const { + return provider_->UseTF32(); + } + tunable::CudaTuningContext* GetTuningContext() const { return static_cast(provider_->GetTuningContext()); } diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index 892e8d5329eba..103c79c93b2ca 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -225,6 +225,7 @@ struct CUDA_Provider : Provider { info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; info.enable_skip_layer_norm_strict_mode = params->enable_skip_layer_norm_strict_mode != 0; info.use_ep_level_unified_stream = params->use_ep_level_unified_stream != 0; + info.use_tf32 = params->use_tf32 != 0; return std::make_shared(info); } @@ -258,6 +259,7 @@ struct CUDA_Provider : Provider { cuda_options.enable_skip_layer_norm_strict_mode = internal_options.enable_skip_layer_norm_strict_mode; cuda_options.prefer_nhwc = internal_options.prefer_nhwc; cuda_options.use_ep_level_unified_stream = internal_options.use_ep_level_unified_stream; + cuda_options.use_tf32 = internal_options.use_tf32; } ProviderOptions GetProviderOptions(const void* provider_options) override { diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index 0a256394b7d99..3c0bf183362dd 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -212,6 +212,9 @@ void* CudaStream::GetResource(int version, int id) const { case CudaResource::prefer_nhwc_t: return reinterpret_cast(ep_info_.prefer_nhwc); break; + case CudaResource::use_tf32_t: + return reinterpret_cast(ep_info_.use_tf32); + break; default: break; } diff --git a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc index 3e50116eafd17..ee0334e552022 100644 --- a/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc +++ b/onnxruntime/core/providers/cuda/math/einsum_utils/einsum_auxiliary_ops.cc @@ -51,25 +51,27 @@ Status MatMul(const T* input_1_data, const T* input_2_data, T* output_data, CudaT one = cuda::ToCudaType::FromFloat(1.0f); CudaT zero = cuda::ToCudaType::FromFloat(0.0f); - CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(static_cast(einsum_cuda_assets)->cublas_handle_, - CUBLAS_OP_N, - CUBLAS_OP_N, - static_cast(N), - static_cast(M), - static_cast(K), - &one, - reinterpret_cast(input_2_data), - static_cast(N), - static_cast(right_stride), - reinterpret_cast(input_1_data), - static_cast(K), - static_cast(left_stride), - &zero, - reinterpret_cast(output_data), - static_cast(N), - static_cast(output_stride), - static_cast(num_batches), - static_cast(einsum_cuda_assets)->cuda_ep_->GetDeviceProp())); + CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper( + static_cast(einsum_cuda_assets)->cublas_handle_, + CUBLAS_OP_N, + CUBLAS_OP_N, + static_cast(N), + static_cast(M), + static_cast(K), + &one, + reinterpret_cast(input_2_data), + static_cast(N), + static_cast(right_stride), + reinterpret_cast(input_1_data), + static_cast(K), + static_cast(left_stride), + &zero, + reinterpret_cast(output_data), + static_cast(N), + static_cast(output_stride), + static_cast(num_batches), + static_cast(einsum_cuda_assets)->cuda_ep_->GetDeviceProp(), + static_cast(einsum_cuda_assets)->cuda_ep_->UseTF32())); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/gemm.cc b/onnxruntime/core/providers/cuda/math/gemm.cc index 8fe23c9a036cc..4e61e0c8c69c6 100644 --- a/onnxruntime/core/providers/cuda/math/gemm.cc +++ b/onnxruntime/core/providers/cuda/math/gemm.cc @@ -118,7 +118,7 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const b_data, N, GetConstOnes(M, Stream(ctx)), 1, /*beta*/ &zero, - out_data, N, device_prop)); + out_data, N, device_prop, UseTF32())); } else if (b_shape.NumDimensions() == 2 && b_shape[1] == 1) { // B is (M, 1), broadcast using Y(N,M) = 1 * ones(N,1) x B(1,M) + 0 * Y CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( @@ -130,7 +130,7 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const GetConstOnes(N, Stream(ctx)), N, b_data, 1, /*beta*/ &zero, - out_data, N, device_prop)); + out_data, N, device_prop, UseTF32())); } else { // B is (M, N), no broadcast needed. CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(out_data, b_data, static_cast(M) * N * sizeof(T), cudaMemcpyDeviceToDevice, Stream(ctx))); @@ -153,7 +153,7 @@ Status Gemm::ComputeDefault(OpKernelContext* ctx, int M, int N, int K) const // ideally we need to set the output buffer contents to 0 if bias is missing, // but passing 0 for beta is cheaper and it will ignore any junk in the output buffer B != nullptr ? &beta : &zero, - out_data, N, device_prop)); + out_data, N, device_prop, UseTF32())); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/math/matmul.cc b/onnxruntime/core/providers/cuda/math/matmul.cc index e4c37c52a1780..6e126fbeadce8 100644 --- a/onnxruntime/core/providers/cuda/math/matmul.cc +++ b/onnxruntime/core/providers/cuda/math/matmul.cc @@ -173,7 +173,8 @@ Status FuncMatMul( &cuda_zero, reinterpret_cast(Y->MutableData()), ldc, - device_prop)); + device_prop, + cuda_kernel->UseTF32())); return Status::OK(); } else if (CanUseStridedBatchedGemm(A->Shape(), B->Shape(), trans_A, trans_B, trans_batch_B, trans_batch_B, stride_A, stride_B, stride_C, batch_count)) { @@ -195,7 +196,8 @@ Status FuncMatMul( ldc, stride_C, static_cast(batch_count), - device_prop)); + device_prop, + cuda_kernel->UseTF32())); return Status::OK(); } @@ -213,12 +215,12 @@ Status FuncMatMul( ORT_RETURN_IF_ERROR(Y_arrays.CopyToGpu(ctx->GetComputeStream())); // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. - // It requires Ampere or newer GPU, and pointers of matrics shall be aligned (ideal alignment is 16-byte). + // It requires Ampere or newer GPU, and pointers of matrices shall be aligned (ideal alignment is 16-byte). // Assume that start memory of input/output tensor is aligned, we only check offsets of sub-matrix per batch here. - cublasMath_t mode = (std::is_same::value && device_prop.major >= 8 && helper.IsBatchedGemmAligned()) - ? CUBLAS_TF32_TENSOR_OP_MATH - : CUBLAS_DEFAULT_MATH; - CublasMathModeSetter math_mode_setter(device_prop, cuda_kernel->GetCublasHandle(ctx), mode); + bool use_tf32 = std::is_same::value && + cuda_kernel->UseTF32() && + device_prop.major >= 8 && + helper.IsBatchedGemmAligned(); // note that onnxruntime OrtValue is row major, while cublas is column major, // so swap left/right operands @@ -238,7 +240,8 @@ Status FuncMatMul( Y_arrays.GpuPtr(), ldc, static_cast(helper.OutputOffsets().size()), - device_prop)); + device_prop, + use_tf32)); return Status::OK(); } @@ -321,7 +324,8 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help &zero, reinterpret_cast(Y->MutableData()), ldc, - device_prop)); + device_prop, + UseTF32())); return Status::OK(); } else if (CanUseStridedBatchedGemm(left_X->Shape(), right_X->Shape(), transa, transb, trans_batch_a_, trans_batch_b_, stride_A, stride_B, stride_C, batch_count)) { @@ -343,7 +347,8 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help ldc, stride_C, static_cast(batch_count), - device_prop)); + device_prop, + UseTF32())); return Status::OK(); } @@ -361,12 +366,12 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help ORT_RETURN_IF_ERROR(output_arrays.CopyToGpu(ctx->GetComputeStream())); // TF32 provides a huge performance gain for training and inference while preserving FP32 levels of accuracy. - // It requires Ampere or newer GPU, and pointers of matrics shall be aligned (ideal alignment is 16-byte). + // It requires Ampere or newer GPU, and pointers of matrices shall be aligned (ideal alignment is 16-byte). // Assume that start memory of input/output tensor is aligned, we only check offsets of sub-matrix per batch here. - cublasMath_t mode = (std::is_same::value && device_prop.major >= 8 && helper.IsBatchedGemmAligned()) - ? CUBLAS_TF32_TENSOR_OP_MATH - : CUBLAS_DEFAULT_MATH; - CublasMathModeSetter math_mode_setter(device_prop, GetCublasHandle(ctx), mode); + bool use_tf32 = std::is_same::value && + this->UseTF32() && + device_prop.major >= 8 && + helper.IsBatchedGemmAligned(); // note that onnxruntime OrtValue is row major, while cublas is column major, // so swap left/right operands @@ -386,7 +391,8 @@ Status MatMul::ComputeDefault(OpKernelContext* ctx, MatMulComputeHelper& help output_arrays.GpuPtr(), ldc, static_cast(helper.OutputOffsets().size()), - device_prop)); + device_prop, + use_tf32)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h index 510cc5cfbb7dd..053c66ddcb34a 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/cuda/shared_inc/fpgeneric.h @@ -29,13 +29,15 @@ cublasGemmHelper(cublasHandle_t handle, const float* B, int ldb, const float* beta, float* C, int ldc, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool use_tf32) { #if defined(USE_CUDA) - // TF32 uses 10 bit mantissa which has sufficient margin of precision for most use cases. It gets 8x throughput than FP32 in A100. - // It can be overrided by setting environment variable NVIDIA_TF32_OVERRIDE = 0 to disable TF32 - onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, CUBLAS_TF32_TENSOR_OP_MATH); + // To disable TF32, set environment variable NVIDIA_TF32_OVERRIDE = 0 or set provider option use_tf32 = 0 + cublasMath_t mode = use_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, mode); #else ORT_UNUSED_PARAMETER(prop); + ORT_UNUSED_PARAMETER(use_tf32); #endif return cublasSgemm(handle, @@ -58,7 +60,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, const double* B, int ldb, const double* beta, double* C, int ldc, - const cudaDeviceProp& /*prop*/) { + const cudaDeviceProp& /*prop*/, + bool /*use_tf32*/) { return cublasDgemm(handle, transa, transb, @@ -79,7 +82,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, const half* B, int ldb, const half* beta, half* C, int ldc, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -121,7 +125,8 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, const half* B, int ldb, const float* beta, half* C, int ldc, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -155,10 +160,11 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, } #if defined(USE_CUDA) -inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, - int n, int k, const BFloat16* alpha, const BFloat16* A, int lda, - const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc, - const cudaDeviceProp& /*prop*/) { +inline cublasStatus_t cublasGemmHelper( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, + int n, int k, const BFloat16* alpha, const BFloat16* A, int lda, + const BFloat16* B, int ldb, const BFloat16* beta, BFloat16* C, int ldc, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -169,7 +175,7 @@ inline cublasStatus_t cublasGemmHelper(cublasHandle_t handle, cublasOperation_t #else inline cublasStatus_t cublasGemmHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const BFloat16*, const BFloat16*, int, const BFloat16*, int, const BFloat16*, - BFloat16*, int, const cudaDeviceProp&) { + BFloat16*, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif @@ -185,7 +191,17 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const float* beta, float* Carray[], int ldc, int batch_count, - const cudaDeviceProp&) { + const cudaDeviceProp& prop, + bool use_tf32) { +// The caller shall check memory alignments of the matrices when use_tf32 is true. +#if defined(USE_CUDA) + cublasMath_t mode = use_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, mode); +#else + ORT_UNUSED_PARAMETER(prop); + ORT_UNUSED_PARAMETER(use_tf32); +#endif + return cublasSgemmBatched(handle, transa, transb, @@ -208,7 +224,8 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const double* beta, double* Carray[], int ldc, int batch_count, - const cudaDeviceProp& /*prop*/) { + const cudaDeviceProp& /*prop*/, + bool /*use_tf32*/) { return cublasDgemmBatched(handle, transa, transb, @@ -231,7 +248,8 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, const half* beta, half* Carray[], int ldc, int batch_count, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -266,11 +284,12 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, } #if defined(USE_CUDA) -inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[], - int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta, - BFloat16* Carray[], int ldc, int batch_count, - const cudaDeviceProp& /*prop*/) { +inline cublasStatus_t cublasGemmBatchedHelper( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const BFloat16* alpha, const BFloat16* Aarray[], + int lda, const BFloat16* Barray[], int ldb, const BFloat16* beta, + BFloat16* Carray[], int ldc, int batch_count, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); @@ -282,7 +301,8 @@ inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t handle, cublasOpera #else inline cublasStatus_t cublasGemmBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, int, const BFloat16*, const BFloat16*[], int, const BFloat16*[], int, - const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&) { + const BFloat16*, BFloat16*[], int, int, const cudaDeviceProp&, + bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif @@ -301,15 +321,14 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, float* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& prop) { -#ifdef ENABLE_TRAINING_OPS + const cudaDeviceProp& prop, + bool use_tf32) { #if defined(USE_CUDA) - onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, CUBLAS_TF32_TENSOR_OP_MATH); -#else - ORT_UNUSED_PARAMETER(prop); -#endif + cublasMath_t mode = use_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; + onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, mode); #else ORT_UNUSED_PARAMETER(prop); + ORT_UNUSED_PARAMETER(use_tf32); #endif return cublasSgemmStridedBatched(handle, @@ -337,7 +356,8 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, double* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& /*prop*/) { + const cudaDeviceProp& /*prop*/, + bool /*use_tf32*/) { return cublasDgemmStridedBatched(handle, transa, transb, @@ -363,7 +383,8 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, __half* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -411,7 +432,8 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, __half* C, int ldc, long long int strideC, int batch_count, - const cudaDeviceProp& prop) { + const cudaDeviceProp& prop, + bool /*use_tf32*/) { const HalfGemmOptions* half_options = HalfGemmOptions::GetInstance(); onnxruntime::cuda::CublasMathModeSetter math_mode_setter(prop, handle, half_options->GetMathMode()); if (half_options->IsCompute16F()) { @@ -447,49 +469,66 @@ inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, } #if defined(USE_CUDA) -inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t handle, cublasOperation_t transa, - cublasOperation_t transb, int m, int n, int k, - const BFloat16* alpha, const BFloat16* A, int lda, - long long int strideA, const BFloat16* B, int ldb, - long long int strideB, const BFloat16* beta, BFloat16* C, int ldc, - long long int strideC, int batch_count, - const cudaDeviceProp& /*prop*/) { +inline cublasStatus_t cublasGemmStridedBatchedHelper( + cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const BFloat16* alpha, const BFloat16* A, int lda, + long long int strideA, const BFloat16* B, int ldb, + long long int strideB, const BFloat16* beta, BFloat16* C, int ldc, + long long int strideC, int batch_count, + const cudaDeviceProp& /*prop*/, bool /*use_tf32*/) { float h_a = alpha->ToFloat(); float h_b = beta->ToFloat(); // accumulating in FP32 - return cublasGemmStridedBatchedEx(handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, strideA, B, CUDA_R_16BF, - ldb, strideB, &h_b, C, CUDA_R_16BF, ldc, strideC, batch_count, CUDA_R_32F, - CUBLAS_GEMM_DEFAULT); + return cublasGemmStridedBatchedEx( + handle, transa, transb, m, n, k, &h_a, A, CUDA_R_16BF, lda, strideA, B, CUDA_R_16BF, + ldb, strideB, &h_b, C, CUDA_R_16BF, ldc, strideC, batch_count, CUDA_R_32F, + CUBLAS_GEMM_DEFAULT); } #else -inline cublasStatus_t cublasGemmStridedBatchedHelper(cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, - int, const BFloat16*, const BFloat16*, int, long long int, - const BFloat16*, int, long long int, const BFloat16*, BFloat16*, - int, long long int, int, const cudaDeviceProp&) { +inline cublasStatus_t cublasGemmStridedBatchedHelper( + cublasHandle_t, cublasOperation_t, cublasOperation_t, int, int, + int, const BFloat16*, const BFloat16*, int, long long int, + const BFloat16*, int, long long int, const BFloat16*, BFloat16*, + int, long long int, int, const cudaDeviceProp&, bool /*use_tf32*/) { return CUBLAS_STATUS_NOT_SUPPORTED; } #endif // transpose using geam -inline cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, float* C, int ldc) { +inline cublasStatus_t cublasTransposeHelper( + cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const float* alpha, const float* A, int lda, const float* beta, const float* B, int ldb, + float* C, int ldc) { return cublasSgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } -inline cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, double* C, int ldc) { +inline cublasStatus_t cublasTransposeHelper( + cudaStream_t, cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, const double* alpha, const double* A, int lda, const double* beta, const double* B, int ldb, + double* C, int ldc) { return cublasDgeam(handle, transa, transb, m, n, alpha, A, lda, beta, B, ldb, C, ldc); } bool CanUse_cublasTransposeHelper_MLFloat16(int m, int n); -cublasStatus_t cublasTransposeHelper(cudaStream_t, cublasHandle_t, cublasOperation_t, cublasOperation_t, int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); + +cublasStatus_t cublasTransposeHelper( + cudaStream_t, cublasHandle_t, cublasOperation_t, cublasOperation_t, + int m, int n, const half*, const half* A, int, const half*, const half*, int, half* C, int); // copy -inline cublasStatus_t cublasCopyHelper(cudaStream_t, cublasHandle_t handle, int n, const float* x, int incx, float* y, int incy) { +inline cublasStatus_t cublasCopyHelper( + cudaStream_t, cublasHandle_t handle, int n, const float* x, int incx, float* y, int incy) { return cublasScopy(handle, n, x, incx, y, incy); } -inline cublasStatus_t cublasCopyHelper(cudaStream_t, cublasHandle_t handle, int n, const double* x, int incx, double* y, int incy) { +inline cublasStatus_t cublasCopyHelper( + cudaStream_t, cublasHandle_t handle, int n, const double* x, int incx, double* y, int incy) { return cublasDcopy(handle, n, x, incx, y, incy); } -cublasStatus_t cublasCopyHelper(cudaStream_t stream, cublasHandle_t handle, int n, const half* x, int incx, half* y, int incy); -cublasStatus_t cublasCopyHelper(cudaStream_t stream, cublasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); +cublasStatus_t cublasCopyHelper( + cudaStream_t stream, cublasHandle_t handle, int n, const half* x, int incx, half* y, int incy); + +cublasStatus_t cublasCopyHelper( + cudaStream_t stream, cublasHandle_t handle, int n, const BFloat16* x, int incx, BFloat16* y, int incy); diff --git a/onnxruntime/core/providers/rocm/rocm_kernel.h b/onnxruntime/core/providers/rocm/rocm_kernel.h index c0b7d4722d3e4..70bf08d65401a 100644 --- a/onnxruntime/core/providers/rocm/rocm_kernel.h +++ b/onnxruntime/core/providers/rocm/rocm_kernel.h @@ -101,6 +101,10 @@ class RocmKernel : public OpKernel { return static_cast(provider_->GetTuningContext()); } + bool UseTF32() const { + return false; + } + // To support hipMemcpyAsync, the cpu memory should be allocated in pinned memory // and it can only be released after the copy has finished template diff --git a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h index 7cbc37cb64c5a..d93f70785c093 100644 --- a/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h +++ b/onnxruntime/core/providers/rocm/shared_inc/fpgeneric.h @@ -115,7 +115,8 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle, const half* B, int ldb, const float* beta, half* C, int ldc, - const hipDeviceProp_t&) { + const hipDeviceProp_t&, + bool /*use_tf32*/) { return rocblasGemmHelper(handle, transa, transb, @@ -154,7 +155,7 @@ inline rocblas_status rocblasGemmHelper(rocblas_handle handle, rocblas_gemm_algo_standard, 0, 0); } -// Compatible for function call with the extra hipDeviceProp_t argument +// Compatible for function call with extra arguments (see cublasGemmHelper) template rocblas_status rocblasGemmHelper(rocblas_handle handle, rocblas_operation transa, @@ -165,7 +166,8 @@ rocblas_status rocblasGemmHelper(rocblas_handle handle, const Scalar* B, int ldb, const Scalar* beta, Scalar* C, int ldc, - const hipDeviceProp_t&) { + const hipDeviceProp_t&, + bool /*use_tf32*/) { return rocblasGemmHelper(handle, transa, transb, @@ -404,7 +406,7 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, rocblas_gemm_algo_standard, 0, 0); } -// Compatible for function call with the extra hipDeviceProp_t argument +// Compatible for function call with with extra arguments (see cublasGemmStridedBatchedHelper) template rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, rocblas_operation transa, @@ -419,7 +421,8 @@ rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, Scalar* C, int ldc, intmax_t strideC, int batchCount, - const hipDeviceProp_t&) { + const hipDeviceProp_t&, + bool /*use_tf32*/) { return rocblasGemmStridedBatchedHelper(handle, transa, transb, @@ -445,7 +448,8 @@ inline rocblas_status rocblasGemmStridedBatchedHelper(rocblas_handle handle, __half* C, int ldc, intmax_t strideC, int batchCount, - const hipDeviceProp_t&) { + const hipDeviceProp_t&, + bool /*use_tf32*/) { return rocblasGemmStridedBatchedHelper(handle, transa, transb, diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 32ae15e71acc6..bb8732784945d 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1555,6 +1555,7 @@ OrtCUDAProviderOptionsV2 OrtCUDAProviderOptionsToOrtCUDAProviderOptionsV2(const cuda_options_converted.cudnn_conv1d_pad_to_nc1d = 0; cuda_options_converted.enable_skip_layer_norm_strict_mode = 0; cuda_options_converted.use_ep_level_unified_stream = 0; + cuda_options_converted.use_tf32 = 1; return cuda_options_converted; } diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu index fd9e9c4fd1612..8b05b96ec38a9 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/gemm.cu @@ -56,6 +56,9 @@ class GemmBenchmark : public IKernelExplorer { typedef typename ToCudaType::MappedType CudaT; CudaT one = ToCudaType::FromFloat(1.0f); CudaT zero = ToCudaType::FromFloat(0.0f); + + // TF32 is enable by default. To disable TF32, set environment variable NVIDIA_TF32_OVERRIDE = 0 + constexpr bool use_tf32 = true; CUBLAS_CALL_THROW(cublasGemmHelper( params_.cublas_handle, CUBLAS_OP_N, @@ -69,7 +72,8 @@ class GemmBenchmark : public IKernelExplorer { &zero, params_.output_, params_.n_, - device_prop_)); + device_prop_, + use_tf32)); } private: @@ -79,11 +83,11 @@ class GemmBenchmark : public IKernelExplorer { cudaDeviceProp device_prop_; }; -#define REGISTER_OP(name, type) \ - py::class_>(m, #name "_" #type) \ +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ .def(py::init()) \ - .def("SetRepeats", &name::SetRepeats) \ - .def("Profile", &name::Profile) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ .def("Run", &name::Run); KE_REGISTER(m) { diff --git a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc index 09baf8def05f6..31ef62e69bb88 100644 --- a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc @@ -433,7 +433,8 @@ static void RunModelWithRandomInput( std::vector token_offset_dims{batch_size, sequence_length}; std::vector cum_seq_len_dims{batch_size + 1}; - float gpu_threshold = is_float16 ? 0.15f : 0.005f; + // TF32 in SM >= 80 is enabled by default, need larger threshold for float when TF32 is enabled. + float gpu_threshold = is_float16 ? 0.15f : (HasCudaEnvironment(800) ? 0.05f : 0.005f); gpu_threshold *= sequence_length > 1024 ? 4.0f : 1.0f; // threshold should increase with sequence length bool enable_cuda = HasCudaEnvironment(is_float16 ? 530 : 0); if (enable_cuda) { diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 5b41806b646af..91b6c71e735a8 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -428,6 +428,8 @@ def test_get_and_set_option_with_values(option_name, option_values): test_get_and_set_option_with_values("tunable_op_max_tuning_duration_ms", ["-1", "1"]) + test_get_and_set_option_with_values("use_tf32", ["1", "0"]) + option["gpu_external_alloc"] = "0" option["gpu_external_free"] = "0" option["gpu_external_empty_cache"] = "0" From df5c6718bd67bdc59fb77fbd4a6ccdcaf902145c Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Tue, 6 Feb 2024 14:54:06 -0800 Subject: [PATCH 049/207] Remove iOS simulator max runtime version limit. (#19396) --- tools/ci_build/github/apple/get_simulator_device_info.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tools/ci_build/github/apple/get_simulator_device_info.py b/tools/ci_build/github/apple/get_simulator_device_info.py index 2a36418bac9cb..7de9aa13912e0 100755 --- a/tools/ci_build/github/apple/get_simulator_device_info.py +++ b/tools/ci_build/github/apple/get_simulator_device_info.py @@ -138,13 +138,11 @@ def runtime_id_and_device_pair_key(runtime_id_and_device_pair): def main(): parser = argparse.ArgumentParser(description="Gets simulator info from Xcode and prints it in JSON format.") - _ = parser.parse_args() # no args yet + parser.add_argument("--max-runtime-version", help="The maximum runtime version to allow.") + args = parser.parse_args() info = get_simulator_device_info( - # The macOS-13 hosted agent image has iOS 17 which is currently in beta. Limit it to 16.4 for now. - # See https://github.com/actions/runner-images/issues/8023 - # TODO Remove max_runtime_version limit. - max_runtime_version_str="16.4", + max_runtime_version_str=args.max_runtime_version, ) print(json.dumps(info, indent=2)) From 91b2e660fe4b5ad70a7b6fde181cbfc9336e38a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Wed, 7 Feb 2024 02:01:26 +0100 Subject: [PATCH 050/207] [Build] fix: missing nvcc flags when compiling with unittests (#19308) When configured using the following CMake ops Clion is not able to configure due to checking with `nvcc ... --dryrun tmp.cu`: ``` cmake -G Ninja -Donnxruntime_USE_TENSORRT="ON" -Donnxruntime_USE_CUDA="ON" -Donnxruntime_USE_CUDA_NHWC_OPS="ON" -DCMAKE_CUDA_ARCHITECTURES="native" -Donnxruntime_NVCC_THREADS=1 -Donnxruntime_ENABLE_NVTX_PROFILE="ON" -Donnxruntime_USE_TENSORRT_BUILTIN_PARSER="ON" -DCMAKE_CUDA_COMPILER_LAUNCHER="ccache" -Donnxruntime_BUILD_UNIT_TESTS="ON" -Donnxruntime_USE_TRITON_KERNEL=OFF -Donnxruntime_USE_FLASH_ATTENTION=OFF ``` Without building the unittests everything works fine. I believe my changes only follow the logic that is actually desired. If `NVCC_HAS_STRICT_ALIASING` is set to false it should not be possible to add this as a CUDA flag. Same is true for `HAS_NOERROR` as seen in `adjust_global_compile_flags.cmake` --- cmake/CMakeLists.txt | 5 ++++- cmake/onnxruntime_unittests.cmake | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 34e7687e91876..90fe8276ea9c7 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -985,9 +985,12 @@ function(onnxruntime_set_compile_flags target_name) foreach(FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target_name} PRIVATE "$<$:SHELL:--compiler-options ${FLAG}>") endforeach() - if ((NVCC_HAS_STRICT_ALIASING AND "${target_name}" MATCHES "cuda") OR (HAS_STRICT_ALIASING AND NOT "${target_name}" MATCHES "cuda")) + if (NVCC_HAS_STRICT_ALIASING AND "${target_name}" MATCHES "cuda") target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") endif() + if (HAS_STRICT_ALIASING AND NOT "${target_name}" MATCHES "cuda") + target_compile_options(${target_name} PRIVATE "$<$:-Wno-strict-aliasing>") + endif() endif() if (onnxruntime_USE_ROCM) # flags are detected with CXX language mode, some flags are not supported with hipclang diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 5b4a007d6b974..308caad296831 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -111,7 +111,9 @@ function(AddTest) target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") - target_compile_options(${_UT_TARGET} PRIVATE "-Wno-error=uninitialized") + if (${HAS_NOERROR}) + target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Wno-error=uninitialized>") + endif() endif() set(TEST_ARGS ${_UT_TEST_ARGS}) From 302d4be7d935a6679a43ced5b3524b02f0e02da1 Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Tue, 6 Feb 2024 17:10:55 -0800 Subject: [PATCH 051/207] [DML EP] Fix external data unpacking (#19415) ### Description This change https://github.com/microsoft/onnxruntime/commit/55a669409a93e1ed8d8e7e27b17048ce8a0cfebb didn't take into account external data when unpacking initializer, and therefore crashes when trying to unpack them. --- .../DmlExecutionProvider/src/GraphDescBuilder.cpp | 15 ++++++++++----- .../src/MLOperatorAuthorImpl.cpp | 12 +++++++++--- .../src/MLOperatorAuthorImpl.h | 1 + 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index c6a15e76f4736..2456b396de3f6 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -344,20 +344,25 @@ namespace Dml::GraphDescBuilder dmlFusedNodeInputIndex < isConstGpuGraphInputCount && isConstGpuGraphInput[dmlFusedNodeInputIndex]) { - // This is a highly inefficient approach to generating constant nodes. It duplicates constant data - // across the graph input as well as every consumer's unique constant node. However it is currently + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently // only used for small inputs. uint32_t c_maxConstNodeDataSize = 8; - ComPtr constantInput = constantCpuGraphInputGetter(arg->Name()); auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; std::vector toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors(); DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; + ComPtr constantInput; - if (constantInput && tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) + if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) { - // The tensor description's size should be no larger than the constant input unless it was rounded to + constantInput = constantCpuGraphInputGetter(arg->Name()); + } + + if (constantInput) + { + // The tensor description's size should be no larger than the constant input unless it was rounded to // the required alignment. assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index dbd06abf82f72..d524780de71b8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter } ORT_CATCH_RETURN } - + template HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept { @@ -1168,7 +1168,7 @@ namespace Windows::AI::MachineLearning::Adapter m_requiredConstantCpuInputs.begin(), m_requiredConstantCpuInputs.end(), inputIndex) != m_requiredConstantCpuInputs.end(); - + // This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present. ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant); } @@ -1562,7 +1562,13 @@ namespace Windows::AI::MachineLearning::Adapter OnnxTensorWrapper::OnnxTensorWrapper(onnx::TensorProto* impl, const onnxruntime::Path& modelPath) : m_impl(impl) { // The tensor may be stored as raw data or in typed fields. - if (impl->has_raw_data()) + if (impl->data_location() == onnx::TensorProto_DataLocation_EXTERNAL) + { + THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*impl, modelPath, m_unpackedExternalTensor)); + m_dataPtr = reinterpret_cast(m_unpackedExternalTensor.data()); + m_tensorByteSize = m_unpackedExternalTensor.size(); + } + else if (impl->has_raw_data()) { m_dataPtr = reinterpret_cast(impl->mutable_raw_data()->data()); m_tensorByteSize = impl->raw_data().size(); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index 6530d89d895e7..59e253e88457a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -309,6 +309,7 @@ class OnnxTensorWrapper : public WRL::Base, public Closable private: size_t m_tensorByteSize = 0; std::unique_ptr m_unpackedTensor; + std::vector m_unpackedExternalTensor; std::byte* m_dataPtr = nullptr; // Lifetime is managed by the caller and guaranteed to outlive this class From 36d223676b419a8e6868d38f9363e236edbd0690 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 7 Feb 2024 14:01:51 +1000 Subject: [PATCH 052/207] Use GraphViewer.IsConstantInitializer in NNAPI EP. (#19401) ### Description An overridable initializer should not have a fixed value included in an NNAPI model as it could be changed at runtime. The current check doesn't include validating that the initializer is constant. I was updating GetClipMinMax as part of adding CoreML EP ML Program support, and in order to make both CoreML and NNAPI do the more correct thing of using IsConstantInitializer this set of changes was required. ### Motivation and Context Make NNAPI and CoreML EPs more correct. --------- Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> --- include/onnxruntime/core/graph/graph.h | 1 - include/onnxruntime/core/graph/graph_viewer.h | 3 +- onnxruntime/core/graph/graph_viewer.cc | 2 + .../coreml/builders/impl/clip_op_builder.cc | 5 +- .../nnapi/nnapi_builtin/builders/helper.cc | 37 ++++---- .../nnapi/nnapi_builtin/builders/helper.h | 10 +-- .../builders/impl/LRN_op_builder.cc | 4 +- .../builders/impl/base_op_builder.cc | 32 ++++--- .../builders/impl/base_op_builder.h | 11 ++- .../builders/impl/batchnorm_op_builder.cc | 20 ++--- .../builders/impl/binary_op_builder.cc | 16 ++-- .../builders/impl/cast_op_builder.cc | 4 +- .../builders/impl/clip_op_builder.cc | 8 +- .../builders/impl/concat_op_builder.cc | 22 ++--- .../builders/impl/conv_op_builder.cc | 28 +++---- .../builders/impl/depthtospace_op_builder.cc | 4 +- .../impl/dequantizelinear_op_builder.cc | 6 +- .../builders/impl/flatten_op_builder.cc | 4 +- .../builders/impl/gather_op_builder.cc | 8 +- .../builders/impl/gemm_op_builder.cc | 32 ++++--- .../builders/impl/leakyrelu_op_builder.cc | 4 +- .../builders/impl/minmax_op_builder.cc | 4 +- .../builders/impl/pad_op_builder.cc | 17 ++-- .../builders/impl/pool_op_builder.cc | 29 ++++--- .../impl/quantizelinear_op_builder.cc | 6 +- .../builders/impl/reduction_op_builder.cc | 6 +- .../builders/impl/reshape_op_builder.cc | 30 +++---- .../builders/impl/resize_op_builder.cc | 68 ++++++++------- .../builders/impl/slice_op_builder.cc | 12 +-- .../builders/impl/softmax_op_builder.cc | 22 +++-- .../builders/impl/split_op_builder.cc | 15 ++-- .../builders/impl/squeeze_op_builder.cc | 8 +- .../builders/impl/transpose_op_builder.cc | 15 ++-- .../builders/impl/unary_op_builder.cc | 30 ++++--- .../builders/impl/unsqueeze_op_builder.cc | 8 +- .../nnapi_builtin/builders/model_builder.cc | 13 ++- .../nnapi/nnapi_builtin/builders/op_builder.h | 2 +- .../builders/op_builder_helpers.cc | 69 ++++++++------- .../builders/op_builder_helpers.h | 10 +-- .../core/providers/shared/utils/utils.cc | 84 ++++++++++++------- .../core/providers/shared/utils/utils.h | 14 +++- .../webnn/builders/impl/clip_op_builder.cc | 5 +- 42 files changed, 382 insertions(+), 346 deletions(-) diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 22827d43b200f..b9b8a25286b7b 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -753,7 +753,6 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi cannot be overridden at runtime. If the initializer is not found or is not constant, a nullptr is returned. @param check_outer_scope If true and the graph is a subgraph, check ancestor graph/s for 'name' if not found in 'graph'. - @remarks check_outer_scope of true is not supported in a minimal build */ const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const; diff --git a/include/onnxruntime/core/graph/graph_viewer.h b/include/onnxruntime/core/graph/graph_viewer.h index 3cdbb07099cab..1023d50310181 100644 --- a/include/onnxruntime/core/graph/graph_viewer.h +++ b/include/onnxruntime/core/graph/graph_viewer.h @@ -165,7 +165,8 @@ class GraphViewer { if a const initializer is part of the underlying Graph but not part of this GraphViewer, it will still be returned instead of nullptr */ - const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, bool check_outer_scope) const; + const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name, + bool check_outer_scope = true) const; /** Get the Node containing this Graph if IsSubgraph is true. Returns nullptr otherwise. */ const Node* ParentNode() const noexcept { return graph_->ParentNode(); } diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index cf78040ea5ac6..acf7b3a16541f 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -212,6 +212,8 @@ const std::string& GraphViewer::Description() const noexcept { bool GraphViewer::GetInitializedTensor(const std::string& tensor_name, const ONNX_NAMESPACE::TensorProto*& value) const { + value = nullptr; + // if we are using filtered subgraph, the initializer has to be part of the subgraph if (filter_info_ != nullptr && filtered_initializers_.find(tensor_name) == filtered_initializers_.cend()) return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index 3a3f89d24c7d8..a298a8d12c741 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -50,7 +50,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input_name = node.InputDefs()[0]->Name(); const auto& output_name = node.OutputDefs()[0]->Name(); float min, max; - ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetInitializerTensors(), node, min, max, logger), "GetClipMinMax failed"); + ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetGraphViewer(), node, min, max, logger), "GetClipMinMax failed"); bool has_min = min != std::numeric_limits::lowest(); bool has_max = max != std::numeric_limits::max(); @@ -132,8 +132,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool ClipOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { float min, max; - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - return GetClipMinMax(initializers, node, min, max, logger); + return GetClipMinMax(input_params.graph_viewer, node, min, max, logger); } void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc index 3209ad734fa20..0b32508a5bb38 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.cc @@ -184,9 +184,8 @@ bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit) { return true; } -common::Status GetQuantizationScaleAndZeroPoint( - const InitializedTensorSet& initializers, const NodeUnitIODef& io_def, const Path& model_path, - float& scale, int32_t& zero_point) { +common::Status GetQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, + const Path& model_path, float& scale, int32_t& zero_point) { scale = 0.0f; zero_point = 0; @@ -198,14 +197,24 @@ common::Status GetQuantizationScaleAndZeroPoint( const auto& quant_param = *io_def.quant_param; { // get the scale const auto& name = quant_param.scale.Name(); - Initializer unpacked_tensor(*initializers.at(name), model_path); + const auto* s = graph_viewer.GetConstantInitializer(name); + if (!s) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, name, " is not a constant initializer"); + }; + + Initializer unpacked_tensor(*s, model_path); // The scale should be one or more floats scale = unpacked_tensor.DataAsSpan()[0]; } if (quant_param.zero_point) { // get the zero point if it's there const auto& name = quant_param.zero_point->Name(); - Initializer unpacked_tensor(*initializers.at(name), model_path); + const auto* zp = graph_viewer.GetConstantInitializer(name); + if (!zp) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, name, " is not a constant initializer"); + }; + + Initializer unpacked_tensor(*zp, model_path); // Onnx quantization uses uint8 [int8 not yet supported], need to cast to int32_t used by NNAPI zero_point = static_cast(unpacked_tensor.DataAsByteSpan()[0]); } @@ -213,13 +222,13 @@ common::Status GetQuantizationScaleAndZeroPoint( return Status::OK(); } -common::Status GetQuantizationScaleAndZeroPoint( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, const std::string& name, - float& scale, int32_t& zero_point, ArgType arg_type) { +common::Status GetQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const std::string& name, float& scale, int32_t& zero_point, + ArgType arg_type) { const auto& io_defs = arg_type == ArgType::kInput ? node_unit.Inputs() : node_unit.Outputs(); for (const auto& io_def : io_defs) { if (io_def.node_arg.Name() == name) - return GetQuantizationScaleAndZeroPoint(initializers, io_def, node_unit.ModelPath(), + return GetQuantizationScaleAndZeroPoint(graph_viewer, io_def, node_unit.ModelPath(), scale, zero_point); } @@ -348,7 +357,7 @@ bool IsNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph_viewer, } const auto* op_builder = op_builder_it->second; - return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node_unit, params); + return op_builder->IsOpSupported(graph_viewer, node_unit, params); } bool IsNodeSupportedInGroup(const NodeUnit& node_unit, const GraphViewer& graph_viewer, @@ -381,11 +390,11 @@ uint32_t ShapeSize(const Shape& shape, size_t begin_idx, size_t end_idx) { SafeInt{1}, std::multiplies>{}); } -bool CheckIsInitializer(const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const std::string& input_name, const char* input_description) { - if (!Contains(initializers, input_name)) { +bool CheckIsConstantInitializer(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const std::string& input_name, const char* input_description) { + if (!graph_viewer.GetConstantInitializer(input_name)) { LOGS_DEFAULT(VERBOSE) << input_description << " of " << node_unit.Name() << "of type [" - << node_unit.OpType() << "] must be an initializer tensor"; + << node_unit.OpType() << "] must be a constant initializer"; return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h index 766034b3decea..a606b8aceb63d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/helper.h @@ -132,11 +132,11 @@ bool IsQuantizedBinaryOp(QuantizedOpType quant_op_type); bool HasValidBinaryOpQuantizedInputTypes(const NodeUnit& node_unit); common::Status GetQuantizationScaleAndZeroPoint( - const InitializedTensorSet& initializers, const NodeUnitIODef& io_def, const Path& model_path, + const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const Path& model_path, float& scale, int32_t& zero_point); common::Status GetQuantizationScaleAndZeroPoint( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, const std::string& name, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const std::string& name, float& scale, int32_t& zero_point, ArgType arg_type = ArgType::kInput); // Get Shape/Type of a NodeArg @@ -167,11 +167,11 @@ inline uint32_t ShapeSize(const Shape& shape) { return ShapeSize(shape, 0, shape.size()); } -// Check the given input is an initializer tensor +// Check the given input is a constant initializer // input_name is the name of the initializer // input_description is the string describing the input in the output message (if any) -bool CheckIsInitializer(const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const std::string& input_name, const char* input_description); +bool CheckIsConstantInitializer(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const std::string& input_name, const char* input_description); // Convert ONNX int64 input to NNAPI int32 type input and optionally handle negative axis if needed // Mostly used in handling `axes` input for now diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/LRN_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/LRN_op_builder.cc index 00bca4001326c..91cad034d8854 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/LRN_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/LRN_op_builder.cc @@ -29,7 +29,7 @@ class LRNOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -91,7 +91,7 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No // Operator support related -bool LRNOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool LRNOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc index 7797e0a47caaf..adc79576272ab 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/graph/graph_viewer.h" #include "core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h" namespace onnxruntime { @@ -11,10 +12,11 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node const auto is_ext_initializer = [&](const NodeArg& node_arg) { const auto& input_name(node_arg.Name()); - if (!Contains(initializers, input_name)) + const auto initializer = initializers.find(input_name); + if (initializer == initializers.end()) return false; - const auto& tensor = *initializers.at(input_name); + const auto& tensor = *initializer->second; if (tensor.has_data_location() && tensor.data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { LOGS_DEFAULT(VERBOSE) << "Initializer [" << input_name @@ -51,8 +53,12 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const NodeU model_builder.GetEffectiveFeatureLevel(), model_builder.UseNCHW(), }; - ORT_RETURN_IF_NOT(IsOpSupported(model_builder.GetInitializerTensors(), node_unit, params), - "Unsupported operator ", node_unit.OpType()); + + // We checked supported in IExecutionProvider::GetCapability. + // Checking again in AddToModelBuilder which is called in IExecutionProvider::Compile is redundant. + // ORT_RETURN_IF_NOT(IsOpSupported(model_builder.GetGraphViewer(), node_unit, params), + // "Unsupported operator ", node_unit.OpType()); + #ifndef NDEBUG model_builder.SetDebugCurrentOnnxNodeIndex(node_unit.Index()); #endif @@ -64,7 +70,7 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const NodeU // Operator support related -bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool BaseOpBuilder::IsOpSupported(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { int32_t required_feature_level = GetMinSupportedNNAPIFeatureLevel(node_unit, params); if (required_feature_level > params.android_feature_level) { @@ -77,20 +83,20 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons if (!IsNodeUnitTypeSupported(node_unit)) return false; - if (!HasSupportedInputOutputs(initializers, node_unit, params)) + if (!HasSupportedInputOutputs(graph_viewer, node_unit, params)) return false; // We do not support external initializers for now - if (HasExternalInitializer(initializers, node_unit)) + if (HasExternalInitializer(graph_viewer.GetAllInitializedTensors(), node_unit)) return false; if (!HasSupportedOpSet(node_unit)) return false; - return IsOpSupportedImpl(initializers, node_unit, params); + return IsOpSupportedImpl(graph_viewer, node_unit, params); } -bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool BaseOpBuilder::HasSupportedInputOutputs(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { // We do not support unknown(null) input shape auto has_supported_shape = [](const NodeArg& node_arg, const std::string& name, const std::string& op_type) { @@ -128,12 +134,12 @@ bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initial return false; } } - return HasSupportedInputOutputsImpl(initializers, node_unit, params); + + return HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); } -bool BaseOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const { +bool BaseOpBuilder::HasSupportedInputOutputsImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, + const OpSupportCheckParams& /* params */) const { // We only check the type of input 0 by default // specific op builder can override this const auto& input = node_unit.Inputs()[0].node_arg; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h index 339ccd67f33e3..6a54bf7bdb938 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.h @@ -52,11 +52,11 @@ class BaseOpBuilder : public IOpBuilder { // Operator support related public: - bool IsOpSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupported(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; protected: - virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& /* node_unit */, + virtual bool IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& /* node_unit */, const OpSupportCheckParams& /* params */) const { return true; } @@ -68,9 +68,8 @@ class BaseOpBuilder : public IOpBuilder { return ANEURALNETWORKS_FEATURE_LEVEL_1; } - virtual bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const OpSupportCheckParams& params) const; + virtual bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const; virtual int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const { return 1; } virtual int GetMaxSupportedOpSet(const NodeUnit& /* node_unit */) const { return 19; } @@ -82,7 +81,7 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const NodeUnit& node_unit) const; - bool HasSupportedInputOutputs(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool HasSupportedInputOutputs(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const; }; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/batchnorm_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/batchnorm_op_builder.cc index 3add0ac26c0d4..75a66d3a14643 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/batchnorm_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/batchnorm_op_builder.cc @@ -33,7 +33,7 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; // BatchNormalization opset 6- has unsupported attributes @@ -127,7 +127,7 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu // Operator support related -bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { if (node_unit.Outputs().size() != 1) { LOGS_DEFAULT(VERBOSE) << "Your onnx model may be in training mode, please export " @@ -158,20 +158,20 @@ bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& const auto& b_name = inputs[2].node_arg.Name(); const auto& mean_name = inputs[3].node_arg.Name(); const auto& var_name = inputs[4].node_arg.Name(); - if (!Contains(initializers, scale_name)) { - LOGS_DEFAULT(VERBOSE) << "Scale of BN must be known"; + if (!graph_viewer.GetConstantInitializer(scale_name)) { + LOGS_DEFAULT(VERBOSE) << "Scale of BN must be a constant initializer"; return false; } - if (!Contains(initializers, b_name)) { - LOGS_DEFAULT(VERBOSE) << "B of BN must be known"; + if (!graph_viewer.GetConstantInitializer(b_name)) { + LOGS_DEFAULT(VERBOSE) << "B of BN must be a constant initializer"; return false; } - if (!Contains(initializers, mean_name)) { - LOGS_DEFAULT(VERBOSE) << "Mean of BN must be known"; + if (!graph_viewer.GetConstantInitializer(mean_name)) { + LOGS_DEFAULT(VERBOSE) << "Mean of BN must be a constant initializer"; return false; } - if (!Contains(initializers, var_name)) { - LOGS_DEFAULT(VERBOSE) << "Var of BN must be known"; + if (!graph_viewer.GetConstantInitializer(var_name)) { + LOGS_DEFAULT(VERBOSE) << "Var of BN must be a constant initializer"; return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/binary_op_builder.cc index dce1a7c8659bf..5599fbdc69bdd 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/binary_op_builder.cc @@ -34,10 +34,10 @@ class BinaryOpBuilder : public BaseOpBuilder { private: int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int GetMinSupportedOpSet(const NodeUnit& node_unit) const override; @@ -95,7 +95,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const if (is_quant_op) { ORT_RETURN_IF_ERROR(GetBinaryOpQuantizationScaleAndZeroPoint( - model_builder.GetInitializerTensors(), node_unit, + model_builder.GetGraphViewer(), node_unit, a_scale, b_scale, y_scale, a_zero_point, b_zero_point, y_zero_point)); } @@ -163,22 +163,22 @@ int BinaryOpBuilder::GetMinSupportedOpSet(const NodeUnit& node_unit) const { } bool BinaryOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { bool is_quantized_op = IsQuantizedOp(node_unit); bool is_pow = node_unit.OpType() == "Pow"; if (!is_quantized_op && !is_pow) - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); if (is_quantized_op) { // QLinearAdd/QDQAdd/QLinearMul/QDQMul if (!HasValidBinaryOpQuantizedInputTypes(node_unit)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0, 1}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0, 1}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; } @@ -203,7 +203,7 @@ bool BinaryOpBuilder::HasSupportedInputOutputsImpl( return true; } -bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool BinaryOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& op_type(node_unit.OpType()); const auto& inputs = node_unit.Inputs(); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/cast_op_builder.cc index b31ee484dc5a2..9059de817e210 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/cast_op_builder.cc @@ -29,7 +29,7 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -70,7 +70,7 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N return Status::OK(); } -bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool CastOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { NodeAttrHelper helper(node_unit); const auto to = helper.Get("to", 0); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/clip_op_builder.cc index b3e294d2f0845..9821d9267c71f 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/clip_op_builder.cc @@ -32,7 +32,7 @@ class ClipOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -64,7 +64,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } float min, max; - GetClipMinMax(model_builder.GetInitializerTensors(), node_unit.GetNode(), min, max, + GetClipMinMax(model_builder.GetGraphViewer(), node_unit.GetNode(), min, max, logging::LoggingManager::DefaultLogger()); int32_t op_code; @@ -85,10 +85,10 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // Operator support related -bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool ClipOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { float min, max; - if (!GetClipMinMax(initializers, node_unit.GetNode(), min, max, logging::LoggingManager::DefaultLogger())) + if (!GetClipMinMax(graph_viewer, node_unit.GetNode(), min, max, logging::LoggingManager::DefaultLogger())) return false; // We only supoort relu6 or relu1 diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/concat_op_builder.cc index 2bf8f07e26fd4..a8394faec51be 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/concat_op_builder.cc @@ -32,11 +32,11 @@ class ConcatOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } @@ -113,7 +113,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const float scale = 0.0f; int32_t zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - model_builder.GetInitializerTensors(), node_unit.Inputs()[i], node_unit.ModelPath(), + model_builder.GetGraphViewer(), node_unit.Inputs()[i], node_unit.ModelPath(), scale, zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, scale, zero_point)); @@ -128,7 +128,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const int32_t y_zero_point = operand_types.at(input0).operandType.zeroPoint; if (is_quant_op) { ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - model_builder.GetInitializerTensors(), node_unit.Outputs()[0], node_unit.ModelPath(), + model_builder.GetGraphViewer(), node_unit.Outputs()[0], node_unit.ModelPath(), y_scale, y_zero_point)); } @@ -151,7 +151,7 @@ bool ConcatOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQConcat; } -bool ConcatOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool ConcatOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -168,7 +168,7 @@ bool ConcatOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ } bool ConcatOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { const auto& op_type = node_unit.OpType(); const auto& op_name = node_unit.Name(); @@ -188,11 +188,11 @@ bool ConcatOpBuilder::HasSupportedInputOutputsImpl( if (IsQuantizedOp(node_unit)) { std::vector input_indices(input_size); std::iota(input_indices.begin(), input_indices.end(), 0); - if (!IsQuantizedIOSupported(initializers, node_unit, input_indices, params, ArgType::kInput)) { + if (!IsQuantizedIOSupported(graph_viewer, node_unit, input_indices, params, ArgType::kInput)) { return false; } - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) { + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) { return false; } @@ -203,7 +203,7 @@ bool ConcatOpBuilder::HasSupportedInputOutputsImpl( size_t input_idx = 0; auto status = GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[input_idx], node_unit.ModelPath(), + graph_viewer, node_unit.Inputs()[input_idx], node_unit.ModelPath(), input_scales[input_idx], input_zps[input_idx]); if (!status.IsOK()) { @@ -214,7 +214,7 @@ bool ConcatOpBuilder::HasSupportedInputOutputsImpl( } for (++input_idx; input_idx < input_size; ++input_idx) { - if (!HasRequiredScaleAndZeroPoint(initializers, + if (!HasRequiredScaleAndZeroPoint(graph_viewer, MakeString("Op [", op_type, "] name [", op_name, "] input ", input_idx), node_unit.Inputs()[input_idx], node_unit.ModelPath(), @@ -225,7 +225,7 @@ bool ConcatOpBuilder::HasSupportedInputOutputsImpl( } // NNAPI (28-) requires the output scale and zp be the same as the input 0 - if (!HasRequiredScaleAndZeroPoint(initializers, + if (!HasRequiredScaleAndZeroPoint(graph_viewer, MakeString("Op [", op_type, "] name [", op_name, "]'s output 0"), node_unit.Outputs()[0], node_unit.ModelPath(), input_scales[0] /* required_scale */, diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/conv_op_builder.cc index 5b8bbd338a13d..5477cd16f9c01 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/conv_op_builder.cc @@ -33,7 +33,7 @@ class ConvOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -41,9 +41,8 @@ class ConvOpBuilder : public BaseOpBuilder { return params.use_nchw ? ANEURALNETWORKS_FEATURE_LEVEL_3 : ANEURALNETWORKS_FEATURE_LEVEL_2; } - bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const override; + bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; @@ -279,19 +278,19 @@ bool ConvOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { } bool ConvOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { if (!IsQuantizedOp(node_unit)) - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); // QLinearConv only supports input of uint8 for now if (!HasValidBinaryOpQuantizedInputTypes(node_unit)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0, 1}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0, 1}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; return true; @@ -299,7 +298,7 @@ bool ConvOpBuilder::HasSupportedInputOutputsImpl( // Operator support related -bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool ConvOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { const auto& op_type = node_unit.OpType(); bool is_quant_conv = IsQuantizedOp(node_unit); @@ -314,8 +313,9 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, NodeAttrHelper helper(node_unit); const auto group = helper.Get("group", 1); const auto weight_name = inputs[1].node_arg.Name(); - if (Contains(initializers, weight_name)) { - const auto& tensor = *initializers.at(weight_name); + const auto* weight = graph_viewer.GetConstantInitializer(weight_name); + if (weight) { + const auto& tensor = *weight; if (tensor.dims().size() != 4) { LOGS_DEFAULT(VERBOSE) << "Only conv 2d is supported."; return false; @@ -335,13 +335,13 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } } else { - LOGS_DEFAULT(VERBOSE) << "The weight of convolution must be known"; + LOGS_DEFAULT(VERBOSE) << "The weight of convolution must be a constant initializer"; return false; } if (is_quant_conv) { - if (inputs.size() > 2 && !Contains(initializers, inputs[2].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "Bias of QLinearConv must be known"; + if (inputs.size() > 2 && !graph_viewer.GetConstantInitializer(inputs[2].node_arg.Name())) { + LOGS_DEFAULT(VERBOSE) << "Bias of QLinearConv must be a constant initializer"; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/depthtospace_op_builder.cc index 649f1e1cff2b7..ef8709641e2d0 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/depthtospace_op_builder.cc @@ -29,7 +29,7 @@ class DepthToSpaceOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -66,7 +66,7 @@ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Operator support related -bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { NodeAttrHelper helper(node_unit); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/dequantizelinear_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/dequantizelinear_op_builder.cc index b2d89ffecdca4..7d0e04fbd7b0e 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/dequantizelinear_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/dequantizelinear_op_builder.cc @@ -38,9 +38,9 @@ class DequantizeLinearOpBuilder : public BaseOpBuilder { } bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override { - return IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput); + return IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput); } }; @@ -61,7 +61,7 @@ Status DequantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil float scale = 0.0; int32_t zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - model_builder.GetInitializerTensors(), node_unit.Inputs()[0], node_unit.ModelPath(), scale, zero_point)); + model_builder.GetGraphViewer(), node_unit.Inputs()[0], node_unit.ModelPath(), scale, zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, scale, zero_point)); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/flatten_op_builder.cc index 065b9638bdf64..b5e9c011990ce 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/flatten_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/flatten_op_builder.cc @@ -44,7 +44,7 @@ class FlattenOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -70,7 +70,7 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons // Operator support related -bool FlattenOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool FlattenOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gather_op_builder.cc index ac8970f19df06..d6da9181b5a3d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gather_op_builder.cc @@ -36,7 +36,7 @@ class GatherOpBuilder : public BaseOpBuilder { return ANEURALNETWORKS_FEATURE_LEVEL_3; } - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -133,7 +133,7 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const // Operator support related -bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool GatherOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); Shape input_shape; @@ -166,8 +166,8 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; if (indices_type != ONNX_NAMESPACE::TensorProto_DataType_INT32) { - if (!Contains(initializers, indices_name)) { - LOGS_DEFAULT(VERBOSE) << "Indices of Gather must be known."; + if (!graph_viewer.GetConstantInitializer(indices_name)) { + LOGS_DEFAULT(VERBOSE) << "Indices of Gather must be a constant initializer."; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gemm_op_builder.cc index 9b3003d472b02..8488f7cc74a6e 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/gemm_op_builder.cc @@ -69,11 +69,10 @@ class GemmOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const override; + bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; int GetMinSupportedOpSet(const NodeUnit& node_unit) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } @@ -261,21 +260,20 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // Operator support related -bool GemmOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const OpSupportCheckParams& params) const { +bool GemmOpBuilder::HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const { if (!IsQuantizedOp(node_unit)) { - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); } // QLinearMatMul/QDQGemm/QDQMatMul if (!HasValidBinaryOpQuantizedInputTypes(node_unit)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0, 1}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0, 1}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; return true; @@ -295,7 +293,7 @@ bool GemmOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return IsQuantizedGemm(GetQuantizedOpType(node_unit)); } -bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool GemmOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { // check batch matmul first, then fall back to checking single gemm/matmul { @@ -355,8 +353,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } - if (transB == 0 && !Contains(initializers, inputs[1].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "B of Gemm must be known if transB != 1"; + if (transB == 0 && !graph_viewer.GetConstantInitializer(inputs[1].node_arg.Name())) { + LOGS_DEFAULT(VERBOSE) << "B of Gemm must be a constant initializer if transB != 1"; return false; } @@ -380,8 +378,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } else if (op_type == "MatMul" || is_qlinear_matmul) { // Only support A*B B is an initializer - if (!Contains(initializers, inputs[1].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "B of MatMul must be known"; + if (!graph_viewer.GetConstantInitializer(inputs[1].node_arg.Name())) { + LOGS_DEFAULT(VERBOSE) << "B of MatMul must be a constant initializer"; return false; } } else { @@ -389,8 +387,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } if (is_quant_gemm) { - if (inputs.size() > 2 && !Contains(initializers, inputs[2].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "Bias of QDQ Gemm must be known"; + if (inputs.size() > 2 && !graph_viewer.GetConstantInitializer(inputs[2].node_arg.Name())) { + LOGS_DEFAULT(VERBOSE) << "Bias of QDQ Gemm must be a constant initializer"; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/leakyrelu_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/leakyrelu_op_builder.cc index 3db63a756ab1a..6a633c443c9e5 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/leakyrelu_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/leakyrelu_op_builder.cc @@ -27,7 +27,7 @@ class LeakyReluOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; // LeakyRelu opset 6- has unsupported attributes @@ -111,7 +111,7 @@ Status LeakyReluOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Operator support related -bool LeakyReluOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /*initializers*/, const NodeUnit& node_unit, +bool LeakyReluOpBuilder::IsOpSupportedImpl(const GraphViewer& /*graph_viewer*/, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/minmax_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/minmax_op_builder.cc index 522f389ae62a0..aeadbd17053cf 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/minmax_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/minmax_op_builder.cc @@ -37,7 +37,7 @@ class MinMaxOpBuilder : public BaseOpBuilder { // Min/Max opset 5- uses consumed_inputs attribute which is not supported for now int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 6; } - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -53,7 +53,7 @@ Status MinMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const // Operator support related -bool MinMaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool MinMaxOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { // TODO: support 2+ inputs for Min/Max op if (node_unit.Inputs().size() != 2) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc index 11d37f9036b11..b0404ebec0583 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pad_op_builder.cc @@ -45,7 +45,7 @@ class PadOpBuilder : public BaseOpBuilder { return 11; } - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -115,7 +115,7 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const No return model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type}); } -bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool PadOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); @@ -152,14 +152,13 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c // only support if `pads` input is known and does not contain negative values { - const auto pads_initializer_it = initializers.find(inputs[1].node_arg.Name()); - if (pads_initializer_it == initializers.end()) { - LOGS_DEFAULT(VERBOSE) << "pads must be known"; + const auto* pads_initializer = graph_viewer.GetConstantInitializer(inputs[1].node_arg.Name()); + if (!pads_initializer) { + LOGS_DEFAULT(VERBOSE) << "pads must be a constant initializer"; return false; } - const ONNX_NAMESPACE::TensorProto& pads_initializer = *pads_initializer_it->second; - Initializer unpacked_tensor(pads_initializer); + Initializer unpacked_tensor(*pads_initializer); auto tensor_data = unpacked_tensor.DataAsSpan(); for (size_t i = 0; i < unpacked_tensor.size(); i++) { if (tensor_data[i] < 0) { @@ -173,8 +172,8 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c // only support if `constant_value` input is known // Note: Could add support for non-constant initializer later. Then we need to ensure it is a scalar (with shape []). if (inputs.size() > 2) { - if (!Contains(initializers, inputs[2].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "constant_value must be known"; + if (!graph_viewer.GetConstantInitializer(inputs[2].node_arg.Name())) { + LOGS_DEFAULT(VERBOSE) << "constant_value must be a constant initializer"; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pool_op_builder.cc index c14568aaccfa3..a2a4786b72ec7 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/pool_op_builder.cc @@ -32,7 +32,7 @@ class PoolOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -40,10 +40,9 @@ class PoolOpBuilder : public BaseOpBuilder { return params.use_nchw ? ANEURALNETWORKS_FEATURE_LEVEL_3 : ANEURALNETWORKS_FEATURE_LEVEL_2; } - bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const OpSupportCheckParams& params) const override; - bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override; + bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; + bool IsNodeUnitTypeSupported(const NodeUnit& node_unit) const override; bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; @@ -116,16 +115,16 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N float y_scale = input_operand_type.operandType.scale; int32_t y_zero_point = input_operand_type.operandType.zeroPoint; if (is_quant_pool) { - const auto& initializers = model_builder.GetInitializerTensors(); + const auto& graph_viewer = model_builder.GetGraphViewer(); float x_scale = 0.0f; int32_t x_zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); + graph_viewer, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); // Verify if the scale and zero point values from onnx input and nnapi input match ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Outputs()[0], node_unit.ModelPath(), y_scale, y_zero_point)); + graph_viewer, node_unit.Outputs()[0], node_unit.ModelPath(), y_scale, y_zero_point)); } InlinedVector input_indices; @@ -171,7 +170,7 @@ bool PoolOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return IsQuantizedPool(GetQuantizedOpType(node_unit)); } -bool PoolOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool PoolOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& op_name = node_unit.Name(); const auto& op_type = node_unit.OpType(); @@ -236,7 +235,7 @@ bool PoolOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, float input_scale = 0.0f; int32_t input_zp = 0; auto status = GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), input_scale, input_zp); + graph_viewer, node_unit.Inputs()[0], node_unit.ModelPath(), input_scale, input_zp); if (!status.IsOK()) { LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name << "] GetQuantizationScaleAndZeroPoint for input_scale/zp failed, message: " @@ -247,7 +246,7 @@ bool PoolOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, float output_scale = 0.0f; int32_t output_zp = 0; status = GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Outputs()[0], node_unit.ModelPath(), output_scale, output_zp); + graph_viewer, node_unit.Outputs()[0], node_unit.ModelPath(), output_scale, output_zp); if (!status.IsOK()) { LOGS_DEFAULT(ERROR) << "Op [" << op_type << "] name [" << op_name << "] GetQuantizationScaleAndZeroPoint for output_scale/zp failed, message: " @@ -274,7 +273,7 @@ bool PoolOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } bool PoolOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { const auto& op_type = node_unit.OpType(); bool is_quant_pool = IsQuantizedOp(node_unit); @@ -282,13 +281,13 @@ bool PoolOpBuilder::HasSupportedInputOutputsImpl( bool is_average_pool = op_type == "AveragePool" || op_type == "QLinearAveragePool"; bool is_quant_average_pool = is_quant_pool && is_average_pool; if (!is_max_pool && !is_quant_average_pool) - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); if (is_quant_average_pool) { - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/quantizelinear_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/quantizelinear_op_builder.cc index 49ff01d27219a..d13b81c2a14b8 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/quantizelinear_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/quantizelinear_op_builder.cc @@ -38,9 +38,9 @@ class QuantizeLinearOpBuilder : public BaseOpBuilder { } bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override { - return IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput); + return IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput); } }; @@ -60,7 +60,7 @@ Status QuantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde float scale = 0.0f; int32_t zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - model_builder.GetInitializerTensors(), node_unit.Outputs()[0], node_unit.ModelPath(), scale, zero_point)); + model_builder.GetGraphViewer(), node_unit.Outputs()[0], node_unit.ModelPath(), scale, zero_point)); Type output_type = Type::TENSOR_QUANT8_ASYMM; const OperandType output_operand_type(output_type, shaper[output], scale, zero_point); diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc index 8d0347673ba56..a6da290753b74 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc @@ -35,7 +35,7 @@ class ReductionOpBuilder : public BaseOpBuilder { private: int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -169,7 +169,7 @@ int32_t ReductionOpBuilder::GetMinSupportedNNAPIFeatureLevel( return ANEURALNETWORKS_FEATURE_LEVEL_3; } -bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool ReductionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); const auto& op(node_unit.OpType()); @@ -190,7 +190,7 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ const bool noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0) != 0; if (inputs.size() > 1 && inputs[1].node_arg.Exists()) { const auto& axes_name = inputs[1].node_arg.Name(); - if (!Contains(initializers, axes_name)) { + if (!graph_viewer.GetConstantInitializer(axes_name)) { LOGS_DEFAULT(VERBOSE) << "Axes of ReduceMean must be a constant initializer."; return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reshape_op_builder.cc index 869883b98b22e..f2f9165d2f3cc 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reshape_op_builder.cc @@ -35,14 +35,13 @@ class ReshapeOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; // Reshape opset 4- uses attributes for new shape which we do not support for now int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 5; } - bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const override; + bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; @@ -59,10 +58,10 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { auto& shaper(model_builder.GetShaper()); - const auto& initializers(model_builder.GetInitializerTensors()); + const auto& graph_viewer(model_builder.GetGraphViewer()); auto input = node_unit.Inputs()[0].node_arg.Name(); - const auto& shape_tensor = *initializers.at(node_unit.Inputs()[1].node_arg.Name()); + const auto& shape_tensor = *graph_viewer.GetConstantInitializer(node_unit.Inputs()[1].node_arg.Name()); Initializer unpacked_tensor(shape_tensor); auto raw_shape = unpacked_tensor.DataAsSpan(); const auto size = SafeInt(shape_tensor.dims()[0]); @@ -80,7 +79,7 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons int32_t x_zero_point = 0; if (IsQuantizedOp(node_unit)) { ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); + graph_viewer, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); } @@ -93,12 +92,13 @@ bool ReshapeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQReshape; } -bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool ReshapeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); const auto& perm_name = inputs[1].node_arg.Name(); - if (!Contains(initializers, perm_name)) { - LOGS_DEFAULT(VERBOSE) << "New shape of reshape must be known"; + const auto* perm = graph_viewer.GetConstantInitializer(perm_name); + if (!perm) { + LOGS_DEFAULT(VERBOSE) << "New shape of reshape must be a constant initializer"; return false; } @@ -112,7 +112,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer return false; } - const auto& perm_tensor = *initializers.at(perm_name); + const auto& perm_tensor = *perm; Initializer unpacked_tensor(perm_tensor); auto raw_perm = unpacked_tensor.DataAsSpan(); const auto perm_size = SafeInt(perm_tensor.dims()[0]); @@ -138,17 +138,17 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer } bool ReshapeOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { if (!IsQuantizedOp(node_unit)) { - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); } - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput)) { + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput)) { return false; } - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) { + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) { return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc index cdaa1c8fac76c..d75b9cc72ff4b 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc @@ -33,19 +33,18 @@ class ResizeOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, - const OpSupportCheckParams& /* params */) const override; + int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing // We only support Resize opset 11+ here int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 11; } - bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const override; + bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } bool IsQuantizedOp(const NodeUnit& node_unit) const override; }; @@ -74,7 +73,6 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& shaper(model_builder.GetShaper()); const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); - const auto& initializers(model_builder.GetInitializerTensors()); NodeAttrHelper helper(node_unit); const auto& inputs = node_unit.Inputs(); const auto android_feature_level = model_builder.GetEffectiveFeatureLevel(); @@ -92,7 +90,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const float x_scale = 0.0f; int32_t x_zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); + model_builder.GetGraphViewer(), node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); } @@ -147,7 +145,7 @@ bool ResizeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQResize; } -bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool ResizeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -228,32 +226,29 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers } } - { // scales and sizes (if present) must be initializers + // scales or sizes must be constant initializers + { + // scales is input 3, sizes input 4, one must exist. only one is used. const auto inputs = node_unit.Inputs(); - if (inputs.size() < 3) { + bool using_scales = inputs.size() > 2 && inputs[2].node_arg.Exists(); + bool using_sizes = !using_scales && inputs.size() > 3 && inputs[3].node_arg.Exists(); + if (!using_scales && !using_sizes) { LOGS_DEFAULT(VERBOSE) << "Input scales or sizes of Resize must be known"; return false; } - // scales - bool using_scales = (inputs.size() > 2 && inputs[2].node_arg.Exists()); - if (using_scales && !Contains(initializers, inputs[2].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "Input scales of Resize must be known"; - return false; - } - - // sizes - bool using_sizes = inputs.size() > 3 && inputs[3].node_arg.Exists(); - if (using_sizes && !Contains(initializers, inputs[3].node_arg.Name())) { - LOGS_DEFAULT(VERBOSE) << "Input sizes of Resize must be known"; - return false; - } - bool input_is_nchw = false; // haven't a good solution to check layout when scale is 1.0F // We want to check if the scales or sizes are not trying to resize on N/C channels here - if (using_scales) { // we are using scales - const auto& scales_tensor = *initializers.at(inputs[2].node_arg.Name()); - Initializer const unpacked_tensor(scales_tensor); + bool input_is_nchw = false; + + if (using_scales) { + const auto* scales = graph_viewer.GetConstantInitializer(inputs[2].node_arg.Name()); + if (!scales) { + LOGS_DEFAULT(VERBOSE) << "Input scales of Resize must be a constant initializer"; + return false; + } + + const Initializer unpacked_tensor(*scales); auto scales_data = unpacked_tensor.DataAsSpan(); input_is_nchw = scales_data[1] == 1.0F; float const scale_n = scales_data[0]; @@ -265,10 +260,13 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; } } else { - // we are using sizes - const auto& sizes_name = inputs[3].node_arg.Name(); - const auto& sizes_tensor = *initializers.at(sizes_name); - Initializer unpacked_tensor(sizes_tensor); + const auto* sizes = graph_viewer.GetConstantInitializer(inputs[3].node_arg.Name()); + if (!sizes) { + LOGS_DEFAULT(VERBOSE) << "Input sizes of Resize must be a constant initializer"; + return false; + } + + Initializer unpacked_tensor(*sizes); auto sizes_data = unpacked_tensor.DataAsSpan(); input_is_nchw = sizes_data[1] == input_shape[1]; @@ -308,7 +306,7 @@ int32_t ResizeOpBuilder::GetMinSupportedNNAPIFeatureLevel(const NodeUnit& node_u } bool ResizeOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { int32_t input_type; if (!GetType(node_unit.Inputs()[0].node_arg, input_type)) @@ -323,10 +321,10 @@ bool ResizeOpBuilder::HasSupportedInputOutputsImpl( } if (IsQuantizedOp(node_unit)) { - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/slice_op_builder.cc index 903469d34e67c..facdc7132dc00 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/slice_op_builder.cc @@ -40,7 +40,7 @@ class SliceOpBuilder : public BaseOpBuilder { // We only support slice from opset 10 int GetMinSupportedOpSet(const NodeUnit& /* node_unit */) const override { return 10; } - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -201,7 +201,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const // Operator support related -bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool SliceOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -219,19 +219,19 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } - if (!CheckIsInitializer(initializers, node_unit, node_unit.Inputs()[1].node_arg.Name(), "starts")) { + if (!CheckIsConstantInitializer(graph_viewer, node_unit, node_unit.Inputs()[1].node_arg.Name(), "starts")) { return false; } - if (!CheckIsInitializer(initializers, node_unit, node_unit.Inputs()[2].node_arg.Name(), "ends")) { + if (!CheckIsConstantInitializer(graph_viewer, node_unit, node_unit.Inputs()[2].node_arg.Name(), "ends")) { return false; } const auto& inputs = node_unit.Inputs(); if (inputs.size() > 3) { - if (!CheckIsInitializer(initializers, node_unit, node_unit.Inputs()[3].node_arg.Name(), "axes")) { + if (!CheckIsConstantInitializer(graph_viewer, node_unit, node_unit.Inputs()[3].node_arg.Name(), "axes")) { return false; } if (inputs.size() > 4) { - if (!CheckIsInitializer(initializers, node_unit, node_unit.Inputs()[4].node_arg.Name(), "steps")) { + if (!CheckIsConstantInitializer(graph_viewer, node_unit, node_unit.Inputs()[4].node_arg.Name(), "steps")) { return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/softmax_op_builder.cc index 1e420fec80827..a2a8b4512b028 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/softmax_op_builder.cc @@ -33,7 +33,7 @@ class SoftMaxOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -41,7 +41,7 @@ class SoftMaxOpBuilder : public BaseOpBuilder { return ANEURALNETWORKS_FEATURE_LEVEL_2; } bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } @@ -77,8 +77,7 @@ Status SoftMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons int32_t y_zero_point = 0; if (IsQuantizedOp(node_unit)) { ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - model_builder.GetInitializerTensors(), node_unit.Inputs()[0], node_unit.ModelPath(), - x_scale, x_zero_point)); + model_builder.GetGraphViewer(), node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); @@ -156,7 +155,7 @@ bool SoftMaxOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQSoftmax; } -bool SoftMaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool SoftMaxOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -197,24 +196,23 @@ bool SoftMaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool SoftMaxOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const OpSupportCheckParams& params) const { +bool SoftMaxOpBuilder::HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const { if (!IsQuantizedOp(node_unit)) { - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); } - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput)) { + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput)) { return false; } - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) { + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) { return false; } // NNAPI requires the scale be 1.f/256 and zero point to be 0 if (!HasRequiredScaleAndZeroPoint( - initializers, + graph_viewer, MakeString("Op [", node_unit.OpType(), "] name [", node_unit.Name(), "]'s output 0 "), node_unit.Outputs()[0], node_unit.ModelPath(), 1.f / 256 /* required_scale */, 0 /* required_zp */)) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc index 68b63badb8f7e..b2225643b788e 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc @@ -35,7 +35,7 @@ class SplitOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; // Split opset 13- uses "split" as attribute. Currently it's not supported. @@ -85,7 +85,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const // Operator support related -bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool SplitOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -98,13 +98,13 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const auto split_dims_at_axis = input_shape[SafeInt(HandleNegativeAxis(axis, input_shape.size()))]; if (input_defs.size() > 1 && input_defs[1].node_arg.Exists()) { // if optional input `split` is provided - auto split_initializer_it = initializers.find(input_defs[1].node_arg.Name()); - if (split_initializer_it == initializers.end()) { - LOGS_DEFAULT(VERBOSE) << "Optional input 'split' must be initializer if provided."; + const auto* splits = graph_viewer.GetConstantInitializer(input_defs[1].node_arg.Name()); + if (!splits) { + LOGS_DEFAULT(VERBOSE) << "Optional input 'split' must be a constant initializer if provided."; return false; } - const auto& splits_tensor = *split_initializer_it->second; - Initializer unpacked_tensor(splits_tensor); + + Initializer unpacked_tensor(*splits); auto splits_span = unpacked_tensor.DataAsSpan(); uint32_t sum_of_splits = std::accumulate(splits_span.begin(), splits_span.end(), SafeInt(0)); if (sum_of_splits != split_dims_at_axis) { @@ -119,6 +119,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, auto it = std::adjacent_find(splits_span.begin(), splits_span.end(), [](const auto& a, const auto& b) { return a != b; }); + if (it != splits_span.end()) { LOGS_DEFAULT(VERBOSE) << "NNAPI only supports the case that number of splits evenly divides split axis size"; return false; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/squeeze_op_builder.cc index a0fe744eaacc8..fb3ca5e6175fa 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/squeeze_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/squeeze_op_builder.cc @@ -32,7 +32,7 @@ class SqueezeOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -59,7 +59,7 @@ Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons // Operator support related -bool SqueezeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool SqueezeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); Shape input_shape; @@ -76,8 +76,8 @@ bool SqueezeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer // Squeeze opset 13 use input 1 as axes, if we have input 1 then it need to be an initializer if (node_unit.SinceVersion() > 12 && inputs.size() > 1) { const auto& axes_name = inputs[1].node_arg.Name(); - if (!Contains(initializers, axes_name)) { - LOGS_DEFAULT(VERBOSE) << "Input axes of Squeeze must be known"; + if (!graph_viewer.GetConstantInitializer(axes_name)) { + LOGS_DEFAULT(VERBOSE) << "Input axes of Squeeze must be a constant initializer"; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/transpose_op_builder.cc index 4d243c730bf05..6fe5ca32fe044 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/transpose_op_builder.cc @@ -32,7 +32,7 @@ class TransposeOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, @@ -41,7 +41,7 @@ class TransposeOpBuilder : public BaseOpBuilder { } bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; bool IsNodeUnitTypeSupported(const NodeUnit& /* node_unit */) const override { return true; } bool IsQuantizedOp(const NodeUnit& node_unit) const override; @@ -59,7 +59,6 @@ void TransposeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, cons Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { auto& shaper(model_builder.GetShaper()); - const auto& initializers(model_builder.GetInitializerTensors()); const auto& input = node_unit.Inputs()[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); @@ -78,7 +77,7 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co float x_scale = 0.0f; int32_t x_zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); + model_builder.GetGraphViewer(), node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); } @@ -95,7 +94,7 @@ bool TransposeOpBuilder::IsQuantizedOp(const NodeUnit& node_unit) const { return GetQuantizedOpType(node_unit) == QuantizedOpType::QDQTranspose; } -bool TransposeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, +bool TransposeOpBuilder::IsOpSupportedImpl(const GraphViewer& /* graph_viewer */, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -112,7 +111,7 @@ bool TransposeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia } bool TransposeOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { int32_t input_type; if (!GetType(node_unit.Inputs()[0].node_arg, input_type)) @@ -127,10 +126,10 @@ bool TransposeOpBuilder::HasSupportedInputOutputsImpl( } if (IsQuantizedOp(node_unit)) { - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unary_op_builder.cc index 796fd207fe428..dbd960ee5536c 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unary_op_builder.cc @@ -32,19 +32,18 @@ class UnaryOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& /* node_unit */, + int32_t GetMinSupportedNNAPIFeatureLevel(const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; - bool HasSupportedInputOutputsImpl( - const InitializedTensorSet& /* initializers */, const NodeUnit& node_unit, - const OpSupportCheckParams& /* params */) const override; + bool HasSupportedInputOutputsImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const OpSupportCheckParams& params) const override; int GetMinSupportedOpSet(const NodeUnit& node_unit) const override; - static bool IsQuantizedOpSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + static bool IsQuantizedOpSupported(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params); }; @@ -117,11 +116,10 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const float y_scale = 0.0f; int32_t y_zero_point = 0; if (is_qlinear_sigmoid) { - const auto& initializers = model_builder.GetInitializerTensors(); float x_scale = 0.0f; int32_t x_zero_point = 0; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); + model_builder.GetGraphViewer(), node_unit.Inputs()[0], node_unit.ModelPath(), x_scale, x_zero_point)); // Verify if the scale and zero point values from onnx input and nnapi input match ORT_RETURN_IF_ERROR(IsValidInputQuantizedType(model_builder, input, x_scale, x_zero_point)); @@ -141,10 +139,10 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const // Operator support related -bool UnaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool UnaryOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { if (node_unit.OpType() == "QLinearSigmoid") { - return IsQuantizedOpSupported(initializers, node_unit, params); + return IsQuantizedOpSupported(graph_viewer, node_unit, params); } else if (node_unit.OpType() == "Sigmoid") { Shape input_shape; if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) @@ -178,16 +176,16 @@ int32_t UnaryOpBuilder::GetMinSupportedNNAPIFeatureLevel(const NodeUnit& node_un } bool UnaryOpBuilder::HasSupportedInputOutputsImpl( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const { // We only need to override input check for QLinearSigmoid if (node_unit.OpType() != "QLinearSigmoid") - return BaseOpBuilder::HasSupportedInputOutputsImpl(initializers, node_unit, params); + return BaseOpBuilder::HasSupportedInputOutputsImpl(graph_viewer, node_unit, params); - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kInput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kInput)) return false; - if (!IsQuantizedIOSupported(initializers, node_unit, {0}, params, ArgType::kOutput)) + if (!IsQuantizedIOSupported(graph_viewer, node_unit, {0}, params, ArgType::kOutput)) return false; return true; @@ -204,13 +202,13 @@ int UnaryOpBuilder::GetMinSupportedOpSet(const NodeUnit& node_unit) const { } /* static */ bool UnaryOpBuilder::IsQuantizedOpSupported( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) { + const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) { const auto& op_type = node_unit.OpType(); ORT_ENFORCE(op_type == "QLinearSigmoid"); // NNAPI requires the scale be 1.f/256 and zero point to be 0 // See https://android.googlesource.com/platform/frameworks/ml/+/refs/heads/android10-c2f2-release/nn/common/operations/Activation.cpp#180 - if (!HasRequiredScaleAndZeroPoint(initializers, + if (!HasRequiredScaleAndZeroPoint(graph_viewer, MakeString("Op [", op_type, "] name [", node_unit.Name(), "]'s output 0 "), node_unit.Outputs()[0], node_unit.ModelPath(), 1.f / 256 /* required_scale */, 0 /* required_zp */)) { diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unsqueeze_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unsqueeze_op_builder.cc index a9bece7d42364..95cd813800c9a 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/unsqueeze_op_builder.cc @@ -32,7 +32,7 @@ class UnsqueezeOpBuilder : public BaseOpBuilder { // Operator support related private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + bool IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const override; }; @@ -74,7 +74,7 @@ Status UnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co // Operator support related -bool UnsqueezeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool UnsqueezeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& /* params */) const { const auto& inputs = node_unit.Inputs(); Shape input_shape; @@ -93,8 +93,8 @@ bool UnsqueezeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ // Unsqueeze opset 13 uses input 1 as axes, if we have input 1 then it needs to be an initializer if (node_unit.SinceVersion() > 12 && inputs.size() > 1) { const auto& axes_name = inputs[1].node_arg.Name(); - if (!Contains(initializers, axes_name)) { - LOGS_DEFAULT(VERBOSE) << "Input axes of Unsqueeze must be known"; + if (!graph_viewer.GetConstantInitializer(axes_name)) { + LOGS_DEFAULT(VERBOSE) << "Input axes of Unsqueeze must be a constant initializer"; return false; } } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc index b75e78cbfe7cc..6962a7be94bb6 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/model_builder.cc @@ -100,7 +100,7 @@ void ModelBuilder::PreprocessActivations() { activation_node_units_.emplace(node_unit.get(), ANEURALNETWORKS_FUSED_RELU); } else if (op_type == "Clip") { // Relu1 or Relu6 float min, max; - if (!GetClipMinMax(GetInitializerTensors(), node, min, max, logging::LoggingManager::DefaultLogger())) + if (!GetClipMinMax(graph_viewer_, node, min, max, logging::LoggingManager::DefaultLogger())) continue; if (min == -1.0f && max == 1.0f) { @@ -151,7 +151,7 @@ void ModelBuilder::GetAllQuantizedOpInputs() { } static Status GetInputDataType( - const InitializedTensorSet& initializers, + const GraphViewer& graph_viewer, const std::unordered_map>& all_quantized_op_inputs, const std::string& name, int32_t data_type, const Shape& shape, OperandType& operand_type) { @@ -177,7 +177,7 @@ static Status GetInputDataType( // TODO, verify the scale and zero point match if there are multiple op using same input const auto* node_unit = all_quantized_op_inputs.at(name)[0]; ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, *node_unit, name, scale, zero_point, ArgType::kInput)); + graph_viewer, *node_unit, name, scale, zero_point, ArgType::kInput)); break; } case ONNX_NAMESPACE::TensorProto_DataType_INT32: @@ -226,9 +226,8 @@ Status ModelBuilder::RegisterInitializers() { } OperandType operand_type(Type::TENSOR_FLOAT32, shape); - ORT_RETURN_IF_ERROR( - GetInputDataType(GetInitializerTensors(), all_quantized_op_inputs_, - name, tensor.data_type(), shape, operand_type)); + ORT_RETURN_IF_ERROR(GetInputDataType(graph_viewer_, all_quantized_op_inputs_, name, tensor.data_type(), shape, + operand_type)); shaper_.AddShape(name, operand_type.dimensions); uint32_t index = 0; @@ -304,7 +303,7 @@ Status ModelBuilder::RegisterModelInputs() { "The input of graph doesn't have elem_type: ", input_name); } else { ORT_RETURN_IF_ERROR( - GetInputDataType(GetInitializerTensors(), all_quantized_op_inputs_, + GetInputDataType(graph_viewer_, all_quantized_op_inputs_, input_name, type_proto->tensor_type().elem_type(), shape, operand_type)); } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h index c565af491ff90..f6db4022fb8f4 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder.h @@ -56,7 +56,7 @@ class IOpBuilder { // Operator support check related // Check if an operator is supported - virtual bool IsOpSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, + virtual bool IsOpSupported(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const OpSupportCheckParams& params) const = 0; }; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc index 26db7c8e7afea..a066c64dac67d 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc @@ -679,16 +679,15 @@ Status HandleAutoPad(const Shape& input_shape, return Status::OK(); } -Status GetBinaryOpQuantizationScaleAndZeroPoint( - const InitializedTensorSet& initializers, const NodeUnit& node_unit, - float& a_scale, float& b_scale, float& y_scale, - int32_t& a_zero_point, int32_t& b_zero_point, int32_t& y_zero_point) { +Status GetBinaryOpQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + float& a_scale, float& b_scale, float& y_scale, + int32_t& a_zero_point, int32_t& b_zero_point, int32_t& y_zero_point) { ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[0], node_unit.ModelPath(), a_scale, a_zero_point)); + graph_viewer, node_unit.Inputs()[0], node_unit.ModelPath(), a_scale, a_zero_point)); ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Inputs()[1], node_unit.ModelPath(), b_scale, b_zero_point)); + graph_viewer, node_unit.Inputs()[1], node_unit.ModelPath(), b_scale, b_zero_point)); ORT_RETURN_IF_ERROR(GetQuantizationScaleAndZeroPoint( - initializers, node_unit.Outputs()[0], node_unit.ModelPath(), y_scale, y_zero_point)); + graph_viewer, node_unit.Outputs()[0], node_unit.ModelPath(), y_scale, y_zero_point)); return Status::OK(); } @@ -699,16 +698,18 @@ Status GetConvMatMulOpQuantizationScaleAndZeroPoint( int32_t& a_zero_point, int32_t& w_zero_point, int32_t& y_zero_point, std::optional>& w_scales, bool& is_per_tensor_u8s8) { is_per_tensor_u8s8 = false; - const auto& initializers(model_builder.GetInitializerTensors()); + const auto& graph_viewer(model_builder.GetGraphViewer()); + // Get scale and zero points // We will handle per-channel weight scale and zero point later ORT_RETURN_IF_ERROR( - GetBinaryOpQuantizationScaleAndZeroPoint(initializers, node_unit, + GetBinaryOpQuantizationScaleAndZeroPoint(graph_viewer, node_unit, a_scale, w_scale, y_scale, a_zero_point, w_zero_point, y_zero_point)); const auto& inputs = node_unit.Inputs(); - const auto& weight_tensor = *initializers.at(inputs[1].node_arg.Name()); + // all these were checked to be constant in GemmOpBuilder::IsOpSupportedImpl + const auto& weight_tensor = *graph_viewer.GetConstantInitializer(inputs[1].node_arg.Name()); // We are done here if this is u8u8 QLinearConv if (weight_tensor.data_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) @@ -719,7 +720,7 @@ Status GetConvMatMulOpQuantizationScaleAndZeroPoint( // For this case we will need to convert the int8 weight tensor to uint8 // And have same scale and 128 as zero point // The conversion of the weight tensor itself will be done in the OpBuilder - const auto& scale_tensor = *initializers.at(inputs[1].quant_param->scale.Name()); + const auto& scale_tensor = *graph_viewer.GetConstantInitializer(inputs[1].quant_param->scale.Name()); int64_t scale_dim = scale_tensor.dims().empty() ? 1 : scale_tensor.dims()[0]; if (scale_dim == 1) { w_zero_point = 128; @@ -1072,20 +1073,20 @@ Status AddReshapeOperator(ModelBuilder& model_builder, return Status::OK(); } -bool IsQuantizationScaleSupported(const InitializedTensorSet& initializers, +bool IsQuantizationScaleSupported(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const OpSupportCheckParams& params, const std::string& op_type, bool is_quant_matmul, bool is_conv_matmul_u8s8_weight) { const auto scale_name = io_def.quant_param->scale.Name(); - auto it = initializers.find(scale_name); - if (it == initializers.cend()) { - LOGS_DEFAULT(VERBOSE) << "The scale of " << op_type << " must be an initializer tensor"; + const auto* scale = graph_viewer.GetConstantInitializer(scale_name); + if (!scale) { + LOGS_DEFAULT(VERBOSE) << "The scale of " << op_type << " must be a constant initializer"; return false; } - const auto& scale_tensor = *it->second; + const auto& scale_tensor = *scale; int64_t scales_dim = scale_tensor.dims().empty() ? 1 : scale_tensor.dims()[0]; if (!is_conv_matmul_u8s8_weight) { if (scales_dim != 1) { @@ -1123,7 +1124,7 @@ bool IsQuantizationScaleSupported(const InitializedTensorSet& initializers, return true; } -bool IsQuantizationZeroPointSupported(const InitializedTensorSet& initializers, +bool IsQuantizationZeroPointSupported(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const std::string& op_type, const Path& model_path, @@ -1134,12 +1135,13 @@ bool IsQuantizationZeroPointSupported(const InitializedTensorSet& initializers, return true; const auto& zero_point_name = io_def.quant_param->zero_point->Name(); - if (!Contains(initializers, zero_point_name)) { - LOGS_DEFAULT(VERBOSE) << "The zero point of " << op_type << " must be an initializer tensor"; + const auto* zero_point = graph_viewer.GetConstantInitializer(zero_point_name); + if (!zero_point) { + LOGS_DEFAULT(VERBOSE) << "The zero point of " << op_type << " must be a constant initializer"; return false; } - const auto& zero_tensor = *initializers.at(zero_point_name); + const auto& zero_tensor = *zero_point; int64_t zero_dim = zero_tensor.dims().empty() ? 1 : zero_tensor.dims()[0]; if (!is_conv_matmul_u8s8_weight) { @@ -1194,8 +1196,9 @@ bool IsQuantizationZeroPointSupported(const InitializedTensorSet& initializers, return true; } -bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, - const std::vector& indices, const OpSupportCheckParams& params, ArgType arg_type) { +bool IsQuantizedIOSupported(const GraphViewer& graph_viewer, const NodeUnit& node_unit, + const std::vector& indices, const OpSupportCheckParams& params, + ArgType arg_type) { const auto& op_type = node_unit.OpType(); auto quant_op_type = GetQuantizedOpType(node_unit); @@ -1247,12 +1250,12 @@ bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, const Node } // Check scale and zero point - if (!IsQuantizationScaleSupported(initializers, io_def, params, op_type, + if (!IsQuantizationScaleSupported(graph_viewer, io_def, params, op_type, is_quant_matmul, is_conv_matmul_u8s8_weight)) { return false; } - if (!IsQuantizationZeroPointSupported(initializers, io_def, op_type, node_unit.ModelPath(), + if (!IsQuantizationZeroPointSupported(graph_viewer, io_def, op_type, node_unit.ModelPath(), is_quant_matmul, is_conv_matmul_u8s8_weight)) { return false; } @@ -1261,33 +1264,27 @@ bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, const Node return true; } -bool HasRequiredScaleAndZeroPoint(const InitializedTensorSet& initializers, +bool HasRequiredScaleAndZeroPoint(const GraphViewer& graph_viewer, const std::string& op_desc, const NodeUnitIODef& io_def, const Path& path, float required_scale, int32_t required_zp) { float scale = 0.0f; int32_t zp = 0; - auto status = GetQuantizationScaleAndZeroPoint(initializers, io_def, path, - scale, zp); + auto status = GetQuantizationScaleAndZeroPoint(graph_viewer, io_def, path, scale, zp); if (!status.IsOK()) { - LOGS_DEFAULT(ERROR) << op_desc - << " GetQuantizationScaleAndZeroPoint failed, message: " - << status.ErrorMessage(); + LOGS_DEFAULT(ERROR) << op_desc << " GetQuantizationScaleAndZeroPoint failed, message: " << status.ErrorMessage(); return false; } if (scale != required_scale) { - LOGS_DEFAULT(VERBOSE) << op_desc - << " scale can only be [" << required_scale - << "], actual scale: " << scale; + LOGS_DEFAULT(VERBOSE) << op_desc << " scale can only be [" << required_scale << "], actual scale: " << scale; return false; } if (zp != required_zp) { - LOGS_DEFAULT(VERBOSE) << op_desc - << "] zero point can only be [" << required_zp - << "], actual zero point: " << scale; + LOGS_DEFAULT(VERBOSE) << op_desc << "] zero point can only be [" << required_zp << "], actual zero point: " + << zp; return false; } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h index 0cc442890ab6e..7ccf4c1ef7555 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h @@ -118,7 +118,7 @@ Status HandleAutoPad(const Shape& input_shape, // Get scales and zero points for the qlinear binary ops (which has 2 input and 1 output) // QLinearConv, QLinearMatmul, QLinearAdd, QLinearMul // a, b are inputs, and y is output -Status GetBinaryOpQuantizationScaleAndZeroPoint(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +Status GetBinaryOpQuantizationScaleAndZeroPoint(const GraphViewer& graph_viewer, const NodeUnit& node_unit, float& a_scale, float& b_scale, float& y_scale, int32_t& a_zero_point, int32_t& b_zero_point, int32_t& y_zero_point); @@ -193,14 +193,14 @@ inline bool IsNodeLayoutNHWC(const NodeUnit& node_unit) { return node_unit.Domain() == kMSInternalNHWCDomain; } -bool IsQuantizationScaleSupported(const InitializedTensorSet& initializers, +bool IsQuantizationScaleSupported(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const OpSupportCheckParams& params, const std::string& op_type, bool is_quant_matmul, bool is_conv_matmul_u8s8_weight); -bool IsQuantizationZeroPointSupported(const InitializedTensorSet& initializers, +bool IsQuantizationZeroPointSupported(const GraphViewer& graph_viewer, const NodeUnitIODef& io_def, const std::string& op_type, const Path& model_path, @@ -208,13 +208,13 @@ bool IsQuantizationZeroPointSupported(const InitializedTensorSet& initializers, bool is_conv_matmul_u8s8_weight); // Check if the given quantized input(s) or output(s) is supported -bool IsQuantizedIOSupported(const InitializedTensorSet& initializers, const NodeUnit& node_unit, +bool IsQuantizedIOSupported(const GraphViewer& graph_viewer, const NodeUnit& node_unit, const std::vector& indices, const OpSupportCheckParams& params, ArgType arg_type); // Some Quantized NNAPI operations have required output scale and zero point // e.g. Softmax (uint8) requires output scale be 1.f/256 and zp be 0 // This helper function checks if the given io_def has required scale and zp -bool HasRequiredScaleAndZeroPoint(const InitializedTensorSet& initializers, +bool HasRequiredScaleAndZeroPoint(const GraphViewer& graph_viewer, const std::string& op_desc, const NodeUnitIODef& io_def, const Path& path, diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 39ea4dd8412bb..37ad14ac2e9b1 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -25,12 +25,14 @@ bool GetType(const NodeArg& node_arg, int32_t& type, const logging::Logger& logg return true; } -bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node, - float& min, float& max, const logging::Logger& logger) { +namespace { +bool GetClipMinMaxImpl(std::function get_const_initializer, + const Node& node, float& min, float& max, const logging::Logger& logger) { const auto& node_name = node.Name(); int32_t input_type; - if (!GetType(*node.InputDefs()[0], input_type, logger)) + if (!GetType(*node.InputDefs()[0], input_type, logger)) { return false; + } min = std::numeric_limits::lowest(); max = std::numeric_limits::max(); @@ -41,49 +43,73 @@ bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node, min = helper.Get("min", std::numeric_limits::lowest()); max = helper.Get("max", std::numeric_limits::max()); } else { - if (node.InputDefs().size() > 1) { - // we have input min - const auto& min_name = node.InputDefs()[1]->Name(); - if (!Contains(initializers, min_name)) { - LOGS(logger, VERBOSE) << "Input min of Clip must be known"; + auto get_value = + [&](const ONNX_NAMESPACE::TensorProto* initializer, std::string_view type, float& value) -> bool { + if (!initializer) { + LOGS(logger, VERBOSE) << type << " input of Clip must be a constant initializer"; return false; } - Initializer unpacked_tensor_min(*initializers.at(min_name)); + + Initializer unpacked_tensor_min(*initializer); switch (input_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - min = unpacked_tensor_min.DataAsSpan()[0]; + value = unpacked_tensor_min.DataAsSpan()[0]; break; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - min = (unpacked_tensor_min.DataAsSpan()[0]).ToFloat(); + value = unpacked_tensor_min.DataAsSpan()[0].ToFloat(); break; default: - LOGS(logger, VERBOSE) << "GetClipMinMax() only support Clip node with float inputs for now. " - << "The node [" << node_name << "] has input 0 type: " << input_type; + LOGS(logger, VERBOSE) << "GetClipMinMax() only supports float and float16 as min and max inputs for now." + << " The node [" << node_name << "] has input type: " << input_type; return false; } - if (node.InputDefs().size() > 2) { - // we have input max - const auto& max_name = node.InputDefs()[2]->Name(); - if (!Contains(initializers, max_name)) { - LOGS(logger, VERBOSE) << "Input max of Clip must be known"; - return false; - } - Initializer unpacked_tensor_max(*initializers.at(max_name)); - switch (input_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - max = unpacked_tensor_max.DataAsSpan()[0]; - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - max = (unpacked_tensor_max.DataAsSpan()[0]).ToFloat(); - break; - } + return true; + }; + + // min and max are both optional. could have neither, one or both. + if (node.InputDefs().size() > 1 && node.InputDefs()[1]->Exists()) { + // we have input min + const auto& min_name = node.InputDefs()[1]->Name(); + const auto* min_value = get_const_initializer(min_name); + if (!get_value(min_value, "Min", min)) { + return false; + } + } + + if (node.InputDefs().size() > 2 && node.InputDefs()[2]->Exists()) { + // we have input max + const auto& max_name = node.InputDefs()[2]->Name(); + const auto* max_value = get_const_initializer(max_name); + if (!get_value(max_value, "Max", max)) { + return false; } } } return true; } +} // namespace + +bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node, float& min, float& max, + const logging::Logger& logger) { + return GetClipMinMaxImpl( + [&graph_viewer](const std::string& name) -> const ONNX_NAMESPACE::TensorProto* { + return graph_viewer.GetConstantInitializer(name); + }, + node, min, max, logger); +} + +// deprecated version that is not able to check if the initializer is constant +bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node, float& min, float& max, + const logging::Logger& logger) { + return GetClipMinMaxImpl( + [&initializers](const std::string& name) -> const ONNX_NAMESPACE::TensorProto* { + auto entry = initializers.find(name); + return entry == initializers.end() ? nullptr : entry->second; + }, + node, min, max, logger); +} NodeAttrHelper::NodeAttrHelper(const onnxruntime::Node& node) : node_attributes_(node.GetAttributes()) {} diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h index 1e93f040711df..31b1aba2e1a63 100644 --- a/onnxruntime/core/providers/shared/utils/utils.h +++ b/onnxruntime/core/providers/shared/utils/utils.h @@ -16,14 +16,20 @@ namespace logging { class Logger; } +class GraphViewer; class Node; class NodeArg; class NodeUnit; -// Get the min/max of a Clip operator. -// If min/max are not known initializer tensors, will return false -// For now we only support getting float min/max, -// since in most cases, Clip(0,6)[Relu6] will be fused by quantization tool +// Get the min/max of a Clip operator. Reads values from attributes for opset < 11 and inputs after that. +// For opset 11+, if min/max are not constant initializers, will return false. +// For now we only support getting float min/max. +bool GetClipMinMax(const GraphViewer& graph_viewer, const Node& node, + float& min, float& max, const logging::Logger& logger); + +/// GraphViewer GetConstantInitializer/IsConstantInitializer should be used to ensure the initializer is +/// constant. Low risk for Clip min/max but in general the infrastructure to check if an operator is supported needs +/// to be updated to not use InitializedTensorSet which may contain non-constant initializers. bool GetClipMinMax(const InitializedTensorSet& initializers, const Node& node, float& min, float& max, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index 9de5b889808fc..0d6001bcba89f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -47,7 +47,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& output_name = node.OutputDefs()[0]->Name(); emscripten::val options = emscripten::val::object(); float minValue, maxValue; - ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetInitializerTensors(), node, minValue, maxValue, logger), + ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetGraphViewer(), node, minValue, maxValue, logger), "GetClipMinMax failed"); options.set("minValue", minValue); options.set("maxValue", maxValue); @@ -70,6 +70,9 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { + // TODO: Update IsOpSupportedImpl to pass GraphViewer instead of InitializedTensorSet so the implementations + // can ensure initializers are constant. See #19401 for details of how this update was made to the NNAPI EP. + // GetClipMinMax(graph_viewer, node, minValue, maxValue, logger) float min, max; return GetClipMinMax(initializers, node, min, max, logger); } From 75f06319d61fd6c93387aa0f49b2ab7b2de647bf Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Wed, 7 Feb 2024 12:51:02 -0500 Subject: [PATCH 053/207] Change binet to bin (#19424) ### Description This pull request includes a small change to the `Dockerfile.manylinux2_28_cuda` file in the `tools/ci_build/github/linux/docker` directory. The change corrects the `PREPEND_PATH` argument from `/usr/local/cuda/binet` to `/usr/local/cuda/bin`, ensuring the correct path to CUDA binaries is set. --- .../ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index 0c95083d614ed..fafc47b6e9de6 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -7,7 +7,7 @@ ARG PLATFORM=x86_64 ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 ARG DEVTOOLSET_ROOTPATH=/usr ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64 -ARG PREPEND_PATH=/usr/local/cuda/binet +ARG PREPEND_PATH=/usr/local/cuda/bin ARG TRT_VERSION=8.6.1.6-1.cuda11.8 #Build manylinux docker image begin From 0d10c7f3c1111cfff064e7990aa897ac9fd05c82 Mon Sep 17 00:00:00 2001 From: luoyu-intel Date: Wed, 7 Feb 2024 21:04:37 +0000 Subject: [PATCH 054/207] Revert NeuralSpeed code for x64 MatMulNBits (#19382) ### Description Revert PR#19016 https://github.com/microsoft/onnxruntime/pull/19016 Revert PR#17669 https://github.com/microsoft/onnxruntime/pull/17669 --- cgmanifests/generated/cgmanifest.json | 10 - cmake/CMakeLists.txt | 12 - cmake/deps.txt | 1 - cmake/external/neural_speed.cmake | 15 - cmake/onnxruntime_providers_cpu.cmake | 15 - .../cpu/quantization/matmul_nbits.cc | 144 ------ .../cpu/quantization/neural_speed_defs.h | 45 -- .../cpu/quantization/neural_speed_gemm.cc | 438 ------------------ .../cpu/quantization/neural_speed_gemm.h | 129 ------ .../cpu/quantization/neural_speed_wrapper.h | 39 -- .../test/contrib_ops/matmul_4bits_test.cc | 175 ------- 11 files changed, 1023 deletions(-) delete mode 100644 cmake/external/neural_speed.cmake delete mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h delete mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc delete mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h delete mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index efd901787fdb7..fc4ea25603152 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -202,16 +202,6 @@ "comments": "mp11" } }, - { - "component": { - "type": "git", - "git": { - "commitHash": "c11386eb632eec7c1c2aa323142f73519f946e2a", - "repositoryUrl": "https://github.com/intel/neural-speed.git" - }, - "comments": "neural_speed" - } - }, { "component": { "type": "git", diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 90fe8276ea9c7..0ccd874cee3c9 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -88,7 +88,6 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) -option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) @@ -902,10 +901,6 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() - if(USE_NEURAL_SPEED) - target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED) - endif() - set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) if (onnxruntime_USE_CUDA) # Suppress a "conversion_function_not_usable" warning in gsl/span @@ -1193,13 +1188,6 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() -if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD) - include(neural_speed) - if (USE_NEURAL_SPEED) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla) - endif() -endif() - # TVM EP if (onnxruntime_USE_TVM) if (NOT TARGET tvm) diff --git a/cmake/deps.txt b/cmake/deps.txt index cb431f8c77397..17c3cbf9a6c43 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -35,7 +35,6 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake deleted file mode 100644 index ed711351403a7..0000000000000 --- a/cmake/external/neural_speed.cmake +++ /dev/null @@ -1,15 +0,0 @@ -if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") - set(USE_NEURAL_SPEED TRUE) -elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") - set(USE_NEURAL_SPEED TRUE) -endif() - -if(USE_NEURAL_SPEED) - FetchContent_Declare( - neural_speed - URL ${DEP_URL_neural_speed} - URL_HASH SHA1=${DEP_SHA1_neural_speed} - ) - set(BTLA_USE_OPENMP OFF) - onnxruntime_fetchcontent_makeavailable(neural_speed) -endif() diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index b81a5c79ac0cc..f60faa4d39116 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -60,15 +60,6 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" ) endif() - set(onnxruntime_cpu_neural_speed_srcs - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h" - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h" - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc" - "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h" - ) - if(NOT USE_NEURAL_SPEED) - list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs}) - endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) @@ -153,12 +144,6 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL) target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical") endif() -if(NOT onnxruntime_DISABLE_CONTRIB_OPS) - if(USE_NEURAL_SPEED) - onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla) - endif() -endif() - if (MSVC) target_compile_options(onnxruntime_providers PRIVATE "/bigobj") # if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 166f5c8f52f54..e8d8bbca66fe7 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -10,10 +10,6 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" -#ifdef ORT_NEURAL_SPEED -#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" -#endif - namespace onnxruntime { namespace contrib { @@ -23,16 +19,6 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level static_cast(CompMostAccurate), static_cast(CompLeastAccurate)); -#if defined(ORT_NEURAL_SPEED) - - ORT_UNUSED_PARAMETER(nbits); - ORT_UNUSED_PARAMETER(block_size); - - // Neural Speed APIs already expect a minimum accuracy level so just use the given value. - return accuracy_level; - -#else // defined(ORT_NEURAL_SPEED) - // Find a supported accuracy level that is not less accurate than the one given. // CompMostAccurate is always supported with the fallback implementation. // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. @@ -45,8 +31,6 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level } return effective_accuracy_level; - -#endif // defined(ORT_NEURAL_SPEED) } } // namespace @@ -61,17 +45,6 @@ class MatMulNBits final : public OpKernel { accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); -#ifdef ORT_NEURAL_SPEED - const Tensor* tensor_B = nullptr; - const Tensor* tensor_scale = nullptr; - const Tensor* tensor_zero_point = nullptr; - bool B_constant = info.TryGetConstantInput(1, &tensor_B); - bool scale_constant = info.TryGetConstantInput(2, &tensor_scale); - bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point); - is_asym_ = info.GetInputCount() >= 4; - all_constant_ = B_constant && scale_constant; - all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_; -#endif } Status Compute(OpKernelContext* context) const override; @@ -92,13 +65,6 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; - -#if defined(ORT_NEURAL_SPEED) - - bool is_asym_{false}; - bool all_constant_{false}; - -#endif // defined(ORT_NEURAL_SPEED) }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, @@ -106,54 +72,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; -#if defined(ORT_NEURAL_SPEED) - - if (!all_constant_) { - return Status::OK(); - } - MLAS_THREADPOOL* pool = NULL; - if (nbits_ != 4) { - return Status::OK(); - } - auto comp_type = static_cast(accuracy_level_); - auto nbits = static_cast(nbits_); - if (input_idx == 1) { - packed_b_size_ = NSNBitsGemmPackBSize(N_, K_, block_size_, nbits, is_asym_, comp_type); - if (packed_b_size_ == 0) return Status::OK(); - auto qptr = tensor.Data(); - packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - std::memset(packed_b_.get(), 0, packed_b_size_); - NSNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, false, - comp_type, pool); - if (prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } - is_packed = true; - } - if (input_idx == 2 && packed_b_ != nullptr) { - auto sptr = tensor.Data(); - NSNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, !is_asym_, - comp_type, pool); - if (prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } - is_packed = true; - } - if (input_idx == 3 && packed_b_ != nullptr) { - auto zptr = tensor.Data(); - NSNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits, is_asym_, is_asym_, - comp_type, pool); - if (prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size_); - } - is_packed = true; - } - -#else // defined(ORT_NEURAL_SPEED) - if (input_idx == 1) { const auto compute_type = static_cast(accuracy_level_); if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { @@ -173,8 +91,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } -#endif // defined(ORT_NEURAL_SPEED) - return Status::OK(); } @@ -182,31 +98,11 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; -#if defined(ORT_NEURAL_SPEED) - - // Pack three tensors into one buffer - if (input_idx == 1) { - used_shared_buffers = true; - packed_b_ = std::move(prepacked_buffers[0]); - } - if (input_idx == 2) { - used_shared_buffers = true; - packed_b_ = std::move(prepacked_buffers[0]); - } - if (input_idx == 3) { - used_shared_buffers = true; - packed_b_ = std::move(prepacked_buffers[0]); - } - -#else // defined(ORT_NEURAL_SPEED) - if (input_idx == 1) { used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } -#endif // defined(ORT_NEURAL_SPEED) - return Status::OK(); } @@ -216,46 +112,6 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); -#if defined(ORT_NEURAL_SPEED) - - if (packed_b_) { - TensorShape b_shape({static_cast(N_), static_cast(K_)}); - - MatMulComputeHelper helper; - ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); - - Tensor* y = ctx->Output(0, helper.OutputShape()); - - // Bail out early if the output is going to be empty - if (y->Shape().Size() == 0) return Status::OK(); - - auto* y_data = y->MutableData(); - - const size_t max_len = helper.OutputOffsets().size(); - const size_t M = static_cast(helper.M()); - const size_t N = static_cast(helper.N()); - const size_t K = static_cast(helper.K()); - const size_t lda = helper.Lda(false); - std::vector gemm_params(max_len); - AllocatorPtr allocator; - auto status = ctx->GetTempSpaceAllocator(&allocator); - ORT_RETURN_IF_ERROR(status); - for (size_t i = 0; i < max_len; i++) { - gemm_params[i].A = a_data + helper.LeftOffsets()[i]; - gemm_params[i].lda = lda; - gemm_params[i].B = packed_b_.get(); - gemm_params[i].C = y_data + helper.OutputOffsets()[i]; - gemm_params[i].ldc = N; - } - auto ws_size = NSSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); - // workspace for activation process(dynamic quantization and others) - auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); - NSSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), thread_pool); - return Status::OK(); - } - -#endif // defined(ORT_NEURAL_SPEED) - const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); const auto* scales_data = scales->Data(); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h deleted file mode 100644 index 864abffd131fe..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h +++ /dev/null @@ -1,45 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - ---*/ - -#pragma once - -#include "contrib_ops/cpu/quantization/neural_speed_wrapper.h" - -namespace bestla { - -using tAVX512F = gemm::SCoreRowNAvx512f<48, 8>; -using tAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>; -using tAVX512_FP16 = gemm::HCoreRowNAvx512fp16<96, 8>; -using tAVX_VNNI = gemm::ICoreRowNAvxvnni<24, 4>; -using tAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>; -using tAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>; -using tAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>; -using tAVX2 = gemm::SCoreRowNAvx2<24, 4>; -using tAVX_VNNI_KBlock = gemm::ICoreRowNAvxvnniKBlock<24, 2>; -using tAVX512_VNNI_KBlock = gemm::ICoreRowNAvx512vnniKBlock<48, 4>; -using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<48, 16>; -using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<48, 16>; - -template -using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger; -template -using tWeiNFloat = prologue_b::gemm::WeightKBlockNFloat; - -class ORTThreading : public parallel::IThreading { - public: - explicit ORTThreading(void* tp); - void parallel_for(const parallel::thread_func& func) const override; - void set_threads(int nthreads) override { - (void)(nthreads); - assert(0); - } - void sync() const override { assert(0); } - void* mTp; -}; - -} // namespace bestla diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc deleted file mode 100644 index 73aaa4ae61a6e..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc +++ /dev/null @@ -1,438 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - neural_speed_gemm.cpp - -Abstract: - - GEMM template combinations of neural_speed. ---*/ - -#include "contrib_ops/cpu/quantization/neural_speed_defs.h" -#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" -#include "core/platform/threadpool.h" - -using ThreadPool = onnxruntime::concurrency::ThreadPool; - -namespace bestla { - -ORTThreading::ORTThreading(void* tp) - : IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) {} - -void ORTThreading::parallel_for(const parallel::thread_func& func) const { - ThreadPool::TrySimpleParallelFor(reinterpret_cast(mTp), mThreadNum, - [&](ptrdiff_t tid) { func(static_cast(tid)); }); -} - -template -static void NSSQ4GemmCompF32(size_t M, size_t N, size_t K, const float* A, size_t lda, - storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, - parallel::IThreading* th) { - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); - if (M <= 16) { - using Parallel = parallel::gemm::SchedulerKBlock; - using Launcher = - wrapper::gemm::LauncherKBlock; - static Launcher kernel; - auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); - if (B->IsAsym()) { - reduceA.assign(WorkSpace); - ORTThreading single(nullptr); - kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single); - } - typename Launcher::Param args{gp, - {A, lda_, &reduceA}, - {B}, - {B->template SPtr(), B->SDtype(), B->CStep(), B->template ZPtr(), - reduceA.template RPtr(), reduceA.lda}, - {C, ldc_, nullptr}}; - parallel::GemmRun(kernel, args, th); - } else { - using Parallel = parallel::gemm::SchedulerBase; - using Launcher = - wrapper::gemm::LauncherBase; - static Launcher kernel; - typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_, nullptr}}; - parallel::GemmRun(kernel, args, th); - } -} - -template -static void NSSQ4GemmCompInt8(size_t M, size_t N, size_t K, const float* A, size_t lda, - storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, - parallel::IThreading* th) { - using Parallel = parallel::gemm::SchedulerKBlockS; - using Launcher = - wrapper::gemm::LauncherIntKBlock; - auto M_ = static_cast(M); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto lda_ = static_cast(lda); - auto ldc_ = static_cast(ldc); - static Launcher kernel; - auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->IsAsym()); - quanA.assign(WorkSpace); - if (M <= 16) { - ORTThreading single(nullptr); - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); - } else { - kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); - } - utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); - typename Launcher::Param args{gp, {A, lda_, &quanA}, {B}, {C, ldc_, nullptr}}; - parallel::GemmRun(kernel, args, th); -} - -template -static size_t NSSQ4GemmCompF32WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, - storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { - auto M_ = static_cast(M); - auto K_ = static_cast(K); - (void)(A); - (void)(N); - (void)(C); - (void)(lda); - (void)(ldc); - if (M <= 16) { - using ProA = prologue_a::gemm::ActivationKBlockBaseF32; - static ProA proA; - if (B->IsAsym()) { - auto reduceA = proA.createStorage(M_, K_, B->mBlockSize); - return reduceA.mSize; - } - return 0; - } else { - // using ProA = prologue_a::gemm::ActivationBase; - return 0; - } -} - -template -static size_t NSSQ4GemmCompInt8WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, - storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { - (void)(N); - (void)(lda); - (void)(ldc); - (void)(A); - (void)(C); - using ProA = prologue_a::gemm::ActivationF32KBlockQuantize; - static ProA proA; - auto quanA = - proA.createStorage(static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->IsAsym()); - return quanA.mSize; -} - -} // namespace bestla - -using namespace bestla; - -static bool NSSQ4GemmBatchDriver(size_t M, size_t N, size_t K, size_t BatchN, - const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, - void* ThreadPool) { - GetCPUDevice(); - bestla::ORTThreading orth(ThreadPool); - bool processed = true; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); - auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); - auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); - auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); - if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { - auto kptr = reinterpret_cast(ptr); - auto BlkSize = kptr->mBlockSize; - if (btype == gemm::CompType::tFP32 && PackRow == 1) { - if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, - DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); - } else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, - DataParams[i].ldc, WorkSpace, &orth); - } - } - if (btype == gemm::CompType::tS8 && PackRow == 4) { - if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && - BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { - bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, - DataParams[i].C, DataParams[i].ldc, WorkSpace, - &orth); - } else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && - BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { - bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, - DataParams[i].C, DataParams[i].ldc, WorkSpace, - &orth); - } else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && - BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { - bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, - DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); - } - } - } - } else { - processed = false; - break; - } - } - return processed; -} - -static size_t NSSQ4GemmBatchWorkspaceSize(size_t M, size_t N, size_t K, size_t BatchN, - const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { - GetCPUDevice(); - size_t size = 0; - for (size_t i = 0; i < BatchN; i++) { - auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); - auto uptr = std::unique_ptr(ptr); - if (ptr) { - if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { - auto kptr = reinterpret_cast(ptr); - auto NTile = - gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); - auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); - auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); - auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); - auto BlkSize = kptr->mBlockSize; - if (btype == gemm::CompType::tFP32 && PackRow == 1) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, - DataParams[i].C, DataParams[i].ldc), - size); - } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, - DataParams[i].C, DataParams[i].ldc), - size); - } - } - if (btype == gemm::CompType::tS8 && PackRow == 4) { - if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { - size = std::max(NSSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), - size); - } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && - BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { - size = std::max(NSSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), - size); - } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { - size = std::max(NSSQ4GemmCompInt8WorkspaceSize( - M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), - size); - } - } - } - } - } - return size; -} - -template -static size_t NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) { - static T proB; - auto stor = proB.createStorage(static_cast(N), static_cast(K), static_cast(block_size), - BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, isAsym); - // TODO(Yu) support more scale dtype - return stor.mSize; -} - -static bool NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { - auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); - auto uptr = std::unique_ptr(ptr); - ORTThreading orth(ThreadPool); - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto ldb_ = static_cast(ldb); - GetCPUDevice(); - if (ptr) { - auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); - auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); - auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); - auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); - if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { - auto wptr = reinterpret_cast(ptr); - auto BlkSize = wptr->mBlockSize; - if (btype == gemm::CompType::tFP32 && PackRow == 1) { - if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } - } - if (btype == gemm::CompType::tS8 && PackRow == 4) { - if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && - BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { - static tWeiNInt proB; - proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); - } - } - } - return true; - } - return false; -} - -template -static void NSQ4GemmPackBImpl(void* PackedBuf, size_t BlkSize, const uint8_t* QData, const float* Scale, - const uint8_t* Zp, size_t N, size_t K, bool IsAsym, bool lastCall, size_t ldb, - void* ThreadPool) { - static T proB; - auto N_ = static_cast(N); - auto K_ = static_cast(K); - auto stor = proB.createStorage(N_, K_, static_cast(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, - BTLA_DTYPE::BF16, IsAsym); - stor.assign(reinterpret_cast(PackedBuf)); - ORTThreading orth(ThreadPool); - proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); - if (lastCall) { - proB.reduceWeight(&stor, &orth); - } -} - -static size_t NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, NS_SQNBIT_COMPUTE_TYPE CompType) { - GetCPUDevice(); - if (K % BlkSize != 0) { - return 0; - } - // from low precision to high precision - switch (CompType) { - case NSCompInt8: - if (!isAsym) { // asym int8 is not optimized, so fall through to others. - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { - return NSQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { - return NSQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { - return NSQ4BuSize>(BlkSize, N, K, isAsym); - } - } - [[fallthrough]]; - case NSCompBf16: - case NSCompFp16: - case NSCompFp32: - case NSCompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - return NSQ4BuSize>(BlkSize, N, K, isAsym); - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - return NSQ4BuSize>(BlkSize, N, K, isAsym); - } - [[fallthrough]]; - default: - return 0; - } -} - -static bool NSQ4GemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, - size_t K, size_t ldb, size_t BlkSize, bool isAsym, bool lastCall, - NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { - GetCPUDevice(); - // explicit statement fall through. - switch (CompType) { - case NSCompInt8: - if (!isAsym) { // asym int8 is not optimized, so fall through to others. - if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { - NSQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); - return true; - } - if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { - NSQ4GemmPackBImpl>( - PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); - return true; - } - if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { - NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, - K, isAsym, lastCall, ldb, ThreadPool); - return true; - } - } - [[fallthrough]]; - case NSCompBf16: - case NSCompFp16: - case NSCompFp32: - case NSCompUndef: - if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { - NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, - lastCall, ldb, ThreadPool); - return true; - } - if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { - NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, - ldb, ThreadPool); - return true; - } - [[fallthrough]]; - default: - return false; - } -} - -size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, - NS_SQNBIT_COMPUTE_TYPE CompType) { - if (nbits == 4) { - auto jsize = NSQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); - if (jsize) { - return jsize; - } - } - return 0; -} - -void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, - size_t ldb, size_t BlkSize, int nbits, bool isAsym, bool lastCall, - NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { - if (nbits == 4) { - if (NSQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { - return; - } - } -} - -void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { - // only nbits=4 can be packed, so not necessary to check the nbits in DataParams - if (NSQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { - return; - } -} - -size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, - const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { - // only nbits=4 can be packed, so not necessary to check the nbits in DataParams - return NSSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); -} - -void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, - const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, - void* ThreadPool) { - // only nbits=4 can be packed, so not necessary to check the nbits in DataParams - if (NSSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { - // PackedWeight is created by bestla - return; - } -} diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h deleted file mode 100644 index ebcb3027a209f..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h +++ /dev/null @@ -1,129 +0,0 @@ -/*++ - -Copyright (c) Microsoft Corporation. All rights reserved. - -Licensed under the MIT License. - -Module Name: - - neural_speed_gemm.h - -Abstract: - - Prepack-weight GEMM APIs of neural_speed. ---*/ - -#pragma once - -#include -#include - -/** - * @brief Define compute types of block quantization - */ -enum NS_SQNBIT_COMPUTE_TYPE { - NSCompUndef = 0, /*!< undef */ - NSCompFp32 = 1, /*!< input fp32, accumulator fp32 */ - NSCompFp16 = 2, /*!< input fp16, accumulator fp16 */ - NSCompBf16 = 3, /*!< input bf16, accumulator fp32 */ - NSCompInt8 = 4 /*!< input int8, accumulator int32 */ -}; - -/** - * @brief Data parameters for NBits GEMM routine - * C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * All except C are [in] parameters - */ -struct NS_SQNBITS_GEMM_DATA_PACKED_PARAMS { - const float* A = nullptr; /**< address of A (float32 matrix)*/ - const void* B = nullptr; /**< address of B (packed nbits blob)*/ - float* C = nullptr; /**< address of result matrix */ - size_t lda = 0; /**< leading dimension of A */ - size_t ldc = 0; /**< leading dimension of C*/ -}; - -/** - * @brief Compute the byte size of the parameter combination - * - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @return size of the packing buffer, 0 if the operation is not yet supported. - */ -size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t block_size, int nbits, bool is_asym, - NS_SQNBIT_COMPUTE_TYPE comp_type); - -/** - * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. - * - * @param PackedBuf packed data buffer - * @param QData quantized data buffer - * @param Scale scale pointer - * @param Zp zero point pointer - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param block_size size of the block to quantize, elements from the same block share the same - * scale and zero point - * @param nbits number of bits used for weight quantization (default 4) - * @param is_asym flag for asymmetric quantization - * @param comp_type specify input data type and accumulator data type - * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor - * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where - * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up - * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale - * (is_asym is false) and Zp(is_asym is true). - * @param thread_pool - */ -void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, - size_t ldb, size_t block_size, int nbits, bool is_asym, bool last_call, - NS_SQNBIT_COMPUTE_TYPE comp_type, void* thread_pool); - -/** - * @brief Unpack and dequantize to fp32 - * - * @param FpData unpacked float32 data - * @param PackedBuf quantized and packed data - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B - * @param thread_pool - */ -void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* thread_pool); - -/** - * @brief Get the workspace size required by computation. - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @return Workspace size in bytes - */ -size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, - const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams); - -/** - * @brief Batched GEMM: C = A * B - * A, C must be a float32 matrix - * B must be a packed nbits blob - * - * @param[in] M row size of matrix A and C - * @param[in] N column size of matrix B and C - * @param[in] K column size of matrix A and row size of matrix B - * @param[in] BatchN number of batches - * @param[inout] DataParams An array (size BatchN) of parameter blocks - * @param[in] WorkSpace temporary buffer - * @param[in] ThreadPool - * @return - */ -void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, - const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, - void* ThreadPool = nullptr); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h deleted file mode 100644 index d3902f9bd68c7..0000000000000 --- a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h +++ /dev/null @@ -1,39 +0,0 @@ -//----------------------------------------------------------------------------- -// -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -// -//----------------------------------------------------------------------------- -#pragma once -#if defined(__GNUC__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wunused-parameter" -#pragma GCC diagnostic ignored "-Wsign-compare" -#pragma GCC diagnostic ignored "-Wmissing-field-initializers" -#pragma GCC diagnostic ignored "-Wunused-variable" -#pragma GCC diagnostic ignored "-Wunused-value" -#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" -#pragma GCC diagnostic ignored "-Wunused-function" -#pragma GCC diagnostic ignored "-Wuninitialized" -#pragma GCC diagnostic ignored "-Wclass-memaccess" -#pragma GCC diagnostic ignored "-Wunused-but-set-variable" -#pragma GCC diagnostic ignored "-Wunused-but-set-parameter" - -#elif defined(_MSC_VER) -#pragma warning(push) -#pragma warning(disable : 4457) -#pragma warning(disable : 4189) -#pragma warning(disable : 4100) -#pragma warning(disable : 4244) -#pragma warning(disable : 4267) -#pragma warning(disable : 4702) -#endif - -#include "bestla/bestla_prologue_a.h" -#include "bestla/bestla_wrapper.h" - -#if defined(__GNUC__) -#pragma GCC diagnostic pop -#elif defined(_MSC_VER) -#pragma warning(pop) -#endif diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 2ad20eafc2ef1..d22da2a3da87f 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -149,17 +149,10 @@ TEST(MatMulNBits, Float32) { for (auto N : {1, 2, 32, 288}) { for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { -#ifdef ORT_NEURAL_SPEED - for (auto accuracy_level : {0, 1, 4}) { - RunTest(M, N, K, block_size, accuracy_level, false, false); - RunTest(M, N, K, block_size, accuracy_level, true, false); - } -#else for (auto accuracy_level : {0}) { RunTest(M, N, K, block_size, accuracy_level, false, false); RunTest(M, N, K, block_size, accuracy_level, true, false); } -#endif } } } @@ -192,174 +185,6 @@ TEST(MatMulNBits, Float16Large) { #endif -void RunSharedPrepackedWeightsTest(int64_t M, int64_t N, int64_t K, int block_size, bool is_asym, - int64_t acc_lvl) { - // (M x K) X (K x N) - - OpTester test("MatMulNBits", 1, kMSDomain); - test.AddAttribute("accuracy_level", acc_lvl); - test.AddAttribute("block_size", int64_t(block_size)); - test.AddAttribute("bits", QBits); - test.AddAttribute("N", N); - test.AddAttribute("K", K); - - std::vector input0_vals(M * K); - float fv = -135.f; - for (auto& f : input0_vals) { - f = fv / 127; - fv++; - if (fv > 135.f) { - fv = -135.f; - } - } - - size_t kblks = K / block_size; - std::vector input1_vals(N * K / 2); - for (size_t i = 0; i < input1_vals.size(); i++) { - input1_vals[i] = uint8_t(i); - } - std::vector input2_vals(N * kblks, 0.002f); - for (size_t i = 0; i < N * kblks; i++) { - input2_vals[i] += (i % 100) * 0.00003f; - } - std::vector input3_vals(N * kblks / 2, static_cast(0x88)); - - std::vector input1_f_vals(N * K); - if (is_asym) { - for (size_t i = 0; i < N * kblks; i += 2) { - input3_vals[i / 2] = static_cast(i + 1); - } - for (int64_t i = 0; i < K; i += 2) { - for (int64_t j = 0; j < N; j++) { - auto srcv = input1_vals[j * K / 2 + i / 2]; - auto koff = i % (block_size * 2); - auto zpv = input3_vals[j * kblks / 2 + i / block_size / 2]; - auto zp0 = koff < block_size ? (zpv & 0xf) - 8 : ((zpv & 0xf0) >> 4) - 8; - auto src0 = (srcv & 0xf) - 8; - auto src1 = ((srcv & 0xf0) >> 4) - 8; - auto scale0 = input2_vals[j * kblks + i / block_size]; - auto scale1 = input2_vals[j * kblks + (i + 1) / block_size]; - input1_f_vals[i * N + j] = (static_cast(src0) - zp0) * scale0; - input1_f_vals[(i + 1) * N + j] = (static_cast(src1) - zp0) * scale1; - } - } - } else { - for (int64_t i = 0; i < K; i += 2) { - for (int64_t j = 0; j < N; j++) { - auto srcv = input1_vals[j * K / 2 + i / 2]; - auto src0 = (srcv & 0xf) - 8; - auto src1 = ((srcv & 0xf0) >> 4) - 8; - auto scale0 = input2_vals[j * kblks + i / block_size]; - auto scale1 = input2_vals[j * kblks + (i + 1) / block_size]; - input1_f_vals[i * N + j] = static_cast(src0) * scale0; - input1_f_vals[(i + 1) * N + j] = static_cast(src1) * scale1; - } - } - } - - std::vector expected_vals(M * N); - for (int64_t m = 0; m < M; m++) { - for (int64_t n = 0; n < N; n++) { - float sum = 0.0f; - for (int64_t k = 0; k < K; k++) { - sum += input0_vals[m * K + k] * input1_f_vals[k * N + n]; - } - expected_vals[m * N + n] = sum; - } - } - - test.AddInput("A", {M, K}, input0_vals, false); - - test.AddInput("B", {N, static_cast(kblks), static_cast(block_size / 2)}, input1_vals, - true); - test.AddInput("scales", {N, static_cast(kblks)}, input2_vals, true); - if (is_asym) { - test.AddInput("zero_points", {N, static_cast(kblks / 2)}, input3_vals, true); - } - test.AddOutput("Y", {M, N}, expected_vals, false); - if (acc_lvl == 4) { - test.SetOutputAbsErr("Y", 0.1f); - } - - OrtValue b, scale, zp; - Tensor::InitOrtValue(DataTypeImpl::GetType(), - TensorShape({N, static_cast(kblks), static_cast(block_size / 2)}), - input1_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); - - Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({N, static_cast(kblks)}), - input2_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), scale); - if (is_asym) { - Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({N, static_cast(kblks / 2)}), - input3_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), zp); - } - SessionOptions so; - // Set up B as a shared initializer to be shared between sessions - ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); - ASSERT_EQ(so.AddInitializer("scales", &scale), Status::OK()); - if (is_asym) { - ASSERT_EQ(so.AddInitializer("zero_points", &zp), Status::OK()); - } - - // We want all sessions running using this OpTester to be able to share pre-packed weights if applicable - test.EnableSharingOfPrePackedWeightsAcrossSessions(); - - // Pre-packing is limited just to the CPU EP for now and we will only test the CPU EP - // and we want to ensure that it is available in this build - auto cpu_ep = []() -> std::vector> { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - return execution_providers; - }; - - size_t number_of_pre_packed_weights_counter_session_1 = 0; - size_t number_of_shared_pre_packed_weights_counter = 0; - - // Session 1 - { - auto ep_vec = cpu_ep(); - test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {}, - &number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); - // Assert that no pre-packed weights have been shared thus far - ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); - } - - auto number_of_elements_in_shared_prepacked_buffers_container = test.GetNumPrePackedWeightsShared(); - // Assert that the number of elements in the shared container - // is the same as the number of weights that have been pre-packed - ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_elements_in_shared_prepacked_buffers_container); - - // On some platforms/architectures MLAS may choose to not do any pre-packing and the number of elements - // that have been pre-packed will be zero in which case we do not continue with the testing - // of "sharing" of pre-packed weights as there are no pre-packed weights to be shared at all. - if (number_of_pre_packed_weights_counter_session_1 == 0) return; - - // Session 2 - { - size_t number_of_pre_packed_weights_counter_session_2 = 0; - auto ep_vec = cpu_ep(); - test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {}, - &number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); - - // Assert that the same number of weights were pre-packed in both sessions - ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_pre_packed_weights_counter_session_2); - - // Assert that the number of pre-packed weights that were shared equals - // the number of pre-packed weights in the second session - ASSERT_EQ(number_of_pre_packed_weights_counter_session_2, - static_cast(number_of_shared_pre_packed_weights_counter)); - } -} - -#ifdef ORT_NEURAL_SPEED -TEST(MatMulNBits, SharedPrepackedWeights) { - RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, true, 1); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, false, 1); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, 1); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, 4); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 1024, false, 4); - RunSharedPrepackedWeightsTest(2, 4096, 4096, 4096, false, 4); -} -#endif } // namespace test } // namespace onnxruntime From 19952c5b3521934b8a34b6ef496cf3faa0c4eb59 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Thu, 8 Feb 2024 01:03:06 +0000 Subject: [PATCH 055/207] Add script to convert phi2 to ort-vllm compatible (#19429) ### Description 1. add option to export onnx compatiable with ort_vllm. This makes sure that onnx model only leverages on paged attn from vllm. It's intended to use internally so not mentioned in readme. 2. add details in ORT installation(https://github.com/microsoft/onnxruntime/pull/19338#discussion_r1476906190) ### Motivation and Context --------- Co-authored-by: wejoncy --- .../python/tools/symbolic_shape_infer.py | 5 + .../tools/transformers/dynamo_onnx_helper.py | 26 +++- .../tools/transformers/fusion_options.py | 1 + .../tools/transformers/models/phi2/README.md | 1 + .../models/phi2/convert_to_onnx.py | 64 +++++++-- .../tools/transformers/onnx_model_phi.py | 129 ++++++++++++++---- 6 files changed, 183 insertions(+), 43 deletions(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index e7b7074783162..66c78f80f7910 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -213,6 +213,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "NhwcConv": self._infer_NhwcConv, "PackedAttention": self._infer_PackedAttention, "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, + "PagedAttention": self._infer_PagedAttention, "PythonOp": self._infer_PythonOp, "QuantizeLinear": self._infer_QuantizeLinear, "QuickGelu": self._infer_FastGelu, @@ -470,6 +471,7 @@ def _onnx_infer_single_node(self, node): "SkipLayerNormalization", "SkipSimplifiedLayerNormalization", "PackedAttention", + "PagedAttention", "PythonOp", "MultiHeadAttention", "GroupNorm", @@ -2412,6 +2414,9 @@ def _infer_SkipLayerNormalization(self, node): # noqa: N802 def _infer_GroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_PagedAttention(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + def _infer_GroupQueryAttention(self, node): # noqa: N802 output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type diff --git a/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py b/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py index bca5ace916082..9a66afe3ad4f9 100644 --- a/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py +++ b/onnxruntime/python/tools/transformers/dynamo_onnx_helper.py @@ -73,20 +73,32 @@ def unroll_function(self, func_name: str) -> None: return self.update_edges(edge_mapping) - def remove_dropout_layer(self) -> None: + def remove_function(self, func_name: str, input_id: int, output_id: int) -> None: """ - Removes the dropout layer in the model. + Removes the function in the model. """ - logging.info("Removing dropout layer...") edge_mapping = {} nodes_to_remove = [] for node in self.model.graph.node: - if node.op_type.find("Dropout") != -1: - assert len(node.input) == 1 - assert len(node.output) == 1 - edge_mapping[node.output[0]] = node.input[0] + if node.op_type.find(func_name) != -1: + edge_mapping[node.input[input_id]] = node.output[output_id] nodes_to_remove.append(node) for node in nodes_to_remove: self.model.graph.node.remove(node) self.update_edges(edge_mapping) + + def remove_dropout_layer(self) -> None: + """ + Removes the dropout layer in the model. + """ + logging.info("Removing dropout layer...") + self.remove_function("Dropout", 0, 0) + + def remove_lm_head_layer(self) -> None: + """ + Removes the LM head layer in the model. + """ + logging.info("Removing LM head layer...") + # bugbug: need to copy the right vi over + self.remove_function("Linear_lm_head", 2, 0) diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index c65464a3069c5..4c43e4487bfb1 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -24,6 +24,7 @@ class AttentionOpType(Enum): Attention = "Attention" MultiHeadAttention = "MultiHeadAttention" GroupQueryAttention = "GroupQueryAttention" + PagedAttention = "PagedAttention" def __str__(self): return self.value diff --git a/onnxruntime/python/tools/transformers/models/phi2/README.md b/onnxruntime/python/tools/transformers/models/phi2/README.md index 526fdc3dd7863..da62bba0f02fb 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/README.md +++ b/onnxruntime/python/tools/transformers/models/phi2/README.md @@ -11,6 +11,7 @@ To export ONNX, PyTorch version 2.2.0 or higher is required. The [official websi **There are two options to run the conversion script:**\ _From source:_ ```bash +# Default onnxruntime package is built with CUDA 11.8. For CUDA 12.x, refer to https://onnxruntime.ai/docs/install/#python-installs pip install onnxruntime-gpu==1.17.0 # or onnxruntime==1.17.0 if using cpu git clone git@github.com:microsoft/onnxruntime.git cd onnxruntime/onnxruntime/python/tools/transformers diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index ac3ca40e41be0..b7881d064067d 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -136,14 +136,18 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str): self.precision == Precision.FLOAT16 or self.precision == Precision.INT4 ) and self.attn_op_type != AttentionOpType.MultiHeadAttention: # We keep last three layers of Attention as float32 or bfloat16 to avoid overflow. - node_block_list = [ - "GroupQueryAttention_29", - "GroupQueryAttention_30", - "GroupQueryAttention_31", - "Attention_29", - "Attention_30", - "Attention_31", - ] + node_block_list = ( + [ + "GroupQueryAttention_29", + "GroupQueryAttention_30", + "GroupQueryAttention_31", + "Attention_29", + "Attention_30", + "Attention_31", + ] + if self.attn_op_type != AttentionOpType.PagedAttention + else [] + ) # TODO: temp setting for paged attention logging.info("Converting onnx model to float16/bfloat16...") optimizer.convert_float_to_float16( keep_io_types=False, @@ -220,6 +224,20 @@ def parse_arguments(): help="Generate int4 ONNX model for Nvidia GPUs with CUDA architecture SM=80~89", ) + parser.add_argument( + "--fp16_vllm", + required=False, + action="store_true", + help="Generate fp16 ONNX model for ORT VLLM", + ) + + parser.add_argument( + "--int4_vllm", + required=False, + action="store_true", + help="Generate int4 ONNX model for ORT VLLM", + ) + parser.add_argument( "--overwrite", required=False, @@ -336,6 +354,16 @@ def main(): Precision.INT4, os.path.join(output_dir, "phi2_decoder_int4_gpu_sm8x.onnx"), ), + "fp16_vllm": ( + AttentionOpType.PagedAttention, + Precision.FLOAT16, + os.path.join(output_dir, "phi2_decoder_fp16_vllm.onnx"), + ), + "int4_vllm": ( + AttentionOpType.PagedAttention, + Precision.INT4, + os.path.join(output_dir, "phi2_decoder_int4_vllm.onnx"), + ), } if not args.skip_export: @@ -403,6 +431,22 @@ def run_optimize_phi2_onnx( ) ) + if args.fp16_vllm: + processes.append( + Process( + target=run_optimize_phi2_onnx, + args=(converter, original_onnx_path, *model_type_to_args["fp16_vllm"]), + ) + ) + + if args.int4_vllm: + processes.append( + Process( + target=run_optimize_phi2_onnx, + args=(converter, original_onnx_path, *model_type_to_args["int4_vllm"]), + ) + ) + [p.start() for p in processes] [p.join() for p in processes] @@ -450,8 +494,8 @@ def run_optimize_phi2_onnx( device_id=args.device_id, packed_kv=True, ) - if args.fp32_cpu or args.int4_cpu: - raise NotImplementedError("CPU inference example is not implemented yet.") + if args.fp32_cpu or args.int4_cpu or args.fp16_vllm or args.int4_vllm: + raise NotImplementedError("CPU/vllm inference example is not implemented yet.") if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/onnx_model_phi.py b/onnxruntime/python/tools/transformers/onnx_model_phi.py index df8830b0d0495..e68c3120e3f09 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_phi.py +++ b/onnxruntime/python/tools/transformers/onnx_model_phi.py @@ -255,6 +255,30 @@ def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num ) return [node] + def paged_attn( + self, + inputs: List[str], + outputs: List[str], + prefix: str = "", + num_heads=32, + head_size=80, + scale=0.11180339753627777, + ): + assert len(inputs) == 6 + assert len(outputs) == 1 + node = helper.make_node( + "PagedAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "PagedAttention", + domain="vllm.ort.ext", + num_heads=num_heads, + num_kv_heads=num_heads, + head_size=head_size, + scale=scale, + ) + return [node] + class Phi2PreProcessor(DynamoOnnxHelper): def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): @@ -288,32 +312,46 @@ def simplify_phi2_op_type(self): def process_graph_io(self, attn_op_type: AttentionOpType): self.use_attn = attn_op_type == AttentionOpType.Attention + self.use_vllm = attn_op_type == AttentionOpType.PagedAttention graph = self.model.graph new_inputs = [] for vi in graph.input: if "input_ids" in vi.name: vi_iid = helper.make_tensor_value_info( vi.name, - elem_type=TensorProto.INT32, + elem_type=TensorProto.INT32 if not self.use_vllm else TensorProto.INT64, shape=["batch_size", "seq_len"], ) - vi_pid = helper.make_tensor_value_info( + vi_step = helper.make_tensor_value_info( "step", elem_type=TensorProto.INT64, shape=[1], ) + vi_pid = helper.make_tensor_value_info( + "position_ids", + elem_type=TensorProto.INT64, + shape=["batch_size", "seq_len"], + ) vi_mask = helper.make_tensor_value_info( "attention_mask", elem_type=TensorProto.INT32, shape=["batch_size", "seq_len"], ) - new_inputs.extend([vi_iid, vi_pid, vi_mask]) - if not self.use_attn: - if "past_key" in vi.name or "past_value" in vi.name: + vi_meta = helper.make_tensor_value_info( + "input_metadata", + elem_type=TensorProto.INT64, + shape=[1], + ) + new_inputs.extend([vi_iid, vi_step, vi_mask]) if not self.use_vllm else new_inputs.extend( + [vi_iid, vi_pid, vi_meta] + ) + if self.use_attn: + if "past_key" in vi.name: vi_cache = helper.make_tensor_value_info( - vi.name, + vi.name.replace("past_key", "past"), elem_type=vi.type.tensor_type.elem_type, shape=[ + 2, "batch_size", self.num_attention_heads, "past_seq_len", @@ -321,13 +359,32 @@ def process_graph_io(self, attn_op_type: AttentionOpType): ], ) new_inputs.extend([vi_cache]) - else: + elif self.use_vllm: if "past_key" in vi.name: vi_cache = helper.make_tensor_value_info( - vi.name.replace("past_key", "past"), + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=["num_blocks", "num_heads", "head_size_x", "block_size", "block_x"], + ) + new_inputs.extend([vi_cache]) + if "past_value" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "num_blocks", + "num_heads", + "head_size", + "block_size", + ], + ) + new_inputs.extend([vi_cache]) + else: + if "past_key" in vi.name or "past_value" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, elem_type=vi.type.tensor_type.elem_type, shape=[ - 2, "batch_size", self.num_attention_heads, "past_seq_len", @@ -344,19 +401,7 @@ def process_graph_io(self, attn_op_type: AttentionOpType): if i == 0: new_outputs.extend([vi]) else: - if not self.use_attn: - vi_cache = helper.make_tensor_value_info( - vi.name, - elem_type=vi.type.tensor_type.elem_type, - shape=[ - "batch_size", - self.num_attention_heads, - "total_seq_len", - self.hidden_size // self.num_attention_heads, - ], - ) - new_outputs.extend([vi_cache]) - else: + if self.use_attn: if "present_key" in vi.name: vi_cache = helper.make_tensor_value_info( vi.name.replace("present_key", "present"), @@ -370,6 +415,20 @@ def process_graph_io(self, attn_op_type: AttentionOpType): ], ) new_outputs.extend([vi_cache]) + elif self.use_vllm: + pass + else: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "batch_size", + self.num_attention_heads, + "total_seq_len", + self.hidden_size // self.num_attention_heads, + ], + ) + new_outputs.extend([vi_cache]) graph.ClearField("output") graph.output.extend(new_outputs) @@ -385,6 +444,8 @@ def preprocess_onnx(self, attn_op_type: AttentionOpType): self.update_edges(self.phi2_edge_dict) self.simplify_phi2_op_type() self.remove_dropout_layer() + if attn_op_type == AttentionOpType.PagedAttention: + self.remove_lm_head_layer() self.process_graph_io(attn_op_type) @@ -694,7 +755,9 @@ def fuse( layer_known_edges_names.extend( [attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias] ) - layer_known_edges_names.extend(["attention_mask", "step", "seqlens_k", "total_sequence_length"]) + layer_known_edges_names.extend( + ["attention_mask", "step", "seqlens_k", "total_sequence_length", "input_metadata", "position_ids"] + ) subgraph_nodes = [] subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"])) @@ -708,8 +771,9 @@ def fuse( subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_")) subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_")) subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_")) - subgraph_nodes.extend(self.rotary(["query", "step", cos_cache, sin_cache], ["query_rot"], "Q_")) - subgraph_nodes.extend(self.rotary(["key", "step", cos_cache, sin_cache], ["key_rot"], "K_")) + pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step" + subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_")) + subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_")) if self.attn_op_type == AttentionOpType.MultiHeadAttention: subgraph_nodes.extend( self.mha( @@ -740,6 +804,13 @@ def fuse( self.model.add_initializer( numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name ) + elif self.attn_op_type == AttentionOpType.PagedAttention: + subgraph_nodes.extend( + self.paged_attn( + ["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "input_metadata"], + ["attn_out"], + ) + ) else: past_name = f"past_{layer_id}" present_name = f"present_{layer_id}" @@ -798,6 +869,7 @@ def get_fused_operator_statistics(self): "Attention", "MultiHeadAttention", "GroupQueryAttention", + "PagedAttention", "Gelu", "BiasGelu", "FastGelu", @@ -821,7 +893,12 @@ def is_fully_optimized(self, fused_op_count=None): def op_count(op_name: str): return fused_op_count.get(op_name) or 0 - attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("GroupQueryAttention") + attention = ( + op_count("Attention") + + op_count("MultiHeadAttention") + + op_count("GroupQueryAttention") + + op_count("PagedAttention") + ) gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu") layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization") From 3b1b18347ceb74dec8878fcf689e298ad9bc9d95 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 9 Feb 2024 03:08:41 +1000 Subject: [PATCH 056/207] Check for invalid combination of python + minimal build in build.py (#19463) ### Description Python bindings aren't supported in a minimal build. Check in build.py so user gets a better error message. ### Motivation and Context #19422 --- tools/ci_build/build.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index b2040b24ffaa2..8567d595b7429 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2536,11 +2536,15 @@ def main(): if args.build_nuget and cross_compiling: raise BuildError("Currently nuget package creation is not supported while cross-compiling") - if args.enable_pybind and args.disable_rtti: - raise BuildError("Python bindings use typeid so you can't disable RTTI") + if args.enable_pybind: + if args.disable_rtti: + raise BuildError("Python bindings use typeid so you can't disable RTTI") - if args.enable_pybind and args.disable_exceptions: - raise BuildError("Python bindings require exceptions to be enabled.") + if args.disable_exceptions: + raise BuildError("Python bindings require exceptions to be enabled.") + + if args.minimal_build is not None: + raise BuildError("Python bindings are not supported in a minimal build.") if args.nnapi_min_api: if not args.use_nnapi: From 148f54c6ea8b90dfc1dbf82197ac79c2102054b4 Mon Sep 17 00:00:00 2001 From: ivberg Date: Thu, 8 Feb 2024 11:28:05 -0800 Subject: [PATCH 057/207] Add capturestate / rundown ETW support logging for session and provider options (#19397) ### Description Add capturestate / rundown ETW support logging for session and provider options. ### Motivation and Context Follow-up to #16259 and #18882 This is very useful when you have longer running ONNX sessions which will be the case for a lot of AI workloads. That means ETW tracing may start minutes or hours after a process & session has been established. When a trace is captured, you would want to know the state of ONNX at that time. The state for ONNX is session and config options so that they show up in the trace. Tested with xperf and ORT xperf -start ort -on 3a26b1ff-7484-7484-7484-15261f42614d xperf -capturestate ort 3a26b1ff-7484-7484-7484-15261f42614d <--- Run this after session has been up for some time xperf -stop ort -d .\ort.etl <- Trace will now also have rundown events Also these will show if you use WPR [CaptureStateOnSave ](https://learn.microsoft.com/en-us/windows-hardware/test/wpt/capturestateonsave) --- .../core/framework/execution_providers.h | 55 ++++++++++++-- .../core/platform/windows/telemetry.cc | 24 +++++-- onnxruntime/core/platform/windows/telemetry.h | 15 +++- onnxruntime/core/session/inference_session.cc | 72 ++++++++++++++++--- onnxruntime/core/session/inference_session.h | 14 +++- .../core/session/provider_registration.cc | 4 ++ .../logging/HowToValidateEtwSinkOutput.md | 6 +- 7 files changed, 164 insertions(+), 26 deletions(-) diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 61147e4367876..dc45cad692b6e 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -3,7 +3,6 @@ #pragma once -// #include #include #include #include @@ -14,7 +13,9 @@ #include "core/common/logging/logging.h" #ifdef _WIN32 #include +#include #include "core/platform/tracing.h" +#include "core/platform/windows/telemetry.h" #endif namespace onnxruntime { @@ -44,6 +45,49 @@ class ExecutionProviders { exec_provider_options_[provider_id] = providerOptions; #ifdef _WIN32 + LogProviderOptions(provider_id, providerOptions, false); + + // Register callback for ETW capture state (rundown) + WindowsTelemetry::RegisterInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + for (size_t i = 0; i < exec_providers_.size(); ++i) { + const auto& provider_id = exec_provider_ids_[i]; + + auto it = exec_provider_options_.find(provider_id); + if (it != exec_provider_options_.end()) { + const auto& options = it->second; + + LogProviderOptions(provider_id, options, true); + } + } + } + }); +#endif + + exec_provider_ids_.push_back(provider_id); + exec_providers_.push_back(p_exec_provider); + return Status::OK(); + } + +#ifdef _WIN32 + void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) { for (const auto& config_pair : providerOptions) { TraceLoggingWrite( telemetry_provider_handle, @@ -52,14 +96,11 @@ class ExecutionProviders { TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(provider_id.c_str(), "ProviderId"), TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value")); + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBool(captureState, "isCaptureState")); } -#endif - - exec_provider_ids_.push_back(provider_id); - exec_providers_.push_back(p_exec_provider); - return Status::OK(); } +#endif const IExecutionProvider* Get(const onnxruntime::Node& node) const { return Get(node.GetExecutionProviderType()); diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index a9849873fd060..654281d526e4d 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/platform/windows/telemetry.h" +#include "core/platform/ort_mutex.h" #include "core/common/logging/logging.h" #include "onnxruntime_config.h" @@ -63,6 +64,8 @@ bool WindowsTelemetry::enabled_ = true; uint32_t WindowsTelemetry::projection_ = 0; UCHAR WindowsTelemetry::level_ = 0; UINT64 WindowsTelemetry::keyword_ = 0; +std::vector WindowsTelemetry::callbacks_; +OrtMutex WindowsTelemetry::callbacks_mutex_; WindowsTelemetry::WindowsTelemetry() { std::lock_guard lock(mutex_); @@ -104,6 +107,11 @@ UINT64 WindowsTelemetry::Keyword() const { // return etw_status_; // } +void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) { + std::lock_guard lock(callbacks_mutex_); + callbacks_.push_back(callback); +} + void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, @@ -112,15 +120,21 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ ULONGLONG MatchAllKeyword, _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, _In_opt_ PVOID CallbackContext) { - (void)SourceId; - (void)MatchAllKeyword; - (void)FilterData; - (void)CallbackContext; - std::lock_guard lock(provider_change_mutex_); enabled_ = (IsEnabled != 0); level_ = Level; keyword_ = MatchAnyKeyword; + + InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); +} + +void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + std::lock_guard lock(callbacks_mutex_); + for (const auto& callback : callbacks_) { + callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + } } void WindowsTelemetry::EnableTelemetryEvents() const { diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index c3798943d491d..cdb186e9ed703 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -2,12 +2,14 @@ // Licensed under the MIT License. #pragma once +#include +#include + #include "core/platform/telemetry.h" #include #include #include "core/platform/ort_mutex.h" #include "core/platform/windows/TraceLoggingConfig.h" -#include namespace onnxruntime { @@ -58,16 +60,27 @@ class WindowsTelemetry : public Telemetry { void LogExecutionProviderEvent(LUID* adapterLuid) const override; + using EtwInternalCallback = std::function; + + static void RegisterInternalCallback(const EtwInternalCallback& callback); + private: static OrtMutex mutex_; static uint32_t global_register_count_; static bool enabled_; static uint32_t projection_; + static std::vector callbacks_; + static OrtMutex callbacks_mutex_; static OrtMutex provider_change_mutex_; static UCHAR level_; static ULONGLONG keyword_; + static void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext); + static void NTAPI ORT_TL_EtwEnableCallback( _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cae714954f72f..b045f30a59797 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -46,10 +46,11 @@ #include "core/optimizer/transformer_memcpy.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" #include "core/platform/Barrier.h" -#include "core/platform/ort_mutex.h" #include "core/platform/threadpool.h" #ifdef _WIN32 #include "core/platform/tracing.h" +#include +#include "core/platform/windows/telemetry.h" #endif #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/cpu_execution_provider.h" @@ -241,6 +242,10 @@ Status GetMinimalBuildOptimizationHandling( } // namespace std::atomic InferenceSession::global_session_id_{1}; +std::map InferenceSession::active_sessions_; +#ifdef _WIN32 +OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_ +#endif static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options, const ONNX_NAMESPACE::ModelProto& model_proto, @@ -351,17 +356,47 @@ void InferenceSession::SetLoggingManager(const SessionOptions& session_options, void InferenceSession::ConstructorCommon(const SessionOptions& session_options, const Environment& session_env) { auto status = FinalizeSessionOptions(session_options, model_proto_, is_model_proto_parsed_, session_options_); - // a monotonically increasing session id for use in telemetry - session_id_ = global_session_id_.fetch_add(1); ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ", status.ErrorMessage()); + // a monotonically increasing session id for use in telemetry + session_id_ = global_session_id_.fetch_add(1); + +#ifdef _WIN32 + std::lock_guard lock(active_sessions_mutex_); + active_sessions_[global_session_id_++] = this; + + // Register callback for ETW capture state (rundown) + WindowsTelemetry::RegisterInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + LogAllSessions(); + } + }); +#endif + SetLoggingManager(session_options, session_env); // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. - TraceSessionOptions(session_options); + TraceSessionOptions(session_options, false); #if !defined(ORT_MINIMAL_BUILD) // Update the number of steps for the graph transformer manager using the "finalized" session options @@ -475,7 +510,9 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, telemetry_ = {}; } -void InferenceSession::TraceSessionOptions(const SessionOptions& session_options) { +void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState) { + (void)captureState; // Otherwise Linux build error + LOGS(*session_logger_, INFO) << session_options; #ifdef _WIN32 @@ -498,7 +535,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingUInt8(static_cast(session_options.graph_optimization_level), "graph_optimization_level"), TraceLoggingBoolean(session_options.use_per_session_threads, "use_per_session_threads"), TraceLoggingBoolean(session_options.thread_pool_allow_spinning, "thread_pool_allow_spinning"), - TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute")); + TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute"), + TraceLoggingBoolean(captureState, "isCaptureState")); TraceLoggingWrite( telemetry_provider_handle, @@ -511,7 +549,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingInt32(session_options.intra_op_param.dynamic_block_base_, "dynamic_block_base_"), TraceLoggingUInt32(session_options.intra_op_param.stack_size, "stack_size"), TraceLoggingString(!session_options.intra_op_param.affinity_str.empty() ? session_options.intra_op_param.affinity_str.c_str() : "", "affinity_str"), - TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero")); + TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero"), + TraceLoggingBoolean(captureState, "isCaptureState")); for (const auto& config_pair : session_options.config_options.configurations) { TraceLoggingWrite( @@ -520,7 +559,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value")); + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBoolean(captureState, "isCaptureState")); } #endif } @@ -616,6 +656,12 @@ InferenceSession::~InferenceSession() { } } + // Unregister the session +#ifdef _WIN32 + std::lock_guard lock(active_sessions_mutex_); +#endif + active_sessions_.erase(global_session_id_); + #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT if (session_activity_started_) TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity"); @@ -3070,4 +3116,14 @@ IOBinding* SessionIOBinding::Get() { return binding_.get(); } +#ifdef _WIN32 +void InferenceSession::LogAllSessions() { + std::lock_guard lock(active_sessions_mutex_); + for (const auto& session_pair : active_sessions_) { + InferenceSession* session = session_pair.second; + TraceSessionOptions(session->session_options_, true); + } +} +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 96db49aabdaf6..f8211bfd2dd4e 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -21,11 +22,12 @@ #include "core/framework/session_state.h" #include "core/framework/tuning_results.h" #include "core/framework/framework_provider_common.h" +#include "core/framework/session_options.h" #include "core/graph/basic_types.h" #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" #include "core/optimizer/insert_cast_transformer.h" -#include "core/framework/session_options.h" +#include "core/platform/ort_mutex.h" #ifdef ENABLE_LANGUAGE_INTEROP_OPS #include "core/language_interop_ops/language_interop_ops.h" #endif @@ -119,6 +121,10 @@ class InferenceSession { }; using InputOutputDefMetaMap = InlinedHashMap; + static std::map active_sessions_; +#ifdef _WIN32 + static OrtMutex active_sessions_mutex_; // Protects access to active_sessions_ +#endif public: #if !defined(ORT_MINIMAL_BUILD) @@ -642,7 +648,7 @@ class InferenceSession { void InitLogger(logging::LoggingManager* logging_manager); - void TraceSessionOptions(const SessionOptions& session_options); + void TraceSessionOptions(const SessionOptions& session_options, bool captureState); [[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, const TensorShape& expected_shape, const char* input_output_moniker) const; @@ -679,6 +685,10 @@ class InferenceSession { */ void ShrinkMemoryArenas(gsl::span arenas_to_shrink); +#ifdef _WIN32 + void LogAllSessions(); +#endif + #if !defined(ORT_MINIMAL_BUILD) virtual common::Status AddPredefinedTransformers( GraphTransformerManager& transformer_manager, diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index ade1d96d617fb..17a955ba8ce1a 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -90,6 +90,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, (std::string(provider_name) + " execution provider is not supported in this build. ").c_str()); }; + for (const auto& config_pair : provider_options) { + ORT_THROW_IF_ERROR(options->value.config_options.AddConfigEntry((std::string(provider_name) + ":" + config_pair.first).c_str(), config_pair.second.c_str())); + } + if (strcmp(provider_name, "DML") == 0) { #if defined(USE_DML) options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options)); diff --git a/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md b/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md index 59fe946b929f2..309b474c016c9 100644 --- a/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md +++ b/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md @@ -3,13 +3,13 @@ The ETW Sink (ONNXRuntimeTraceLoggingProvider) allows ONNX semi-structured printf style logs to be output via ETW. ETW makes it easy and useful to only enable and listen for events with great performance, and when you need them instead of only at compile time. -Therefore ONNX will preserve any existing loggers and log severity [provided at compile time](docs/FAQ.md?plain=1#L7). +Therefore ONNX will preserve any existing loggers and log severity [provided at compile time](/docs/FAQ.md?plain=1#L7). However, when the provider is enabled a new ETW logger sink will also be added and the severity separately controlled via ETW dynamically. - Provider GUID: 929DD115-1ECB-4CB5-B060-EBD4983C421D -- Keyword: Logs (0x2) keyword per [logging.h](include\onnxruntime\core\common\logging\logging.h) -- Level: 1-5 ([CRITICAL through VERBOSE](https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_descriptor)) [mapping](onnxruntime\core\platform\windows\logging\etw_sink.cc) to [ONNX severity](include\onnxruntime\core\common\logging\severity.h) in an intuitive manner +- Keyword: Logs (0x2) keyword per [logging.h](/include/onnxruntime/core/common/logging/logging.h) +- Level: 1-5 ([CRITICAL through VERBOSE](https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_descriptor)) [mapping](/onnxruntime/core/platform/windows/logging/etw_sink.cc) to [ONNX severity](/include/onnxruntime/core/common/logging/severity.h) in an intuitive manner Notes: - The ETW provider must be enabled prior to session creation, as that as when internal logging setup is complete From 03be65e064c6cc2ae5a2169b13dcd675c1dc7cf8 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Thu, 8 Feb 2024 15:56:48 -0800 Subject: [PATCH 058/207] [js/web] fix types exports in package.json (#19458) ### Description Since TypeScript v4.7, types need to specify inside "exports" field when it is available. This PR appends types just before each "default" (which is required by spec to be the last item). Fixes #19403. --- js/web/package.json | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/js/web/package.json b/js/web/package.json index a502c2b6b032d..55c3a3238bafc 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -69,11 +69,14 @@ "exports": { ".": { "node": "./dist/ort.node.min.js", + "types": "./types.d.ts", "default": { "import": "./dist/esm/ort.min.js", "require": "./dist/cjs/ort.min.js", + "types": "./types.d.ts", "default": { "development": "./dist/ort.js", + "types": "./types.d.ts", "default": "./dist/ort.min.js" } } @@ -81,34 +84,41 @@ "./experimental": { "import": "./dist/esm/ort.all.min.js", "require": "./dist/cjs/ort.all.min.js", + "types": "./types.d.ts", "default": { "development": "./dist/ort.all.js", + "types": "./types.d.ts", "default": "./dist/ort.all.min.js" } }, "./wasm": { "import": "./dist/esm/ort.wasm.min.js", "require": "./dist/cjs/ort.wasm.min.js", + "types": "./types.d.ts", "default": "./dist/ort.wasm.min.js" }, "./wasm-core": { "import": "./dist/esm/ort.wasm-core.min.js", "require": "./dist/cjs/ort.wasm-core.min.js", + "types": "./types.d.ts", "default": "./dist/ort.wasm-core.min.js" }, "./webgl": { "import": "./dist/esm/ort.webgl.min.js", "require": "./dist/cjs/ort.webgl.min.js", + "types": "./types.d.ts", "default": "./dist/ort.webgl.min.js" }, "./webgpu": { "import": "./dist/esm/ort.webgpu.min.js", "require": "./dist/cjs/ort.webgpu.min.js", + "types": "./types.d.ts", "default": "./dist/ort.webgpu.min.js" }, "./training": { "import": "./dist/esm/ort.training.wasm.min.js", "require": "./dist/cjs/ort.training.wasm.min.js", + "types": "./types.d.ts", "default": "./dist/ort.training.wasm.min.js" } }, From 3d2ddf96e314ad97ee4377a730c465ffff4c8723 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 8 Feb 2024 16:08:27 -0800 Subject: [PATCH 059/207] Bump ruff linter to 0.2.1 (#19471) ### Motivation and Context Include new lint rules --- docs/python/conf.py | 8 ++--- .../onnxruntime_inference_collection.py | 2 +- .../python/tools/qnn/add_trans_cast.py | 8 ++--- .../python/tools/quantization/calibrate.py | 26 +++++++-------- .../tools/quantization/fusions/fusion.py | 4 +-- .../tools/quantization/onnx_quantizer.py | 10 ++---- .../tools/quantization/qdq_quantizer.py | 22 +++---------- .../python/tools/symbolic_shape_infer.py | 32 +++++++++++-------- .../python/tools/tensorrt/perf/benchmark.py | 16 ++++------ .../python/tools/transformers/benchmark.py | 8 ++--- .../tools/transformers/benchmark_helper.py | 6 +--- .../tools/transformers/bert_test_data.py | 4 +-- .../transformers/fusion_skip_group_norm.py | 2 +- .../transformers/models/bert/eval_squad.py | 2 +- .../transformers/models/gpt2/gpt2_parity.py | 4 +-- .../transformers/models/gpt2/gpt2_tester.py | 4 +-- .../models/longformer/benchmark_longformer.py | 4 +-- .../models/stable_diffusion/engine_builder.py | 2 +- .../models/whisper/whisper_helper.py | 8 ++--- .../tools/transformers/onnx_exporter.py | 9 ++---- .../python/tools/transformers/onnx_model.py | 6 ++-- .../python/tools/transformers/profiler.py | 8 ++--- onnxruntime/test/providers/cpu/rnn/GRU.py | 4 +-- onnxruntime/test/providers/cpu/rnn/LSTM.py | 16 +++++----- .../adamw_test/adamw_test_data_generator.py | 8 ++--- .../sgd_test/sgd_test_data_generator.py | 4 +-- .../python/training/ort_triton/_common.py | 2 +- .../_custom_autograd_function_exporter.py | 4 +-- .../ortmodule/_custom_gradient_registry.py | 2 +- .../_hierarchical_ortmodule.py | 4 +-- .../orttraining_test_model_transform.py | 4 +-- .../python/orttraining_test_ortmodule_api.py | 2 +- ...rttraining_test_ortmodule_autograd_dist.py | 2 -- .../tools/scripts/gpt2_model_transform.py | 22 ++++--------- orttraining/tools/scripts/model_transform.py | 18 ++++------- requirements-lintrunner.txt | 2 +- .../github/linux/ort_minimal/readelf_utils.py | 4 +-- tools/ci_build/op_registration_validator.py | 2 +- tools/python/fix_long_lines.py | 5 ++- tools/python/gen_opkernel_doc.py | 9 ++---- tools/python/ort_test_dir_utils.py | 5 +-- .../check_model_can_use_ort_mobile_pkg.py | 2 +- 42 files changed, 118 insertions(+), 198 deletions(-) diff --git a/docs/python/conf.py b/docs/python/conf.py index 7ab2d42aa15e1..438c21570eaac 100644 --- a/docs/python/conf.py +++ b/docs/python/conf.py @@ -2,12 +2,10 @@ # Licensed under the MIT License. # pylint: disable=C0103 -# -*- coding: utf-8 -*- -# -# Configuration file for the Sphinx documentation builder. +"""Configuration file for the Sphinx documentation builder.""" import os -import shutil # noqa: F401 +import shutil import sys sys.path.append(os.path.join(os.path.dirname(__file__), "..", "_common")) @@ -127,7 +125,5 @@ def setup(app): urllib.request.urlretrieve(url, dest) loc = os.path.split(dest)[-1] if not os.path.exists(loc): - import shutil # noqa: F811 - shutil.copy(dest, loc) return app diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 4106943e8facc..2fbd118a43ed1 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -413,7 +413,7 @@ def __init__( self._read_config_from_model = os.environ.get("ORT_LOAD_CONFIG_FROM_MODEL") == "1" # internal parameters that we don't expect to be used in general so aren't documented - disabled_optimizers = kwargs["disabled_optimizers"] if "disabled_optimizers" in kwargs else None + disabled_optimizers = kwargs.get("disabled_optimizers") try: self._create_inference_session(providers, provider_options, disabled_optimizers) diff --git a/onnxruntime/python/tools/qnn/add_trans_cast.py b/onnxruntime/python/tools/qnn/add_trans_cast.py index bd6b8701f8fb8..ced3e3519ad42 100644 --- a/onnxruntime/python/tools/qnn/add_trans_cast.py +++ b/onnxruntime/python/tools/qnn/add_trans_cast.py @@ -270,19 +270,15 @@ def main(): raise AssertionError("Error: Onnx model output: " + graph_output.name + " not exist from QNN model output.") for node in model.graph.node: - node_input_index = 0 - for node_input in node.input: + for node_input_index, node_input in enumerate(node.input): # update consumer node for graph inputs to connect to inserted node if node_input in graph_input_output_name_dic: node.input[node_input_index] = graph_input_output_name_dic[node_input] - node_input_index += 1 - node_output_index = 0 - for node_output in node.output: + for node_output_index, node_output in enumerate(node.output): # update producer node for graph outputs to connect to inserted node if node_output in graph_input_output_name_dic: node.output[node_output_index] = graph_input_output_name_dic[node_output] - node_output_index += 1 model.graph.node.extend(nodes_to_add) graph_topological_sort(model.graph) diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 77b3dce9fb004..624049b244580 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -1100,12 +1100,10 @@ def create_calibrator( calibrator = None if calibrate_method == CalibrationMethod.MinMax: # default settings for min-max algorithm - symmetric = False if "symmetric" not in extra_options else extra_options["symmetric"] - moving_average = False if "moving_average" not in extra_options else extra_options["moving_average"] - averaging_constant = 0.01 if "averaging_constant" not in extra_options else extra_options["averaging_constant"] - max_intermediate_outputs = ( - None if "max_intermediate_outputs" not in extra_options else extra_options["max_intermediate_outputs"] - ) + symmetric = extra_options.get("symmetric", False) + moving_average = extra_options.get("moving_average", False) + averaging_constant = extra_options.get("averaging_constant", 0.01) + max_intermediate_outputs = extra_options.get("max_intermediate_outputs", None) calibrator = MinMaxCalibrater( model, op_types_to_calibrate, @@ -1118,9 +1116,9 @@ def create_calibrator( ) elif calibrate_method == CalibrationMethod.Entropy: # default settings for entropy algorithm - num_bins = 128 if "num_bins" not in extra_options else extra_options["num_bins"] - num_quantized_bins = 128 if "num_quantized_bins" not in extra_options else extra_options["num_quantized_bins"] - symmetric = False if "symmetric" not in extra_options else extra_options["symmetric"] + num_bins = extra_options.get("num_bins", 128) + num_quantized_bins = extra_options.get("num_quantized_bins", 128) + symmetric = extra_options.get("symmetric", False) calibrator = EntropyCalibrater( model, op_types_to_calibrate, @@ -1132,9 +1130,9 @@ def create_calibrator( ) elif calibrate_method == CalibrationMethod.Percentile: # default settings for percentile algorithm - num_bins = 2048 if "num_bins" not in extra_options else extra_options["num_bins"] - percentile = 99.999 if "percentile" not in extra_options else extra_options["percentile"] - symmetric = True if "symmetric" not in extra_options else extra_options["symmetric"] + num_bins = extra_options.get("num_bins", 2048) + percentile = extra_options.get("percentile", 99.999) + symmetric = extra_options.get("symmetric", True) calibrator = PercentileCalibrater( model, op_types_to_calibrate, @@ -1147,8 +1145,8 @@ def create_calibrator( elif calibrate_method == CalibrationMethod.Distribution: # default settings for percentile algorithm - num_bins = 2048 if "num_bins" not in extra_options else extra_options["num_bins"] - scenario = "same" if "scenario" not in extra_options else extra_options["scenario"] + num_bins = extra_options.get("num_bins", 2048) + scenario = extra_options.get("scenario", "same") calibrator = DistributionCalibrater( model, diff --git a/onnxruntime/python/tools/quantization/fusions/fusion.py b/onnxruntime/python/tools/quantization/fusions/fusion.py index 456a75eec2f8c..b54b421226f1a 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion.py @@ -86,11 +86,9 @@ def get_node_attribute(node: onnx.NodeProto, attribute_name: str): @staticmethod def input_index(node_output: str, child_node: onnx.NodeProto) -> int: - index = 0 - for input_name in child_node.input: + for index, input_name in enumerate(child_node.input): if input_name == node_output: return index - index += 1 return -1 @staticmethod diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 898a5f70ac45e..a72d21c03a8a6 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -113,14 +113,10 @@ def __init__( "ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"] ) self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"] - self.is_weight_symmetric = ( - weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN) - if "WeightSymmetric" not in self.extra_options - else self.extra_options["WeightSymmetric"] - ) - self.is_activation_symmetric = ( - False if "ActivationSymmetric" not in self.extra_options else self.extra_options["ActivationSymmetric"] + self.is_weight_symmetric = self.extra_options.get( + "WeightSymmetric", weight_qType in (QuantType.QInt8, QuantType.QInt16, QuantType.QFLOAT8E4M3FN) ) + self.is_activation_symmetric = self.extra_options.get("ActivationSymmetric", False) self.min_real_range = self.extra_options.get("MinimumRealRange") self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 123cfe913d6e2..775a3e8b8b588 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -87,40 +87,28 @@ def __init__( # because those ops may be followed by nodes that require high resolution inputs. # Adding QDQ for those ops' output may end up with worse accuracy. # So, we don't recommend to add QDQ to node's output under such condition. - self.op_types_to_exclude_output_quantization = ( - [] - if "OpTypesToExcludeOutputQuantization" not in extra_options - else extra_options["OpTypesToExcludeOutputQuantization"] - ) + self.op_types_to_exclude_output_quantization = extra_options.get("OpTypesToExcludeOutputQuantization", []) # We do quantization on Dequantizelinear's input to remove Quantizelinear for weight as an optimization. # In some cases, for example QDQ BERT model for TensorRT, QDQ should always appear as a pair. # Therefore, we need to disable this optimization and add qdq pair to weight. - self.add_qdq_pair_to_weight = ( - False if "AddQDQPairToWeight" not in extra_options else extra_options["AddQDQPairToWeight"] - ) + self.add_qdq_pair_to_weight = extra_options.get("AddQDQPairToWeight", False) # Some scenarios do not need the bias quantized. For example, in the case of Quantization Aware Training, # quantizing the bias is not needed. This is because in QAT, all model parameters are expected to be in # floating point format. To that end, we can use the FakeQuant operator for weights and activations that # can always have QDQ pairs (by using AddQDQPairToWeight). But for biases in a quantized model, we can't use # FakeQuant because it only ever appears before a DQ (since it is quantized as int32). - self.quantize_bias = True if "QuantizeBias" not in extra_options else extra_options["QuantizeBias"] + self.quantize_bias = extra_options.get("QuantizeBias", True) # The default behavior is that multiple nodes can share a QDQ pair as their inputs. # In TRT, QDQ pair can`t be shared between nodes, so it will create dedicated QDQ pairs for each node. - self.dedicated_qdq_pair = ( - False if "DedicatedQDQPair" not in extra_options else extra_options["DedicatedQDQPair"] - ) + self.dedicated_qdq_pair = extra_options.get("DedicatedQDQPair", False) if self.dedicated_qdq_pair: self.tensor_to_its_receiving_nodes = {} # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True. - self.qdq_op_type_per_channel_support_to_axis = ( - {} - if "QDQOpTypePerChannelSupportToAxis" not in extra_options - else extra_options["QDQOpTypePerChannelSupportToAxis"] - ) + self.qdq_op_type_per_channel_support_to_axis = extra_options.get("QDQOpTypePerChannelSupportToAxis", {}) self.qdq_op_domain = ms_domain if extra_options.get("UseQDQContribOps", False) else None diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 66c78f80f7910..4b56bc1e8d828 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -350,7 +350,7 @@ def _merge_symbols(self, dims): return None if all([d == dims[0] for d in dims]): return dims[0] - merged = [self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims] + merged = [self.suggested_merge_.get(d, d) for d in dims] if all([d == merged[0] for d in merged]): assert merged[0] in self.symbolic_dims_ return merged[0] @@ -824,17 +824,21 @@ def _infer_ArrayFeatureExtractor(self, node): # noqa: N802 def _infer_symbolic_compute_ops(self, node): funcs = { "Add": lambda l: l[0] + l[1], # noqa: E741 - "Div": lambda l: int(l[0] // l[1]) # noqa: E741 - if isinstance(l[0] // l[1], float) - else l[0] // l[1], # integer div in sympy + "Div": lambda l: ( # noqa: E741 + int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1] + ), # integer div in sympy "Equal": lambda l: l[0] == l[1], # noqa: E741 "Floor": lambda l: sympy.floor(l[0]), # noqa: E741 - "Max": lambda l: l[1] # noqa: E741 - if is_literal(l[0]) and int(l[0]) < -self.int_max_ - else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])), - "Min": lambda l: l[1] # noqa: E741 - if is_literal(l[0]) and int(l[0]) > self.int_max_ - else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])), + "Max": lambda l: ( # noqa: E741 + l[1] + if is_literal(l[0]) and int(l[0]) < -self.int_max_ + else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])) + ), + "Min": lambda l: ( # noqa: E741 + l[1] + if is_literal(l[0]) and int(l[0]) > self.int_max_ + else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])) + ), "Mul": lambda l: int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1], # noqa: E741 "Sub": lambda l: l[0] - l[1], # noqa: E741 "Where": lambda l: l[1] if l[0] else l[2], # noqa: E741 @@ -1476,9 +1480,11 @@ def _infer_aten_group_norm(self, node): output_dtype, [ N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0)), - as_scalar(group) - if group is not None - else str(self._new_symbolic_dim_from_output(node, i, 1)), + ( + as_scalar(group) + if group is not None + else str(self._new_symbolic_dim_from_output(node, i, 1)) + ), ], ) ) diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark.py b/onnxruntime/python/tools/tensorrt/perf/benchmark.py index b33491b356e86..20bb8a71dc35f 100644 --- a/onnxruntime/python/tools/tensorrt/perf/benchmark.py +++ b/onnxruntime/python/tools/tensorrt/perf/benchmark.py @@ -1575,15 +1575,13 @@ def output_metrics(model_to_metrics, csv_filename): for value in results: row = [ value["model_name"], - value["ratio_of_ops_in_cuda_not_fallback_cpu"] - if "ratio_of_ops_in_cuda_not_fallback_cpu" in value - else " ", - value["total_ops_in_trt"] if "total_ops_in_trt" in value else " ", - value["total_ops"] if "total_ops" in value else " ", - value["ratio_of_ops_in_trt"] if "ratio_of_ops_in_trt" in value else " ", - value["total_trt_execution_time"] if "total_trt_execution_time" in value else " ", - value["total_execution_time"] if "total_execution_time" in value else " ", - value["ratio_of_execution_time_in_trt"] if "ratio_of_execution_time_in_trt" in value else " ", + value.get("ratio_of_ops_in_cuda_not_fallback_cpu", " "), + value.get("total_ops_in_trt", " "), + value.get("total_ops", " "), + value.get("ratio_of_ops_in_trt", " "), + value.get("total_trt_execution_time", " "), + value.get("total_execution_time", " "), + value.get("ratio_of_execution_time_in_trt", " "), ] csv_writer.writerow(row) diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index f506516442b1e..22f2cfb8a01ca 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -344,9 +344,7 @@ def run_pytorch( else: tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = ( - tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 - ) + max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) logger.debug(f"Model {model}") logger.debug(f"Number of parameters {model.num_parameters()}") @@ -498,9 +496,7 @@ def run_tensorflow( tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = ( - tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 - ) + max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) for batch_size in batch_sizes: if batch_size <= 0: diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index b6f7a44450c62..c5edae42590fd 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -341,11 +341,7 @@ def inference_ort_with_io_binding( # Bind inputs to device for name in ort_inputs: np_input = torch.from_numpy(ort_inputs[name]).to(device) - input_type = ( - IO_BINDING_DATA_TYPE_MAP[str(ort_inputs[name].dtype)] - if str(ort_inputs[name].dtype) in IO_BINDING_DATA_TYPE_MAP - else data_type - ) + input_type = IO_BINDING_DATA_TYPE_MAP.get(str(ort_inputs[name].dtype), data_type) io_binding.bind_input( name, np_input.device.type, diff --git a/onnxruntime/python/tools/transformers/bert_test_data.py b/onnxruntime/python/tools/transformers/bert_test_data.py index 84ecae1907cd3..aa82e047df328 100644 --- a/onnxruntime/python/tools/transformers/bert_test_data.py +++ b/onnxruntime/python/tools/transformers/bert_test_data.py @@ -174,12 +174,10 @@ def output_test_data(directory: str, inputs: Dict[str, np.ndarray]): else: print("Warning: directory %s existed. Files will be overwritten." % directory) - index = 0 - for name, data in inputs.items(): + for index, (name, data) in enumerate(inputs.items()): tensor = numpy_helper.from_array(data, name) with open(os.path.join(directory, f"input_{index}.pb"), "wb") as file: file.write(tensor.SerializeToString()) - index += 1 def fake_test_data( diff --git a/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py b/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py index df80acbd97807..676052f747967 100644 --- a/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_skip_group_norm.py @@ -147,7 +147,7 @@ def match_bias_path(self, node, input_name_to_nodes, output_name_to_node): def match_transpose_from_nhwc(self, output_name, input_name_to_nodes, output_name_to_node): """Match whether an output is from a Transpose(perm=[0,3,1,2]) node.""" - parent = output_name_to_node[output_name] if output_name in output_name_to_node else None + parent = output_name_to_node.get(output_name, None) if parent is not None and parent.op_type == "Transpose": permutation = OnnxModel.get_node_attribute(parent, "perm") if permutation == [0, 3, 1, 2]: diff --git a/onnxruntime/python/tools/transformers/models/bert/eval_squad.py b/onnxruntime/python/tools/transformers/models/bert/eval_squad.py index 6089c960e47ee..8797fd9c2cfaf 100644 --- a/onnxruntime/python/tools/transformers/models/bert/eval_squad.py +++ b/onnxruntime/python/tools/transformers/models/bert/eval_squad.py @@ -193,7 +193,7 @@ def output_summary(results: List[Dict[str, Any]], csv_filename: str, metric_name if row: for key in key_names: - row[key] = values[key] if key in values else "" + row[key] = values.get(key, "") csv_writer.writerow(row) csv_file.flush() diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py index a1e6d3125e7fb..4823f0d5874dd 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py @@ -171,12 +171,10 @@ def print_wins(wins, rows, test_name): rank = 0 previous_value = -1 - count = 0 - for key, value in sorted_wins.items(): + for count, (key, value) in enumerate(sorted_wins.items()): if value != previous_value: rank = count previous_value = value - count += 1 for row in rows: if row["run_id"] == key: diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_tester.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_tester.py index 12700f00ad0c2..f4705bef6a988 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_tester.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_tester.py @@ -387,8 +387,8 @@ def test_generation( if i % 10 == 0: print(f"{i}") input_ids = inputs["input_ids"] - position_ids = inputs["position_ids"] if "position_ids" in inputs else None - attention_mask = inputs["attention_mask"] if "attention_mask" in inputs else None + position_ids = inputs.get("position_ids", None) + attention_mask = inputs.get("attention_mask", None) onnx_runner = Gpt2Tester( input_ids, diff --git a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py index c9a679c4eac8a..51a967cf22608 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py +++ b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py @@ -289,9 +289,7 @@ def inference(): def load_torch_model(model_name, device): - torch_model_name_or_dir = ( - PRETRAINED_LONGFORMER_MODELS[model_name] if model_name in PRETRAINED_LONGFORMER_MODELS else model_name - ) + torch_model_name_or_dir = PRETRAINED_LONGFORMER_MODELS.get(model_name, model_name) model = LongformerModel.from_pretrained(torch_model_name_or_dir) model.to(device) return model diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py index c03c6f0b21cd3..26b9a2792e9e1 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -92,7 +92,7 @@ def get_diffusers_module_name(self, model_name): "unetxl": "unet", "vae": "vae_decoder", } - return name_mapping[model_name] if model_name in name_mapping else model_name + return name_mapping.get(model_name, model_name) def get_cached_model_name(self, model_name): model_name = self.get_diffusers_module_name(model_name) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index a4bef1f06b4fe..f65060f4f93d3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -22,9 +22,9 @@ from onnxruntime import InferenceSession sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from float16 import float_to_float16_max_diff # noqa: E402 -from onnx_model import OnnxModel # noqa: E402 -from optimizer import optimize_model # noqa: E402 +from float16 import float_to_float16_max_diff +from onnx_model import OnnxModel +from optimizer import optimize_model logger = logging.getLogger(__name__) @@ -290,7 +290,7 @@ def verify_onnx( logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.") os.system(install_cmd) - from datasets import load_dataset # noqa: F811 + from datasets import load_dataset ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 4e064fa53bfc6..3967a7875f3a7 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -492,10 +492,7 @@ def export_onnx_model_from_pt( example_inputs = image_processor(data, return_tensors="pt") else: tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = ( - tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 - ) - + max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt") example_inputs = filter_inputs(example_inputs, input_names) @@ -599,9 +596,7 @@ def export_onnx_model_from_tf( # Fix "Using pad_token, but it is not set yet" error. if tokenizer.pad_token is None: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - max_input_size = ( - tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 - ) + max_input_size = tokenizer.max_model_input_sizes.get(model_name, 1024) config, model = load_tf_model(model_name, model_class, cache_dir, config_modifier) model.resize_token_embeddings(len(tokenizer)) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 0e20b1f871645..e9029f4f620f8 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -838,11 +838,9 @@ def get_graph_inputs(self, current_node, recursive=False): @staticmethod def input_index(node_output, child_node): - index = 0 - for input in child_node.input: + for index, input in enumerate(child_node.input): if input == node_output: return index - index += 1 return -1 def remove_unused_constant(self): @@ -908,7 +906,7 @@ def get_first_output(node): num_nodes_removed = 0 for node in self.model.graph.node: first_output = get_first_output(node) - kept_node = output_to_node[first_output] if first_output in output_to_node else None + kept_node = output_to_node.get(first_output) # Need double check the node since fused node might reuse output name of some nodes to be removed. # It is slow to compare whole node, so we compare op_type first to avoid comparing node in most cases. diff --git a/onnxruntime/python/tools/transformers/profiler.py b/onnxruntime/python/tools/transformers/profiler.py index 8e45b149eaf03..2306b579f92fe 100644 --- a/onnxruntime/python/tools/transformers/profiler.py +++ b/onnxruntime/python/tools/transformers/profiler.py @@ -329,7 +329,7 @@ def parse_node_results(sess_time, kernel_time_only=False, threshold=0): calls = node_freq[node_name] avg_time = duration / float(calls) percentage = (duration / total) * 100.0 - provider = node_provider[node_name] if node_name in node_provider else "" + provider = node_provider.get(node_name, "") before_percentage += percentage lines.append( f"{duration:10d}\t{percentage:5.2f}\t{before_percentage:5.2f}\t{avg_time:8.1f}\t{calls:5d}\t{provider:8s}\t{node_name}" @@ -347,7 +347,7 @@ def parse_node_results(sess_time, kernel_time_only=False, threshold=0): calls = node_freq[node_name] avg_time = duration / float(calls) percentage = (duration / total) * 100.0 - provider = node_provider[node_name] if node_name in node_provider else "" + provider = node_provider.get(node_name, "") lines.append(f"{duration:10d}\t{percentage:5.2f}\t{avg_time:8.1f}\t{calls:5d}\t{provider:8s}\t{node_name}") return lines @@ -393,7 +393,7 @@ def group_node_results(sess_time, kernel_time_only, use_gpu): total_fence_time += item["dur"] continue - provider = item["args"]["provider"] if "provider" in item["args"] else "" + provider = item["args"].get("provider", "") if provider in provider_counter: provider_counter[provider] += 1 else: @@ -425,7 +425,7 @@ def group_node_results(sess_time, kernel_time_only, use_gpu): lines.append("-" * 64) lines.append("Total(μs)\tTime%\tKernel(μs)\tKernel%\tCalls\tAvgKernel(μs)\tFence(μs)\tOperator") for op_name, kernel_time in sorted(op_kernel_time.items(), key=lambda x: x[1], reverse=True): - fence_time = op_fence_time[op_name] if op_name in op_fence_time else 0 + fence_time = op_fence_time.get(op_name, 0) kernel_time_ratio = kernel_time / total_kernel_time total_time = kernel_time + fence_time time_ratio = total_time / (total_kernel_time + total_fence_time) diff --git a/onnxruntime/test/providers/cpu/rnn/GRU.py b/onnxruntime/test/providers/cpu/rnn/GRU.py index 846fc3d06b9a9..144acaf14db61 100644 --- a/onnxruntime/test/providers/cpu/rnn/GRU.py +++ b/onnxruntime/test/providers/cpu/rnn/GRU.py @@ -47,8 +47,8 @@ def __init__(self, **params): if "initial_h" in params else np.zeros((num_directions, batch_size, hidden_size)).reshape(num_directions, batch_size, hidden_size) ) - LBR = params["linear_before_reset"] if "linear_before_reset" in params else 0 # noqa: N806 - self.direction = params["direction"] if "direction" in params else "forward" + LBR = params.get("linear_before_reset", 0) # noqa: N806 + self.direction = params.get("direction", "forward") if num_directions == 1: if self.direction == "forward": diff --git a/onnxruntime/test/providers/cpu/rnn/LSTM.py b/onnxruntime/test/providers/cpu/rnn/LSTM.py index 74299ea2c75a3..116ec3671bf01 100644 --- a/onnxruntime/test/providers/cpu/rnn/LSTM.py +++ b/onnxruntime/test/providers/cpu/rnn/LSTM.py @@ -65,13 +65,13 @@ def __init__(self, **params): # type: (*Any) -> None else np.zeros((num_directions, batch_size, hidden_size)).reshape(num_directions, batch_size, hidden_size) ) - f = params["f"] if "f" in params else ActivationFuncs.sigmoid - g = params["g"] if "g" in params else ActivationFuncs.tanh - h = params["h"] if "h" in params else ActivationFuncs.tanh - input_forget = params["input_forget"] if "input_forget" in params else False - clip = params["clip"] if "clip" in params else 9999.0 + f = params.get("f", ActivationFuncs.sigmoid) + g = params.get("g", ActivationFuncs.tanh) + h = params.get("h", ActivationFuncs.tanh) + input_forget = params.get("input_forget", False) + clip = params.get("clip", 9999.0) - self.direction = params["direction"] if "direction" in params else "forward" + self.direction = params.get("direction", "forward") if num_directions == 1: if self.direction == "forward": @@ -266,8 +266,8 @@ def SimpleWeightsNoBiasTwoRows(direction): # type: () -> None # noqa: N802 R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32) # noqa: N806 if direction == "bidirectional": - W = W = np.tile(W, (2, 1)).reshape(2, number_of_gates * hidden_size, input_size) # noqa: N806 - R = R = np.tile(R, (2, 1)).reshape(2, number_of_gates * hidden_size, hidden_size) # noqa: N806 + W = np.tile(W, (2, 1)).reshape(2, number_of_gates * hidden_size, input_size) # noqa: N806 + R = np.tile(R, (2, 1)).reshape(2, number_of_gates * hidden_size, hidden_size) # noqa: N806 lstm = LSTM_Helper(X=input, W=W, R=R, direction=direction) diff --git a/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py b/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py index 79d41e41d696c..4c1e3a70de1c7 100644 --- a/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py +++ b/onnxruntime/test/testdata/test_data_generation/adamw_test/adamw_test_data_generator.py @@ -58,10 +58,8 @@ def _torch_tensor_to_str(torch_tensor): def _build_param_index_to_name_mapping(model, map_result): """Build index to name mapping, which is used to retrieve data from optimizer group.""" - index = 0 - for param in model.named_parameters(): + for index, param in enumerate(model.named_parameters()): map_result[index] = param[0] - index += 1 torch.manual_seed(seed) @@ -119,8 +117,7 @@ def _build_param_index_to_name_mapping(model, map_result): _sync_stream() for group in adamw_optimizer.param_groups: - p_index = 0 - for param in group["params"]: + for p_index, param in enumerate(group["params"]): state = adamw_optimizer.state[param] name = param_index_to_name_mapping[p_index] # Collect flattened optimizer state data. @@ -130,7 +127,6 @@ def _build_param_index_to_name_mapping(model, map_result): else: m1_dict[name].append(_torch_tensor_to_str(state["exp_avg"].view(-1))) m2_dict[name].append(_torch_tensor_to_str(state["exp_avg_sq"].view(-1))) - p_index += 1 adamw_optimizer.step() adamw_optimizer.zero_grad() diff --git a/onnxruntime/test/testdata/test_data_generation/sgd_test/sgd_test_data_generator.py b/onnxruntime/test/testdata/test_data_generation/sgd_test/sgd_test_data_generator.py index a3d7946d63214..173225a21a52f 100644 --- a/onnxruntime/test/testdata/test_data_generation/sgd_test/sgd_test_data_generator.py +++ b/onnxruntime/test/testdata/test_data_generation/sgd_test/sgd_test_data_generator.py @@ -58,10 +58,8 @@ def _torch_tensor_to_str(torch_tensor): def _build_param_index_to_name_mapping(model, map_result): """Build index to name mapping, which is used to retrieve data from optimizer group.""" - index = 0 - for param in model.named_parameters(): + for index, param in enumerate(model.named_parameters()): map_result[index] = param[0] - index += 1 torch.manual_seed(seed) diff --git a/orttraining/orttraining/python/training/ort_triton/_common.py b/orttraining/orttraining/python/training/ort_triton/_common.py index b7e55bc733ede..a1c3d7d7e1d4f 100644 --- a/orttraining/orttraining/python/training/ort_triton/_common.py +++ b/orttraining/orttraining/python/training/ort_triton/_common.py @@ -30,7 +30,7 @@ def get_variable_name(self, name: str) -> str: # For some operators such as data load/store, we need an internal variable name inside the kernel function. def get_internal_variable_name(self, name: str) -> str: var_name = self._var_map[name] - var_name = self._var_map[var_name] if var_name in self._var_map else var_name + var_name = self._var_map.get(var_name, var_name) return f'float("{var_name}")' if var_name in _SPECIAL_FLOATS else var_name diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index f10416a9bb0f4..af5f3c9ceb565 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -376,8 +376,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): def post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto: # Loop all PythonOp, append "_ctx" as the first output. - index = 0 - for node in exported_model.graph.node: + for index, node in enumerate(exported_model.graph.node): op_name_prefix = node.op_type if node.domain == "com.microsoft" and node.op_type == "PythonOp": output_names = list(node.output) @@ -391,7 +390,6 @@ def post_process_enabling_autograd_function(exported_model: ModelProto) -> Model break node.name = f"{op_name_prefix}_id_{index}" - index += 1 return exported_model diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 4883075112dcb..75512cb8e8c88 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -48,7 +48,7 @@ def _to_gradient_definition(gradient): attr_def.name = key attr_def.value_json = json.dumps(value["value"]) attr_def.dtype = value["dtype"] - attr_def.is_tensor = value["is_tensor"] if "is_tensor" in value else False + attr_def.is_tensor = value.get("is_tensor", False) attributes.append(attr_def) node_def.attributes = attributes node_defs.append(node_def) diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py b/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py index dcaa202d46fd8..905eb62768a92 100644 --- a/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py @@ -214,8 +214,7 @@ def recursive_wrap(module, save_onnx=False, onnx_prefix=""): if isinstance(sub_module, torch.nn.ModuleList): # We encounter a list of sub-modules. # Let's wrap them one-by-one. - idx = 0 - for item_name, sub_module_item in sub_module._modules.items(): + for idx, (item_name, sub_module_item) in enumerate(sub_module._modules.items()): # Avoid saving too many graphs. new_save_onnx = save_onnx and idx == 0 sub_new_prefix = new_prefix + "_" + item_name @@ -237,7 +236,6 @@ def recursive_wrap(module, save_onnx=False, onnx_prefix=""): ) else: recursive_wrap(sub_module_item, new_save_onnx, sub_new_prefix) - idx += 1 else: if is_supported(sub_module): # Just wrap it as ORTModule when possible. diff --git a/orttraining/orttraining/test/python/orttraining_test_model_transform.py b/orttraining/orttraining/test/python/orttraining_test_model_transform.py index 3b07aa1f4daf0..095830cd54ab8 100644 --- a/orttraining/orttraining/test/python/orttraining_test_model_transform.py +++ b/orttraining/orttraining/test/python/orttraining_test_model_transform.py @@ -2,10 +2,8 @@ def add_name(model): - i = 0 - for node in model.graph.node: + for i, node in enumerate(model.graph.node): node.name = "%s_%d" % (node.op_type, i) - i += 1 def find_single_output_node(model, arg): diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 6a6832e06330a..51aa1564cbfbe 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6428,7 +6428,7 @@ def run_step(model, x): reason="This test fail because bert forward loss is nan in updated transformers lib, disable for now." ) def test_bert_result_with_layerwise_recompute(): - original_val = os.environ["ORTMODULE_MEMORY_OPT_LEVEL"] if "ORTMODULE_MEMORY_OPT_LEVEL" in os.environ else None + original_val = os.environ.get("ORTMODULE_MEMORY_OPT_LEVEL", None) # Create PyTorch model with dropout disabled. pt_model = _get_bert_for_sequence_classification_model( "cuda", is_training=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0 diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py index 50016515a69e1..043c70263d31e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py @@ -125,8 +125,6 @@ def run_with_ort_on_gpu(model, args, rank, device): try: mp.spawn(test_Distributed_ReduceWithMarkDirtyModel, nprocs=size, args=(size,)) except Exception: - import sys # noqa: F811 - sys.stdout.flush() sys.stderr.flush() raise diff --git a/orttraining/tools/scripts/gpt2_model_transform.py b/orttraining/tools/scripts/gpt2_model_transform.py index 06f03e06632b4..294af13fe69b7 100644 --- a/orttraining/tools/scripts/gpt2_model_transform.py +++ b/orttraining/tools/scripts/gpt2_model_transform.py @@ -17,10 +17,8 @@ def add_name(model): - i = 0 - for node in model.graph.node: + for i, node in enumerate(model.graph.node): node.name = "%s_%d" % (node.op_type, i) - i += 1 def find_input_node(model, arg): @@ -139,11 +137,9 @@ def process_concat(model): delete_nodes.append(get_node_index(model, n)) # insert new shape to reshape - index = 0 - for reshape_node_index in new_nodes: + for index, reshape_node_index in enumerate(new_nodes): shape_tensor = numpy_helper.from_array(np.asarray(new_nodes[reshape_node_index], dtype=np.int64)) const_node = add_const(model, "concat_shape_node_%d" % index, "concat_shape_%d" % index, shape_tensor) - index += 1 reshape_node = model.graph.node[reshape_node_index] reshape_node.input[1] = const_node.output[0] # delete nodes @@ -154,28 +150,22 @@ def process_concat(model): def replace_input_arg(model, arg, new_arg): for node in model.graph.node: - i = 0 - while i < len(node.input): - if node.input[i] == arg: + for i, input_name in enumerate(node.input): + if input_name == arg: node.input[i] = new_arg - i += 1 def find_weight_index(model, name): - index = 0 - for w in model.graph.initializer: + for index, w in enumerate(model.graph.initializer): if w.name == name: return index - index += 1 return None def find_input_index(model, name): - index = 0 - for w in model.graph.input: + for index, w in enumerate(model.graph.input): if w.name == name: return index - index += 1 return None diff --git a/orttraining/tools/scripts/model_transform.py b/orttraining/tools/scripts/model_transform.py index 81e9f7b16be14..f0cf53990eac3 100644 --- a/orttraining/tools/scripts/model_transform.py +++ b/orttraining/tools/scripts/model_transform.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import sys import numpy as np import onnx -from onnx import TensorProto, helper, numpy_helper, shape_inference # noqa: F401 +from onnx import numpy_helper if len(sys.argv) < 2: print("Please give model path...") @@ -15,10 +17,8 @@ def add_name(model): - i = 0 - for node in model.graph.node: + for i, node in enumerate(model.graph.node): node.name = "%s_%d" % (node.op_type, i) - i += 1 def find_input_node(model, arg): @@ -118,11 +118,9 @@ def process_concat(model): for n in fuse_nodes: delete_nodes.append(get_node_index(model, n)) # insert new shape to reshape - index = 0 - for reshape_node_index in new_nodes: + for index, reshape_node_index in enumerate(new_nodes): shape_tensor = numpy_helper.from_array(np.asarray(new_nodes[reshape_node_index], dtype=np.int64)) const_node = add_const(model, "concat_shape_node_%d" % index, "concat_shape_%d" % index, shape_tensor) - index += 1 reshape_node = model.graph.node[reshape_node_index] reshape_node.input[1] = const_node.output[0] # delete nodes @@ -199,12 +197,10 @@ def replace_input_arg(model, arg, new_arg): i += 1 -def find_weight_index(model, name): - index = 0 - for w in model.graph.initializer: +def find_weight_index(model, name: str) -> int | None: + for index, w in enumerate(model.graph.initializer): if w.name == name: return index - index += 1 return None diff --git a/requirements-lintrunner.txt b/requirements-lintrunner.txt index 25454ce40c263..6836d5df69324 100644 --- a/requirements-lintrunner.txt +++ b/requirements-lintrunner.txt @@ -1,7 +1,7 @@ # This file is auto updated by dependabot lintrunner-adapters>=0.11.0 # RUFF -ruff==0.1.4 +ruff==0.2.1 # BLACK-ISORT black==23.10.1 isort==5.12.0 diff --git a/tools/ci_build/github/linux/ort_minimal/readelf_utils.py b/tools/ci_build/github/linux/ort_minimal/readelf_utils.py index dec070e3f5c75..2264742079d15 100644 --- a/tools/ci_build/github/linux/ort_minimal/readelf_utils.py +++ b/tools/ci_build/github/linux/ort_minimal/readelf_utils.py @@ -66,8 +66,8 @@ def diff_sections_total_size(base_binary_path, binary_path, readelf_path="readel results = collections.OrderedDict() for section in sorted(merged_keys): - base_size = base_section_sizes[section] if section in base_section_sizes else 0 - size = section_sizes[section] if section in section_sizes else 0 + base_size = base_section_sizes.get(section, 0) + size = section_sizes.get(section, 0) base_total += base_size total += size diff --git a/tools/ci_build/op_registration_validator.py b/tools/ci_build/op_registration_validator.py index 8222437f7b42e..5c7edfa88a48b 100644 --- a/tools/ci_build/op_registration_validator.py +++ b/tools/ci_build/op_registration_validator.py @@ -165,7 +165,7 @@ def _validate_last_registration(self, last_r: RegistrationInfo) -> bool: # domain that have newer registrations in a non-contrib op file differently. They should only be considered # deprecated as contrib ops. domain_and_op_str = last_r.domain_and_op_str() - deprecation_version = deprecated_ops.get(domain_and_op_str, None) + deprecation_version = deprecated_ops.get(domain_and_op_str) allow_missing_unversioned_registration = ( deprecation_version is not None and last_r.end_version == deprecation_version - 1 diff --git a/tools/python/fix_long_lines.py b/tools/python/fix_long_lines.py index 383fdc9623551..8a3c249ef672a 100644 --- a/tools/python/fix_long_lines.py +++ b/tools/python/fix_long_lines.py @@ -20,9 +20,8 @@ def _process_files(filenames, clang_exe, tmpdir): bad_lines = [] with open(path, encoding="UTF8") as f: - line_num = 0 - for line in f: - line_num += 1 # clang-format line numbers start at 1 + for i, line in enumerate(f): + line_num = i + 1 # clang-format line numbers start at 1 if len(line) > 120: bad_lines.append(line_num) diff --git a/tools/python/gen_opkernel_doc.py b/tools/python/gen_opkernel_doc.py index 1075ed8192fdd..f6f9f21396859 100644 --- a/tools/python/gen_opkernel_doc.py +++ b/tools/python/gen_opkernel_doc.py @@ -22,11 +22,9 @@ def format_version_range(v): def format_type_constraints(tc): - counter = 0 tcstr = "" firsttcitem = True for tcitem in tc: - counter += 1 if firsttcitem: firsttcitem = False else: @@ -98,7 +96,7 @@ def main(output_path: pathlib.Path, provider_filter: [str]): paramstr += f"*out* {outp.name}:**{outp.typeStr}**" paramstr += "" - paramset = paramdict.get(fullname, None) + paramset = paramdict.get(fullname) if paramset is None: paramdict[fullname] = set() @@ -145,9 +143,8 @@ def main(output_path: pathlib.Path, provider_filter: [str]): else: fout.write("|||") fout.write(format_version_range(version_range) + "|") - tnameindex = 0 - for tname, tcset in sorted(typemap.items()): - tnameindex += 1 + for i, (tname, tcset) in enumerate(sorted(typemap.items())): + tnameindex = i + 1 tclist = [] for tc in sorted(tcset): tclist.append(tc) diff --git a/tools/python/ort_test_dir_utils.py b/tools/python/ort_test_dir_utils.py index cd1f5022af526..3af407b2aeee6 100644 --- a/tools/python/ort_test_dir_utils.py +++ b/tools/python/ort_test_dir_utils.py @@ -115,8 +115,7 @@ def create_test_dir( model_outputs = model.graph.output def save_data(prefix, name_data_map, model_info): - idx = 0 - for name, data in name_data_map.items(): + for idx, (name, data) in enumerate(name_data_map.items()): if isinstance(data, dict): # ignore. map from traditional ML ops pass @@ -130,8 +129,6 @@ def save_data(prefix, name_data_map, model_info): with open(filename, "wb") as f: f.write(tensor.SerializeToString()) - idx += 1 - if not name_input_map: name_input_map = {} diff --git a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py index 9eccb7c36455f..f8cc34e04afa0 100644 --- a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py +++ b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py @@ -105,7 +105,7 @@ def _node_output_is_supported(name): # some models don't have complete imports. use 1 as a default as that's valid for custom domains and should # result in an error for any others. not sure why ONNX or ORT validation allows this though. - opset = opsets[domain] if domain in opsets else 1 + opset = opsets.get(domain, 1) if ( domain not in required_ops or opset not in required_ops[domain] From 1007d8f3d1904ff2efc4e0647e795939fe049464 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 9 Feb 2024 09:24:54 -0800 Subject: [PATCH 060/207] Revert "Revert NeuralSpeed code for x64 MatMulNBits (#19382)" (#19474) This reverts commit 0d10c7f3c1111cfff064e7990aa897ac9fd05c82. --- cgmanifests/generated/cgmanifest.json | 10 + cmake/CMakeLists.txt | 12 + cmake/deps.txt | 1 + cmake/external/neural_speed.cmake | 15 + cmake/onnxruntime_providers_cpu.cmake | 15 + .../cpu/quantization/matmul_nbits.cc | 144 ++++++ .../cpu/quantization/neural_speed_defs.h | 45 ++ .../cpu/quantization/neural_speed_gemm.cc | 438 ++++++++++++++++++ .../cpu/quantization/neural_speed_gemm.h | 129 ++++++ .../cpu/quantization/neural_speed_wrapper.h | 39 ++ .../test/contrib_ops/matmul_4bits_test.cc | 175 +++++++ 11 files changed, 1023 insertions(+) create mode 100644 cmake/external/neural_speed.cmake create mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h create mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc create mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h create mode 100644 onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index fc4ea25603152..efd901787fdb7 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -202,6 +202,16 @@ "comments": "mp11" } }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c11386eb632eec7c1c2aa323142f73519f946e2a", + "repositoryUrl": "https://github.com/intel/neural-speed.git" + }, + "comments": "neural_speed" + } + }, { "component": { "type": "git", diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0ccd874cee3c9..90fe8276ea9c7 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -88,6 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF) option(onnxruntime_USE_SNPE "Build with SNPE support" OFF) option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF) option(onnxruntime_USE_DNNL "Build with DNNL support" OFF) +option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF) option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) @@ -901,6 +902,10 @@ function(onnxruntime_set_compile_flags target_name) target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN) endif() + if(USE_NEURAL_SPEED) + target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED) + endif() + set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON) if (onnxruntime_USE_CUDA) # Suppress a "conversion_function_not_usable" warning in gsl/span @@ -1188,6 +1193,13 @@ if (onnxruntime_USE_DNNL) add_compile_definitions(DNNL_OPENMP) endif() +if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD) + include(neural_speed) + if (USE_NEURAL_SPEED) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla) + endif() +endif() + # TVM EP if (onnxruntime_USE_TVM) if (NOT TARGET tvm) diff --git a/cmake/deps.txt b/cmake/deps.txt index 17c3cbf9a6c43..cb431f8c77397 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -35,6 +35,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 +neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake new file mode 100644 index 0000000000000..ed711351403a7 --- /dev/null +++ b/cmake/external/neural_speed.cmake @@ -0,0 +1,15 @@ +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64") + set(USE_NEURAL_SPEED TRUE) +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64") + set(USE_NEURAL_SPEED TRUE) +endif() + +if(USE_NEURAL_SPEED) + FetchContent_Declare( + neural_speed + URL ${DEP_URL_neural_speed} + URL_HASH SHA1=${DEP_SHA1_neural_speed} + ) + set(BTLA_USE_OPENMP OFF) + onnxruntime_fetchcontent_makeavailable(neural_speed) +endif() diff --git a/cmake/onnxruntime_providers_cpu.cmake b/cmake/onnxruntime_providers_cpu.cmake index f60faa4d39116..b81a5c79ac0cc 100644 --- a/cmake/onnxruntime_providers_cpu.cmake +++ b/cmake/onnxruntime_providers_cpu.cmake @@ -60,6 +60,15 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS) "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc" ) endif() + set(onnxruntime_cpu_neural_speed_srcs + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc" + "${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h" + ) + if(NOT USE_NEURAL_SPEED) + list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs}) + endif() # add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs}) list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs}) @@ -144,6 +153,12 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL) target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical") endif() +if(NOT onnxruntime_DISABLE_CONTRIB_OPS) + if(USE_NEURAL_SPEED) + onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla) + endif() +endif() + if (MSVC) target_compile_options(onnxruntime_providers PRIVATE "/bigobj") # if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index e8d8bbca66fe7..166f5c8f52f54 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -10,6 +10,10 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" +#ifdef ORT_NEURAL_SPEED +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" +#endif + namespace onnxruntime { namespace contrib { @@ -19,6 +23,16 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level static_cast(CompMostAccurate), static_cast(CompLeastAccurate)); +#if defined(ORT_NEURAL_SPEED) + + ORT_UNUSED_PARAMETER(nbits); + ORT_UNUSED_PARAMETER(block_size); + + // Neural Speed APIs already expect a minimum accuracy level so just use the given value. + return accuracy_level; + +#else // defined(ORT_NEURAL_SPEED) + // Find a supported accuracy level that is not less accurate than the one given. // CompMostAccurate is always supported with the fallback implementation. // Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed. @@ -31,6 +45,8 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level } return effective_accuracy_level; + +#endif // defined(ORT_NEURAL_SPEED) } } // namespace @@ -45,6 +61,17 @@ class MatMulNBits final : public OpKernel { accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); +#ifdef ORT_NEURAL_SPEED + const Tensor* tensor_B = nullptr; + const Tensor* tensor_scale = nullptr; + const Tensor* tensor_zero_point = nullptr; + bool B_constant = info.TryGetConstantInput(1, &tensor_B); + bool scale_constant = info.TryGetConstantInput(2, &tensor_scale); + bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point); + is_asym_ = info.GetInputCount() >= 4; + all_constant_ = B_constant && scale_constant; + all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_; +#endif } Status Compute(OpKernelContext* context) const override; @@ -65,6 +92,13 @@ class MatMulNBits final : public OpKernel { const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; size_t packed_b_size_{0}; + +#if defined(ORT_NEURAL_SPEED) + + bool is_asym_{false}; + bool all_constant_{false}; + +#endif // defined(ORT_NEURAL_SPEED) }; Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, @@ -72,6 +106,54 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; +#if defined(ORT_NEURAL_SPEED) + + if (!all_constant_) { + return Status::OK(); + } + MLAS_THREADPOOL* pool = NULL; + if (nbits_ != 4) { + return Status::OK(); + } + auto comp_type = static_cast(accuracy_level_); + auto nbits = static_cast(nbits_); + if (input_idx == 1) { + packed_b_size_ = NSNBitsGemmPackBSize(N_, K_, block_size_, nbits, is_asym_, comp_type); + if (packed_b_size_ == 0) return Status::OK(); + auto qptr = tensor.Data(); + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + std::memset(packed_b_.get(), 0, packed_b_size_); + NSNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, false, + comp_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + if (input_idx == 2 && packed_b_ != nullptr) { + auto sptr = tensor.Data(); + NSNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, !is_asym_, + comp_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + if (input_idx == 3 && packed_b_ != nullptr) { + auto zptr = tensor.Data(); + NSNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits, is_asym_, is_asym_, + comp_type, pool); + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + is_packed = true; + } + +#else // defined(ORT_NEURAL_SPEED) + if (input_idx == 1) { const auto compute_type = static_cast(accuracy_level_); if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { @@ -91,6 +173,8 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat is_packed = true; } +#endif // defined(ORT_NEURAL_SPEED) + return Status::OK(); } @@ -98,11 +182,31 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; +#if defined(ORT_NEURAL_SPEED) + + // Pack three tensors into one buffer + if (input_idx == 1) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + if (input_idx == 2) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + if (input_idx == 3) { + used_shared_buffers = true; + packed_b_ = std::move(prepacked_buffers[0]); + } + +#else // defined(ORT_NEURAL_SPEED) + if (input_idx == 1) { used_shared_buffers = true; packed_b_ = std::move(prepacked_buffers[0]); } +#endif // defined(ORT_NEURAL_SPEED) + return Status::OK(); } @@ -112,6 +216,46 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); +#if defined(ORT_NEURAL_SPEED) + + if (packed_b_) { + TensorShape b_shape({static_cast(N_), static_cast(K_)}); + + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(false); + std::vector gemm_params(max_len); + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + for (size_t i = 0; i < max_len; i++) { + gemm_params[i].A = a_data + helper.LeftOffsets()[i]; + gemm_params[i].lda = lda; + gemm_params[i].B = packed_b_.get(); + gemm_params[i].C = y_data + helper.OutputOffsets()[i]; + gemm_params[i].ldc = N; + } + auto ws_size = NSSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data()); + // workspace for activation process(dynamic quantization and others) + auto ws_ptr = IAllocator::MakeUniquePtr(allocator, ws_size); + NSSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), thread_pool); + return Status::OK(); + } + +#endif // defined(ORT_NEURAL_SPEED) + const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); const auto* scales_data = scales->Data(); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h new file mode 100644 index 0000000000000..864abffd131fe --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h @@ -0,0 +1,45 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +--*/ + +#pragma once + +#include "contrib_ops/cpu/quantization/neural_speed_wrapper.h" + +namespace bestla { + +using tAVX512F = gemm::SCoreRowNAvx512f<48, 8>; +using tAMX_BF16 = gemm::HCoreRowNAmxbf16<64, 16>; +using tAVX512_FP16 = gemm::HCoreRowNAvx512fp16<96, 8>; +using tAVX_VNNI = gemm::ICoreRowNAvxvnni<24, 4>; +using tAVX512_VNNI = gemm::ICoreRowNAvx512vnni<48, 8>; +using tAMX_INT8_US = gemm::ICoreRowNAmxint8<64, 16>; +using tAMX_INT8_SS = gemm::ICoreRowNAmxint8SS<64, 16>; +using tAVX2 = gemm::SCoreRowNAvx2<24, 4>; +using tAVX_VNNI_KBlock = gemm::ICoreRowNAvxvnniKBlock<24, 2>; +using tAVX512_VNNI_KBlock = gemm::ICoreRowNAvx512vnniKBlock<48, 4>; +using tAMX_INT8_US_KBlock = gemm::ICoreRowNAmxint8KBlock<48, 16>; +using tAMX_INT8_SS_KBlock = gemm::ICoreRowNAmxint8SSKBlock<48, 16>; + +template +using tWeiNInt = prologue_b::gemm::WeightKBlockNInteger; +template +using tWeiNFloat = prologue_b::gemm::WeightKBlockNFloat; + +class ORTThreading : public parallel::IThreading { + public: + explicit ORTThreading(void* tp); + void parallel_for(const parallel::thread_func& func) const override; + void set_threads(int nthreads) override { + (void)(nthreads); + assert(0); + } + void sync() const override { assert(0); } + void* mTp; +}; + +} // namespace bestla diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc new file mode 100644 index 0000000000000..73aaa4ae61a6e --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.cc @@ -0,0 +1,438 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + neural_speed_gemm.cpp + +Abstract: + + GEMM template combinations of neural_speed. +--*/ + +#include "contrib_ops/cpu/quantization/neural_speed_defs.h" +#include "contrib_ops/cpu/quantization/neural_speed_gemm.h" +#include "core/platform/threadpool.h" + +using ThreadPool = onnxruntime::concurrency::ThreadPool; + +namespace bestla { + +ORTThreading::ORTThreading(void* tp) + : IThreading(ThreadPool::DegreeOfParallelism(reinterpret_cast(tp))), mTp(tp) {} + +void ORTThreading::parallel_for(const parallel::thread_func& func) const { + ThreadPool::TrySimpleParallelFor(reinterpret_cast(mTp), mThreadNum, + [&](ptrdiff_t tid) { func(static_cast(tid)); }); +} + +template +static void NSSQ4GemmCompF32(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, + parallel::IThreading* th) { + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + if (M <= 16) { + using Parallel = parallel::gemm::SchedulerKBlock; + using Launcher = + wrapper::gemm::LauncherKBlock; + static Launcher kernel; + auto reduceA = kernel.mProA.createStorage(M_, K_, B->mBlockSize); + if (B->IsAsym()) { + reduceA.assign(WorkSpace); + ORTThreading single(nullptr); + kernel.mProA.reduce({A, lda_, &reduceA}, M_, K_, B->mBlockSize, &single); + } + typename Launcher::Param args{gp, + {A, lda_, &reduceA}, + {B}, + {B->template SPtr(), B->SDtype(), B->CStep(), B->template ZPtr(), + reduceA.template RPtr(), reduceA.lda}, + {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); + } else { + using Parallel = parallel::gemm::SchedulerBase; + using Launcher = + wrapper::gemm::LauncherBase; + static Launcher kernel; + typename Launcher::Param args{gp, {A, lda_}, {B}, {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); + } +} + +template +static void NSSQ4GemmCompInt8(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc, int8_t* WorkSpace, + parallel::IThreading* th) { + using Parallel = parallel::gemm::SchedulerKBlockS; + using Launcher = + wrapper::gemm::LauncherIntKBlock; + auto M_ = static_cast(M); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto lda_ = static_cast(lda); + auto ldc_ = static_cast(ldc); + static Launcher kernel; + auto quanA = kernel.mProA.createStorage(M_, K_, B->mBlockSize, B->IsAsym()); + quanA.assign(WorkSpace); + if (M <= 16) { + ORTThreading single(nullptr); + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, &single); + } else { + kernel.mProA.quantize({A, lda_, &quanA}, M_, K_, th); + } + utils::GemmProblem gp(1, M_, N_, K_, B->mBlockSize); + typename Launcher::Param args{gp, {A, lda_, &quanA}, {B}, {C, ldc_, nullptr}}; + parallel::GemmRun(kernel, args, th); +} + +template +static size_t NSSQ4GemmCompF32WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { + auto M_ = static_cast(M); + auto K_ = static_cast(K); + (void)(A); + (void)(N); + (void)(C); + (void)(lda); + (void)(ldc); + if (M <= 16) { + using ProA = prologue_a::gemm::ActivationKBlockBaseF32; + static ProA proA; + if (B->IsAsym()) { + auto reduceA = proA.createStorage(M_, K_, B->mBlockSize); + return reduceA.mSize; + } + return 0; + } else { + // using ProA = prologue_a::gemm::ActivationBase; + return 0; + } +} + +template +static size_t NSSQ4GemmCompInt8WorkspaceSize(size_t M, size_t N, size_t K, const float* A, size_t lda, + storage::gemm::StorageWeightKBlockNInteger* B, float* C, size_t ldc) { + (void)(N); + (void)(lda); + (void)(ldc); + (void)(A); + (void)(C); + using ProA = prologue_a::gemm::ActivationF32KBlockQuantize; + static ProA proA; + auto quanA = + proA.createStorage(static_cast(M), static_cast(K), static_cast(B->mBlockSize), B->IsAsym()); + return quanA.mSize; +} + +} // namespace bestla + +using namespace bestla; + +static bool NSSQ4GemmBatchDriver(size_t M, size_t N, size_t K, size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, int8_t* WorkSpace, + void* ThreadPool) { + GetCPUDevice(); + bestla::ORTThreading orth(ThreadPool); + bool processed = true; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = bestla::storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto BlkSize = kptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == bestla::tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); + } else if (NTile == bestla::tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + bestla::NSSQ4GemmCompF32(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, + DataParams[i].ldc, WorkSpace, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == bestla::tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && + BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + } else if (NTile == bestla::tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, + &orth); + } else if (NTile == bestla::tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && + BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + bestla::NSSQ4GemmCompInt8(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc, WorkSpace, &orth); + } + } + } + } else { + processed = false; + break; + } + } + return processed; +} + +static size_t NSSQ4GemmBatchWorkspaceSize(size_t M, size_t N, size_t K, size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + GetCPUDevice(); + size_t size = 0; + for (size_t i = 0; i < BatchN; i++) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(DataParams[i].B); + auto uptr = std::unique_ptr(ptr); + if (ptr) { + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto kptr = reinterpret_cast(ptr); + auto NTile = + gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + auto BlkSize = kptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + size = std::max(NSSQ4GemmCompF32WorkspaceSize(M, N, K, DataParams[i].A, DataParams[i].lda, kptr, + DataParams[i].C, DataParams[i].ldc), + size); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + size = std::max(NSSQ4GemmCompInt8WorkspaceSize( + M, N, K, DataParams[i].A, DataParams[i].lda, kptr, DataParams[i].C, DataParams[i].ldc), + size); + } + } + } + } + } + return size; +} + +template +static size_t NSQ4BuSize(size_t block_size, size_t N, size_t K, bool isAsym) { + static T proB; + auto stor = proB.createStorage(static_cast(N), static_cast(K), static_cast(block_size), + BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, BTLA_DTYPE::BF16, isAsym); + // TODO(Yu) support more scale dtype + return stor.mSize; +} + +static bool NSQ4GemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + auto ptr = storage::gemm::PackedWeightParser::deserialBuffer(PackedBuf); + auto uptr = std::unique_ptr(ptr); + ORTThreading orth(ThreadPool); + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto ldb_ = static_cast(ldb); + GetCPUDevice(); + if (ptr) { + auto NTile = gemm::CoreAttr::get_mask_val(ptr->mCoreId, gemm::CoreAttr::NTILE_MASK, gemm::CoreAttr::NTILE_SHIFT); + auto PackRow = gemm::CoreAttr::get_packrow(ptr->mCoreId); + auto CType = gemm::CoreAttr::get_comp(ptr->mCoreId); + auto btype = static_cast(gemm::CompTypeHelper::get_B(CType)); + if (ptr->mPrologueID == BTLA_PROLOGUEB_IDS::WeightKBlockNInteger) { + auto wptr = reinterpret_cast(ptr); + auto BlkSize = wptr->mBlockSize; + if (btype == gemm::CompType::tFP32 && PackRow == 1) { + if (NTile == tAVX512F::NTILE && _cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX2::NTILE && _cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + if (btype == gemm::CompType::tS8 && PackRow == 4) { + if (NTile == tAMX_INT8_SS_KBlock::NTILE && _cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX512_VNNI_KBlock::NTILE && _cd->AVX512_VNNI() && + BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } else if (NTile == tAVX_VNNI_KBlock::NTILE && _cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + static tWeiNInt proB; + proB.unpackWeight(N_, K_, wptr, FpData, ldb_, &orth); + } + } + } + return true; + } + return false; +} + +template +static void NSQ4GemmPackBImpl(void* PackedBuf, size_t BlkSize, const uint8_t* QData, const float* Scale, + const uint8_t* Zp, size_t N, size_t K, bool IsAsym, bool lastCall, size_t ldb, + void* ThreadPool) { + static T proB; + auto N_ = static_cast(N); + auto K_ = static_cast(K); + auto stor = proB.createStorage(N_, K_, static_cast(BlkSize), BTLA_DTYPE::S4_CLIP, BTLA_DTYPE::F32, + BTLA_DTYPE::BF16, IsAsym); + stor.assign(reinterpret_cast(PackedBuf)); + ORTThreading orth(ThreadPool); + proB.packNbitsWeightQ4(N_, K_, IsAsym, QData, static_cast(ldb), Scale, Zp, &stor, &orth); + if (lastCall) { + proB.reduceWeight(&stor, &orth); + } +} + +static size_t NSQ4GemmPackBSize(size_t N, size_t K, size_t BlkSize, bool isAsym, NS_SQNBIT_COMPUTE_TYPE CompType) { + GetCPUDevice(); + if (K % BlkSize != 0) { + return 0; + } + // from low precision to high precision + switch (CompType) { + case NSCompInt8: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + } + [[fallthrough]]; + case NSCompBf16: + case NSCompFp16: + case NSCompFp32: + case NSCompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + return NSQ4BuSize>(BlkSize, N, K, isAsym); + } + [[fallthrough]]; + default: + return 0; + } +} + +static bool NSQ4GemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, + size_t K, size_t ldb, size_t BlkSize, bool isAsym, bool lastCall, + NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { + GetCPUDevice(); + // explicit statement fall through. + switch (CompType) { + case NSCompInt8: + if (!isAsym) { // asym int8 is not optimized, so fall through to others. + if (_cd->AMX_INT8() && BlkSize % tAMX_INT8_SS_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX512_VNNI() && BlkSize % tAVX512_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>( + PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX_VNNI() && BlkSize % tAVX_VNNI_KBlock::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, + K, isAsym, lastCall, ldb, ThreadPool); + return true; + } + } + [[fallthrough]]; + case NSCompBf16: + case NSCompFp16: + case NSCompFp32: + case NSCompUndef: + if (_cd->AVX512F() && BlkSize % tAVX512F::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, + lastCall, ldb, ThreadPool); + return true; + } + if (_cd->AVX2() && BlkSize % tAVX2::KTILE == 0) { + NSQ4GemmPackBImpl>(PackedBuf, BlkSize, QData, Scale, Zp, N, K, isAsym, lastCall, + ldb, ThreadPool); + return true; + } + [[fallthrough]]; + default: + return false; + } +} + +size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t BlkSize, int nbits, bool isAsym, + NS_SQNBIT_COMPUTE_TYPE CompType) { + if (nbits == 4) { + auto jsize = NSQ4GemmPackBSize(N, K, BlkSize, isAsym, CompType); + if (jsize) { + return jsize; + } + } + return 0; +} + +void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, + size_t ldb, size_t BlkSize, int nbits, bool isAsym, bool lastCall, + NS_SQNBIT_COMPUTE_TYPE CompType, void* ThreadPool) { + if (nbits == 4) { + if (NSQ4GemmPackB(PackedBuf, QData, Scale, Zp, N, K, ldb, BlkSize, isAsym, lastCall, CompType, ThreadPool)) { + return; + } + } +} + +void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* ThreadPool) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + if (NSQ4GemmUnPackB(FpData, PackedBuf, N, K, ldb, ThreadPool)) { + return; + } +} + +size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + return NSSQ4GemmBatchWorkspaceSize(M, N, K, BatchN, DataParams); +} + +void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, + void* ThreadPool) { + // only nbits=4 can be packed, so not necessary to check the nbits in DataParams + if (NSSQ4GemmBatchDriver(M, N, K, BatchN, DataParams, reinterpret_cast(WorkSpace), ThreadPool)) { + // PackedWeight is created by bestla + return; + } +} diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h new file mode 100644 index 0000000000000..ebcb3027a209f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_gemm.h @@ -0,0 +1,129 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + neural_speed_gemm.h + +Abstract: + + Prepack-weight GEMM APIs of neural_speed. +--*/ + +#pragma once + +#include +#include + +/** + * @brief Define compute types of block quantization + */ +enum NS_SQNBIT_COMPUTE_TYPE { + NSCompUndef = 0, /*!< undef */ + NSCompFp32 = 1, /*!< input fp32, accumulator fp32 */ + NSCompFp16 = 2, /*!< input fp16, accumulator fp16 */ + NSCompBf16 = 3, /*!< input bf16, accumulator fp32 */ + NSCompInt8 = 4 /*!< input int8, accumulator int32 */ +}; + +/** + * @brief Data parameters for NBits GEMM routine + * C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * All except C are [in] parameters + */ +struct NS_SQNBITS_GEMM_DATA_PACKED_PARAMS { + const float* A = nullptr; /**< address of A (float32 matrix)*/ + const void* B = nullptr; /**< address of B (packed nbits blob)*/ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldc = 0; /**< leading dimension of C*/ +}; + +/** + * @brief Compute the byte size of the parameter combination + * + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @return size of the packing buffer, 0 if the operation is not yet supported. + */ +size_t NSNBitsGemmPackBSize(size_t N, size_t K, size_t block_size, int nbits, bool is_asym, + NS_SQNBIT_COMPUTE_TYPE comp_type); + +/** + * @brief Prepack tensor data from n-bit quantized data, scale and zero point buffers. + * + * @param PackedBuf packed data buffer + * @param QData quantized data buffer + * @param Scale scale pointer + * @param Zp zero point pointer + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param block_size size of the block to quantize, elements from the same block share the same + * scale and zero point + * @param nbits number of bits used for weight quantization (default 4) + * @param is_asym flag for asymmetric quantization + * @param comp_type specify input data type and accumulator data type + * @param last_call flag to activate the epilogue process of packB. OpKernel::PrePack will query input tensor + * one by one: QData, Scale, Zp (if is_asym is true). But kernel prefers to pack all tensors into one blob data where + * they can share the common attributes like: block_size. Meanwhile, kernel has some pre-computations to speed up + * inference which require that all blob data are ready. So, you need to set this flag to true when passing Scale + * (is_asym is false) and Zp(is_asym is true). + * @param thread_pool + */ +void NSNBitsGemmPackB(void* PackedBuf, const uint8_t* QData, const float* Scale, const uint8_t* Zp, size_t N, size_t K, + size_t ldb, size_t block_size, int nbits, bool is_asym, bool last_call, + NS_SQNBIT_COMPUTE_TYPE comp_type, void* thread_pool); + +/** + * @brief Unpack and dequantize to fp32 + * + * @param FpData unpacked float32 data + * @param PackedBuf quantized and packed data + * @param N the number of columns of matrix B. + * @param K the number of rows of matrix B. + * @param ldb leading dimension of B + * @param thread_pool + */ +void NSNBitsGemmUnPackB(float* FpData, const void* PackedBuf, size_t N, size_t K, size_t ldb, void* thread_pool); + +/** + * @brief Get the workspace size required by computation. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @return Workspace size in bytes + */ +size_t NSSQNBitsGemmBatchWorkspaceSize(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams); + +/** + * @brief Batched GEMM: C = A * B + * A, C must be a float32 matrix + * B must be a packed nbits blob + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] WorkSpace temporary buffer + * @param[in] ThreadPool + * @return + */ +void NSSQNBitsGemmBatchPackedB(const size_t M, const size_t N, const size_t K, const size_t BatchN, + const NS_SQNBITS_GEMM_DATA_PACKED_PARAMS* DataParams, void* WorkSpace, + void* ThreadPool = nullptr); diff --git a/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h new file mode 100644 index 0000000000000..d3902f9bd68c7 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/neural_speed_wrapper.h @@ -0,0 +1,39 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +//----------------------------------------------------------------------------- +#pragma once +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wsign-compare" +#pragma GCC diagnostic ignored "-Wmissing-field-initializers" +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-value" +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#pragma GCC diagnostic ignored "-Wunused-function" +#pragma GCC diagnostic ignored "-Wuninitialized" +#pragma GCC diagnostic ignored "-Wclass-memaccess" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-parameter" + +#elif defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4457) +#pragma warning(disable : 4189) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4267) +#pragma warning(disable : 4702) +#endif + +#include "bestla/bestla_prologue_a.h" +#include "bestla/bestla_wrapper.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index d22da2a3da87f..2ad20eafc2ef1 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -149,10 +149,17 @@ TEST(MatMulNBits, Float32) { for (auto N : {1, 2, 32, 288}) { for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { +#ifdef ORT_NEURAL_SPEED + for (auto accuracy_level : {0, 1, 4}) { + RunTest(M, N, K, block_size, accuracy_level, false, false); + RunTest(M, N, K, block_size, accuracy_level, true, false); + } +#else for (auto accuracy_level : {0}) { RunTest(M, N, K, block_size, accuracy_level, false, false); RunTest(M, N, K, block_size, accuracy_level, true, false); } +#endif } } } @@ -185,6 +192,174 @@ TEST(MatMulNBits, Float16Large) { #endif +void RunSharedPrepackedWeightsTest(int64_t M, int64_t N, int64_t K, int block_size, bool is_asym, + int64_t acc_lvl) { + // (M x K) X (K x N) + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("accuracy_level", acc_lvl); + test.AddAttribute("block_size", int64_t(block_size)); + test.AddAttribute("bits", QBits); + test.AddAttribute("N", N); + test.AddAttribute("K", K); + + std::vector input0_vals(M * K); + float fv = -135.f; + for (auto& f : input0_vals) { + f = fv / 127; + fv++; + if (fv > 135.f) { + fv = -135.f; + } + } + + size_t kblks = K / block_size; + std::vector input1_vals(N * K / 2); + for (size_t i = 0; i < input1_vals.size(); i++) { + input1_vals[i] = uint8_t(i); + } + std::vector input2_vals(N * kblks, 0.002f); + for (size_t i = 0; i < N * kblks; i++) { + input2_vals[i] += (i % 100) * 0.00003f; + } + std::vector input3_vals(N * kblks / 2, static_cast(0x88)); + + std::vector input1_f_vals(N * K); + if (is_asym) { + for (size_t i = 0; i < N * kblks; i += 2) { + input3_vals[i / 2] = static_cast(i + 1); + } + for (int64_t i = 0; i < K; i += 2) { + for (int64_t j = 0; j < N; j++) { + auto srcv = input1_vals[j * K / 2 + i / 2]; + auto koff = i % (block_size * 2); + auto zpv = input3_vals[j * kblks / 2 + i / block_size / 2]; + auto zp0 = koff < block_size ? (zpv & 0xf) - 8 : ((zpv & 0xf0) >> 4) - 8; + auto src0 = (srcv & 0xf) - 8; + auto src1 = ((srcv & 0xf0) >> 4) - 8; + auto scale0 = input2_vals[j * kblks + i / block_size]; + auto scale1 = input2_vals[j * kblks + (i + 1) / block_size]; + input1_f_vals[i * N + j] = (static_cast(src0) - zp0) * scale0; + input1_f_vals[(i + 1) * N + j] = (static_cast(src1) - zp0) * scale1; + } + } + } else { + for (int64_t i = 0; i < K; i += 2) { + for (int64_t j = 0; j < N; j++) { + auto srcv = input1_vals[j * K / 2 + i / 2]; + auto src0 = (srcv & 0xf) - 8; + auto src1 = ((srcv & 0xf0) >> 4) - 8; + auto scale0 = input2_vals[j * kblks + i / block_size]; + auto scale1 = input2_vals[j * kblks + (i + 1) / block_size]; + input1_f_vals[i * N + j] = static_cast(src0) * scale0; + input1_f_vals[(i + 1) * N + j] = static_cast(src1) * scale1; + } + } + } + + std::vector expected_vals(M * N); + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + sum += input0_vals[m * K + k] * input1_f_vals[k * N + n]; + } + expected_vals[m * N + n] = sum; + } + } + + test.AddInput("A", {M, K}, input0_vals, false); + + test.AddInput("B", {N, static_cast(kblks), static_cast(block_size / 2)}, input1_vals, + true); + test.AddInput("scales", {N, static_cast(kblks)}, input2_vals, true); + if (is_asym) { + test.AddInput("zero_points", {N, static_cast(kblks / 2)}, input3_vals, true); + } + test.AddOutput("Y", {M, N}, expected_vals, false); + if (acc_lvl == 4) { + test.SetOutputAbsErr("Y", 0.1f); + } + + OrtValue b, scale, zp; + Tensor::InitOrtValue(DataTypeImpl::GetType(), + TensorShape({N, static_cast(kblks), static_cast(block_size / 2)}), + input1_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({N, static_cast(kblks)}), + input2_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), scale); + if (is_asym) { + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({N, static_cast(kblks / 2)}), + input3_vals.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), zp); + } + SessionOptions so; + // Set up B as a shared initializer to be shared between sessions + ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); + ASSERT_EQ(so.AddInitializer("scales", &scale), Status::OK()); + if (is_asym) { + ASSERT_EQ(so.AddInitializer("zero_points", &zp), Status::OK()); + } + + // We want all sessions running using this OpTester to be able to share pre-packed weights if applicable + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + // Pre-packing is limited just to the CPU EP for now and we will only test the CPU EP + // and we want to ensure that it is available in this build + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + // Session 1 + { + auto ep_vec = cpu_ep(); + test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {}, + &number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); + // Assert that no pre-packed weights have been shared thus far + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + auto number_of_elements_in_shared_prepacked_buffers_container = test.GetNumPrePackedWeightsShared(); + // Assert that the number of elements in the shared container + // is the same as the number of weights that have been pre-packed + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_elements_in_shared_prepacked_buffers_container); + + // On some platforms/architectures MLAS may choose to not do any pre-packing and the number of elements + // that have been pre-packed will be zero in which case we do not continue with the testing + // of "sharing" of pre-packed weights as there are no pre-packed weights to be shared at all. + if (number_of_pre_packed_weights_counter_session_1 == 0) return; + + // Session 2 + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + auto ep_vec = cpu_ep(); + test.Run(so, OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &ep_vec, {}, + &number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); + + // Assert that the same number of weights were pre-packed in both sessions + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_pre_packed_weights_counter_session_2); + + // Assert that the number of pre-packed weights that were shared equals + // the number of pre-packed weights in the second session + ASSERT_EQ(number_of_pre_packed_weights_counter_session_2, + static_cast(number_of_shared_pre_packed_weights_counter)); + } +} + +#ifdef ORT_NEURAL_SPEED +TEST(MatMulNBits, SharedPrepackedWeights) { + RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, true, 1); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 32, false, 1); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, 1); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 128, false, 4); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 1024, false, 4); + RunSharedPrepackedWeightsTest(2, 4096, 4096, 4096, false, 4); +} +#endif } // namespace test } // namespace onnxruntime From 90cf03767dff037fb5cca3b45f51a2f6b021ac23 Mon Sep 17 00:00:00 2001 From: Shubham Bhokare <32080845+shubhambhokare1@users.noreply.github.com> Date: Fri, 9 Feb 2024 12:26:39 -0800 Subject: [PATCH 061/207] Support ONNX export of OpenAi Whisper model (#17316) Build from source and run the command below Example, converting whisper-base ` python -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-base --model_impl openai -e -o -w --chain_model --output ./demo` --- .../transformers/fusion_bart_attention.py | 239 +++++++++++++++++- .../models/whisper/convert_to_onnx.py | 13 +- .../models/whisper/whisper_decoder.py | 18 +- .../models/whisper/whisper_encoder.py | 17 +- .../whisper/whisper_encoder_decoder_init.py | 25 +- .../models/whisper/whisper_helper.py | 57 ++++- .../models/whisper/whisper_openai_helper.py | 76 ++++++ .../python/tools/transformers/onnx_model.py | 48 ++++ .../tools/transformers/onnx_model_bart.py | 2 +- .../python/transformers/test_generation.py | 8 +- 10 files changed, 474 insertions(+), 29 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention.py b/onnxruntime/python/tools/transformers/fusion_bart_attention.py index 71801401e9d06..ebecc1db24792 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention.py @@ -74,13 +74,74 @@ def check_runtime_shape_path( return True + def check_runtime_shape_path_openai( + self, + reshape_qkv_2, + matmul_qkv, + add_qk, + matmul_qk, + add_q, + ): + reshape_qkv_2_path = self.model.match_parent_path( + reshape_qkv_2, ["Concat", "Slice", "Gather", "Shape"], [1, 0, 0, 0] + ) + if reshape_qkv_2_path is None: + return False + else: + if reshape_qkv_2_path[-1].input[0] != matmul_qkv.output[0]: + return False + + matmul_qk_path_1 = self.model.match_parent_path( + matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [0, 1, 0, 0, 0, 0] + ) + matmul_qk_path_2 = self.model.match_parent_path( + matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [1, 1, 0, 0, 0, 0] + ) + if matmul_qk_path_1 is None or matmul_qk_path_2 is None: + return False + + mul_1 = matmul_qk_path_1[0] + mul_2 = matmul_qk_path_2[0] + if mul_1.input[1] != mul_2.input[1]: + return False + if matmul_qk_path_1[-1].input[0] != add_q.output[0] and matmul_qk_path_2[-1].input[0] != add_q.output[0]: + return False + + # For decoder attentions only + if add_qk is not None: + add_qk_path = self.model.match_parent_path(add_qk, ["Slice"], [1]) + if add_qk_path is None: + return False + slice_q_path_1 = self.model.match_parent_path( + add_qk_path[0], ["Slice", "Unsqueeze", "Gather", "Shape"], [0, 2, 0, 0] + ) + slice_q_path_2 = self.model.match_parent_path(add_qk_path[0], ["Unsqueeze", "Gather", "Shape"], [2, 0, 0]) + if slice_q_path_1 is None and slice_q_path_2 is None: + return False + _, unsqueeze_1, _, _ = slice_q_path_1 + unsqueeze_2, _, _ = slice_q_path_2 + if unsqueeze_1.input[0] != unsqueeze_2.input[0]: + return False + if slice_q_path_1[-1].input[0] != add_q.output[0] and slice_q_path_2[-1].input[0] != add_q.output[0]: + return False + + return True + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + # Track if fusion is occurring for OpenAI implementation of Whisper + model_impl_openai = False + # SkipLayerNormalization has two inputs, and one of them is the root input for attention. qkv_nodes = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 1, 0, 0, 0, 0], ) + qkv_nodes_openai = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 0, 0, 0], + ) if qkv_nodes is not None: ( add_out, @@ -90,6 +151,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): reshape_qkv_1, matmul_qkv, ) = qkv_nodes + elif qkv_nodes_openai is not None: + qkv_nodes = qkv_nodes_openai + ( + add_out, + matmul_out, + reshape_qkv_2, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes + # Set model implementation to openai + model_impl_openai = True else: return @@ -137,6 +209,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, None], ) + v_nodes_openai = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, None], + ) v_nodes_with_past_self_attn = self.model.match_parent_path( # Decoder attention with past value concatenated before MatMul matmul_qkv, @@ -149,12 +226,52 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Reshape"], [1], ) + v_nodes_with_past_cross_attn_openai = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "Reshape", "Transpose"], + [1, 0, 0, 0], + ) past_v, present_v = "", "" reshape_v_2, add_v = None, None if v_nodes is not None: (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes # For initial pass through encoder-decoder_with_past to get starting past values (beam search) present_v = transpose_v.output[0] + elif v_nodes_openai is not None: + v_nodes = v_nodes_openai + (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes + # For initial pass through encoder-decoder_with_past to get starting past values (beam search) + + # Find the child path to access the correct present_v values + # Openai impl provides present/past v values in 3D format + # whereas ort MultiHeadAttention expects v values in 4D, hence the + # additional Reshape and Transpose nodes are added + # For encoder attention types + # Add -> Reshape -> Transpose -> Present_V + reshape_path = self.model.match_child_path( + add_v, + ["Reshape", "Transpose"], + exclude=[reshape_v_1], + ) + # For decoder attention types + # add_v_node Reshape <- Transpose <-Past_V + # \ / + # \ / + # -> Concat <- + # | + # |--> Reshape -> Transpose -> Present_V + concat_path = self.model.match_child_path(add_v, ["Concat", "Reshape", "Transpose"]) + if reshape_path is not None: + (_, transpose_add_v) = reshape_path + if transpose_add_v.output[0] in graph_output_names: + present_v = transpose_add_v.output[0] + if concat_path is not None: + (concat_v, _, transpose_concat_v) = concat_path + if transpose_concat_v.output[0] in graph_output_names: + present_v = transpose_concat_v.output[0] + concat_nodes = self.model.match_parent_path(concat_v, ["Reshape", "Transpose"], [0, 0]) + _, transpose_concat_v_in = concat_nodes + past_v = transpose_concat_v_in.input[0] elif v_nodes_with_past_self_attn is not None: (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn v_nodes = v_nodes_with_past_self_attn @@ -171,6 +288,18 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) ) present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" + elif ( + v_nodes_with_past_cross_attn_openai is not None + and v_nodes_with_past_cross_attn_openai[-1].input[0] in graph_input_names + ): + v_nodes = v_nodes_with_past_cross_attn_openai + past_v = v_nodes[-1].input[0] + present_v = v_nodes[-1].output[0] + if present_v not in graph_output_names: + identity_node_v = list( + filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v]) + ) + present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else "" else: logger.debug("fuse_attention: failed to match v path") return @@ -181,12 +310,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qk_nodes_2 = self.model.match_parent_path( matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0] ) + qk_nodes_2_openai = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) + add_qk = None if qk_nodes_1 is not None: _, matmul_qk = qk_nodes_1 qk_nodes = qk_nodes_1 elif qk_nodes_2 is not None: _, _, add_qk, _, matmul_qk = qk_nodes_2 qk_nodes = qk_nodes_2 + elif qk_nodes_2_openai is not None: + _, add_qk, matmul_qk = qk_nodes_2_openai + qk_nodes = qk_nodes_2_openai else: return @@ -195,8 +329,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], [0, 0, 0, 0, 0, 1], ) + q_nodes_openai = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 0, 1], + ) + reshape_q_2 = None if q_nodes is not None: reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes + elif q_nodes_openai is not None: + q_nodes = q_nodes_openai + mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes else: return @@ -205,6 +348,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0, 0, 1], ) + k_nodes_with_bias_openai = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0], + ) k_nodes_no_bias = self.model.match_parent_path( matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], @@ -222,11 +370,52 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Transpose", "Reshape"], [1, 0], ) + k_nodes_no_bias_with_past_cross_attn_openai = self.model.match_parent_path( + # Decoder attention with past key directly used in MatMul + matmul_qk, + ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"], + [1, 0, 0, 0, 0], + ) past_k, present_k = "", "" reshape_k_2, reshape_k_1, matmul_k = None, None, None if k_nodes_with_bias is not None: _, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias k_nodes = k_nodes_with_bias + elif k_nodes_with_bias_openai is not None: + mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias_openai + k_nodes = k_nodes_with_bias_openai + present_k = matmul_k.output[0] + + # Find the child path to access the correct present_k values + # Openai impl provides present/past k values in 3D format + # whereas ort MultiHeadAttention expects k values in 4D, hence the + # additional Reshape and Transpose nodes are added + # For encoder attention types + # Matmul -> Reshape -> Transpose -> Present_K + reshape_path = self.model.match_child_path( + matmul_k, + ["Reshape", "Transpose"], + exclude=[reshape_k_1], + ) + # For decoder attention types + # matmul_k_node Reshape <- Transpose <- Past_K + # \ / + # \ / + # -> Concat <- + # | + # |--> Reshape -> Transpose -> Present_K + concat_path = self.model.match_child_path(matmul_k, ["Concat", "Reshape", "Transpose"]) + if reshape_path is not None: + (_, transpose_matmul_k) = reshape_path + if transpose_matmul_k.output[0] in graph_output_names: + present_k = transpose_matmul_k.output[0] + if concat_path is not None: + (concat_k, _, transpose_concat_k) = concat_path + if transpose_concat_k.output[0] in graph_output_names: + present_k = transpose_concat_k.output[0] + concat_nodes = self.model.match_parent_path(concat_k, ["Reshape", "Transpose"], [0, 0]) + _, transpose_concat_k_in = concat_nodes + past_k = transpose_concat_k_in.input[0] elif k_nodes_no_bias is not None: _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias k_nodes = k_nodes_no_bias @@ -249,12 +438,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) ) present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" + elif ( + k_nodes_no_bias_with_past_cross_attn_openai is not None + and k_nodes_no_bias_with_past_cross_attn_openai[-1].input[0] in graph_input_names + ): + k_nodes = k_nodes_no_bias_with_past_cross_attn_openai + past_k = k_nodes[-1].input[0] + present_k = k_nodes[-1].output[0] + if present_k not in graph_output_names: + identity_node_k = list( + filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k]) + ) + present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else "" else: return past_k = past_k if past_k in graph_input_names else "" present_k = present_k if present_k in graph_output_names else "" - if k_nodes in (k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): + if k_nodes in (k_nodes_with_bias_openai, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn): # Create empty Add node for attention graph bias_dim = self.model.get_initializer(add_v.input[0]).dims[0] empty_bias_name = "empty_bias" @@ -270,13 +471,29 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_name = self.model.create_node_name("Add") add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) - if not past_k and not self.check_runtime_shape_path( - reshape_qkv_2, - reshape_qkv_1, - reshape_q_2, - reshape_k_2, - reshape_v_2, - root_input, + if ( + model_impl_openai + and not past_k + and not self.check_runtime_shape_path_openai( + reshape_qkv_2, + matmul_qkv, + add_qk, + matmul_qk, + add_q, + ) + ): + return + elif ( + not model_impl_openai + and not past_k + and not self.check_runtime_shape_path( + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + root_input, + ) ): return @@ -301,8 +518,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1 # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1 encoder_attention = one_root_input and qk_nodes == qk_nodes_1 - decoder_attention = one_root_input and qk_nodes == qk_nodes_2 - decoder_attention_with_past = encoder_attention and past_k and past_v + decoder_attention = one_root_input and qk_nodes in (qk_nodes_2, qk_nodes_2_openai) + decoder_attention_with_past = ( + (encoder_attention if not model_impl_openai else decoder_attention) and past_k and past_v + ) decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1 decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1 diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index e15a12c07bed7..bb697fe1e1506 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -38,6 +38,15 @@ def parse_arguments(argv=None): help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models), ) + parser.add_argument( + "--model_impl", + required=False, + default="hf", + choices=["hf", "openai"], + type=str, + help="Select implementation for export of encoder and decoder subgraphs", + ) + parser.add_argument( "--cache_dir", required=False, @@ -300,6 +309,7 @@ def parse_arguments(argv=None): def export_onnx_models( model_name_or_path, + model_impl, cache_dir, output_dir, use_gpu, @@ -321,7 +331,7 @@ def export_onnx_models( device = torch.device("cuda:0" if use_gpu else "cpu") models = WhisperHelper.load_model( - model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path + model_name_or_path, model_impl, cache_dir, device, merge_encoder_and_decoder_init, state_dict_path ) config = models["decoder"].config @@ -431,6 +441,7 @@ def main(argv=None): output_paths = export_onnx_models( args.model_name_or_path, + args.model_impl, cache_dir, output_dir, args.use_gpu, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index eca5ce3de15d3..0d69960a095ac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -18,6 +18,7 @@ from onnx_model import OnnxModel from torch_onnx_export_helper import torch_onnx_export from transformers import WhisperConfig, file_utils +from whisper_openai_helper import WhisperDecoderInitOpenai from onnxruntime import InferenceSession @@ -67,10 +68,13 @@ def forward( class WhisperDecoder(torch.nn.Module): """A Whisper decoder with past key values""" - def __init__(self, decoder, config): + def __init__(self, decoder, config, model_impl: str = "hf", model: torch.nn.Module = None): super().__init__() self.decoder = decoder self.config = config + self.model_impl = model_impl + if model is not None: + self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder) def forward(self, decoder_input_ids, *past): encoder_outputs = file_utils.ModelOutput() @@ -78,6 +82,14 @@ def forward(self, decoder_input_ids, *past): encoder_outputs["last_hidden_state"] = dummy_encoder_hidden_states encoder_outputs["hidden_states"] = dummy_encoder_hidden_states encoder_outputs["attentions"] = None + + if self.model_impl == "openai": + dummy_encoder_hidden_states.unsqueeze(0) + dec_out, present = self.whisper_decoder_openai_init( + decoder_input_ids, dummy_encoder_hidden_states, past=past + ) + return dec_out, present + if len(past) == 0: past_key_values = None else: @@ -158,7 +170,7 @@ def create_dummy( cross_attention_past_shape = [ batch_size, num_attention_heads, - encode_sequence_length, + past_decode_sequence_length, head_size, ] @@ -213,7 +225,7 @@ def export_onnx( decoder.config, batch_size=2, encode_sequence_length=3000, - past_decode_sequence_length=5 if isinstance(decoder, WhisperDecoder) else 0, + past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0, device=device, use_int32_inputs=use_int32_inputs, ) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py index 826d6e42c0775..93281848a5c9c 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py @@ -25,12 +25,15 @@ class WhisperEncoder(torch.nn.Module): """Whisper encoder outputs only the last hidden state""" - def __init__(self, encoder, config: WhisperConfig): + def __init__(self, encoder, config: WhisperConfig, model_impl: str = "hf"): super().__init__() self.encoder = encoder self.config = config + self.model_impl = model_impl def forward(self, input_features): + if self.model_impl == "openai": + return self.encoder(input_features) return self.encoder.model.encoder(input_features)[0] @@ -40,7 +43,11 @@ def __init__(self, input_features): @staticmethod def create_dummy( - batch_size: int, sequence_length: int, feature_size: int, device: torch.device, use_int32_inputs: bool + batch_size: int, + sequence_length: int, + feature_size: int, + device: torch.device, + use_int32_inputs: bool = False, ): """Create dummy inputs for Whisper encoder. @@ -61,9 +68,9 @@ def create_dummy( return WhisperEncoderInputs(input_features) def to_list(self) -> List: - if self.input_features is None: + if self.input_ids is None: return [] - return [self.input_features] + return [self.input_ids] class WhisperEncoderHelper: @@ -74,6 +81,7 @@ def export_onnx( onnx_model_path: str, verbose: bool = True, use_external_data_format: bool = False, + use_int32_inputs: bool = False, ): """Export encoder to ONNX @@ -90,6 +98,7 @@ def export_onnx( sequence_length=3000, feature_size=config.num_mel_bins, device=device, + use_int32_inputs=use_int32_inputs, ) Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index a145178dbf37e..351173f525727 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -4,6 +4,7 @@ # license information. # -------------------------------------------------------------------------- +import copy import logging import os import tempfile @@ -19,6 +20,7 @@ from transformers import WhisperConfig from whisper_decoder import WhisperDecoderInit from whisper_encoder import WhisperEncoder, WhisperEncoderInputs +from whisper_openai_helper import WhisperDecoderInitOpenai from onnxruntime import InferenceSession @@ -34,11 +36,16 @@ def __init__( decoder: torch.nn.Module, config: WhisperConfig, decoder_start_token_id: Optional[int] = None, + model_impl: str = "hf", + model: torch.nn.Module = None, ): super().__init__() self.config = config - self.whisper_encoder = WhisperEncoder(encoder, config) + self.whisper_encoder = WhisperEncoder(encoder, config, model_impl=model_impl) self.whisper_decoder_init = WhisperDecoderInit(decoder, config, decoder_start_token_id) + if model is not None: + self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder) + self.model_impl = model_impl def forward( self, @@ -47,9 +54,14 @@ def forward( ): encoder_hidden_states: torch.FloatTensor = self.whisper_encoder(encoder_input_ids) # Decoder out: (logits, past_key_values, encoder_hidden_state) - decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) - present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(decinit_out[1]) - present = present_self + present_cross + if self.model_impl == "openai": + encoder_hidden_states.unsqueeze(0) + decinit_out, present = self.whisper_decoder_openai_init(decoder_input_ids, encoder_hidden_states) + return decinit_out, encoder_hidden_states, present + else: + decinit_out = self.whisper_decoder_init(decoder_input_ids, encoder_hidden_states) + present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(decinit_out[1]) + present = present_self + present_cross return decinit_out[0], encoder_hidden_states, present @@ -72,7 +84,6 @@ def create_dummy( sequence_length=3000, feature_size=config.num_mel_bins, device=device, - use_int32_inputs=use_int32_inputs, ) decoder_input_ids = None if use_decoder_input_ids: @@ -120,7 +131,9 @@ def export_onnx( ) input_list = inputs.to_list() - out = model(inputs.encoder_input_ids, inputs.decoder_input_ids) + # TODO : Investigate whether copy of model if needed + cloned_model = copy.deepcopy(model).to(device) + out = cloned_model(inputs.encoder_input_ids, inputs.decoder_input_ids) present = out[2] present_names = PastKeyValuesHelper.get_input_names(present, encoder=True) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index f65060f4f93d3..e2dc79ca247ce 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -72,9 +72,49 @@ def get_onnx_path( directory = os.path.join(output_dir, model_name) if new_folder else output_dir return os.path.join(directory, model_name + ".onnx") + @staticmethod + def load_model_openai( + model_name_or_path: str, + cache_dir: str, + device: torch.device, + ) -> torch.nn.Module: + """Load model given a pretrained name or path, then build models for ONNX conversion. + + Args: + model_name_or_path (str): pretrained model name or path + cache_dir (str): cache directory + device (torch.device): device to run the model + merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True. + Returns: + Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion. + """ + from whisper import _ALIGNMENT_HEADS, _MODELS, _download + from whisper.model import ModelDimensions, Whisper + + in_memory = False + + model_name = model_name_or_path.split("/")[-1][8:] + checkpoint_file, alignment_heads = None, None + if model_name in _MODELS: + checkpoint_file = _download(_MODELS[model_name], cache_dir, in_memory) + alignment_heads = _ALIGNMENT_HEADS[model_name] + + with open(checkpoint_file, "rb") as fp: + checkpoint = torch.load(fp, map_location=device) + del checkpoint_file + + dims = ModelDimensions(**checkpoint["dims"]) + model = Whisper(dims) + model.load_state_dict(checkpoint["model_state_dict"]) + + if alignment_heads is not None: + model.set_alignment_heads(alignment_heads) + return model.to(device) + @staticmethod def load_model( model_name_or_path: str, + model_impl: str, cache_dir: str, device: torch.device, merge_encoder_and_decoder_init: bool = True, @@ -94,18 +134,29 @@ def load_model( if version.parse(transformers_version) >= version.parse("4.36.0"): extra_kwargs["attn_implementation"] = "eager" model = WhisperForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir, **extra_kwargs) + + if model_impl == "openai": + openai_model = WhisperHelper.load_model_openai(model_name_or_path, cache_dir, device) + model_encoder, model_decoder = openai_model.encoder, openai_model.decoder + passed_model = openai_model + else: + model_encoder, model_decoder = model, model + passed_model = None + if state_dict_path: model.load_state_dict(torch.load(state_dict_path), strict=False) - decoder = WhisperDecoder(model, model.config) + decoder = WhisperDecoder(model_decoder, model.config, model_impl=model_impl, model=passed_model) decoder.eval().to(device) if merge_encoder_and_decoder_init: encoder_decoder_init = WhisperEncoderDecoderInit( - model, - model, + model_encoder, + model_decoder, model.config, decoder_start_token_id=None, + model_impl=model_impl, + model=passed_model, ) return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder} else: diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py new file mode 100644 index 0000000000000..941f61cf7cc29 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_openai_helper.py @@ -0,0 +1,76 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +class WhisperDecoderInitOpenai(torch.nn.Module): + """WhisperDecoderInit for Openai.""" + + def __init__( + self, + model: torch.nn.Module, + decoder: torch.nn.Module, + ): + super().__init__() + self.whisper_model = model + self.whisper_decoder = decoder + self.kv_cache = {} + + @torch.no_grad() + def forward( + self, + tokens, + audio_features, + past=None, + ): + # Create a kv_cache for past_values + past_kv_cache = dict() + if past is not None: + # Convert past values from 4D to 3D + past = [torch.transpose(val, 1, 2) for val in past] + past = [val.reshape(val.shape[:2] + (-1,)) for val in past] + half_idx = len(past) // 2 + for idx, block in enumerate(self.whisper_decoder.blocks): + past_kv_cache[block.attn.key] = past[2 * idx] + past_kv_cache[block.attn.value] = past[2 * idx + 1] + past_kv_cache[block.cross_attn.key] = past[2 * idx + half_idx] + past_kv_cache[block.cross_attn.value] = past[2 * idx + half_idx + 1] + + if not self.kv_cache: + self.kv_cache, _ = self.whisper_model.install_kv_cache_hooks() + + logits = self.whisper_decoder(tokens, audio_features, kv_cache=past_kv_cache) + + # Add concat node for past values + if past is not None: + for block in self.whisper_decoder.blocks: + self.kv_cache[block.attn.key] = torch.cat( + [past_kv_cache[block.attn.key], self.kv_cache[block.attn.key]], dim=1 + ).detach() + self.kv_cache[block.attn.value] = torch.cat( + [past_kv_cache[block.attn.value], self.kv_cache[block.attn.value]], dim=1 + ).detach() + + present_self, present_cross = [], [] + # Group self and cross values + for block in self.whisper_decoder.blocks: + present_self.append(self.kv_cache[block.attn.key]) + present_self.append(self.kv_cache[block.attn.value]) + if past is None: + present_cross.append(self.kv_cache[block.cross_attn.key]) + present_cross.append(self.kv_cache[block.cross_attn.value]) + + present_self = present_self + present_cross + # Add reshape and transpose ops to convert from 3D to 4D + present_self = [ + present_val.reshape(present_val.shape[:2] + (-1, 64)).transpose(1, 2) for present_val in present_self + ] + return logits, present_self diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index e9029f4f620f8..a8fc6e661933e 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -430,6 +430,54 @@ def find_first_child_by_type(self, node, child_type, input_name_to_nodes=None, r return None + def match_child_path( + self, + node, + child_op_types, + child_output_index=None, + return_indice=None, + exclude=[], # noqa: B006 + ): + """ + Find a sequence of input edges based on constraints on parent op_type and index. + When input_index is None, we will find the first parent node based on constraints, + and return_indice will be appended the corresponding input index. + + Args: + node (str): current node name. + child_op_types (str): constraint of child node op_type of each input edge. + child_output_index (list): constraint of input index of each input edge. None means no constraint. + return_indice (list): a list to append the input index + When there is no constraint on input index of an edge. + + Returns: + children: a list of matched children node. + """ + if child_output_index is not None: + assert len(child_output_index) == len(child_op_types) + + current_node = node + matched_children = [] + for i, op_type in enumerate(child_op_types): + matched_child = None + node_children = self.get_children(current_node) + for child_i, child in enumerate(node_children): + if child.op_type == op_type and child not in exclude: + if child_output_index is not None and child_output_index[i] != child_i: + logger.debug( + f"Failed to match index={i} child_output_index={child_output_index[i]} op_type={op_type}", + stack_info=True, + ) + return None + matched_child = child + if matched_child is None: + logger.debug(f"Failed to match child op_type={op_type}", stack_info=True) + return None + + matched_children.append(matched_child) + current_node = matched_child + return matched_children + def find_first_parent_by_type(self, node, parent_type, output_name_to_node=None, recursive=True): if output_name_to_node is None: output_name_to_node = self.output_name_to_node() diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index 2a48722d17a19..61a786d7af60b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -121,7 +121,7 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): class BartOnnxModel(BertOnnxModel): - def __init__(self, model, num_heads, hidden_size): + def __init__(self, model, num_heads, hidden_size, model_impl="hf"): super().__init__(model, num_heads, hidden_size) self.attention_mask = AttentionMask(self) self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask) diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index c9db1fbc02931..40ea8cf774918 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -361,7 +361,8 @@ def run_configs(self, optional_arguments): # INT8 CPU arguments = self.base_arguments + self.int8_cpu_arguments + optional_arguments - self.run_export(arguments) + if "--model_impl" not in arguments: + self.run_export(arguments) @pytest.mark.slow def test_required_args(self): @@ -393,6 +394,11 @@ def test_cross_qk_overall(self): ] self.run_configs(decoder_input_ids) + @pytest.mark.slow + def test_openai_impl_whisper(self): + optional_args = ["--model_impl", "openai", "--chain_model", "--use_whisper_beamsearch"] + self.run_configs(optional_args) + if __name__ == "__main__": unittest.main() From 0e984ef0d1c41cbc54a453f6de0c5e8784df5e74 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 9 Feb 2024 15:27:04 -0800 Subject: [PATCH 062/207] Small fixes in Resize CPU antialias (#19476) ### Description Add a comment, pass NCHW = true when setting upsample_antialias ### Motivation and Context Small bugs. --- onnxruntime/core/providers/cpu/tensor/upsample_antialias.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h index 59b512def619d..e1dcaf500a325 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h @@ -700,7 +700,7 @@ void ResizeBiCubicAntiAlias(int64_t batch_size, BiCubicParamsAntiAlias::type> p; p.cubic_coeff_a = cubic_coeff_a; SetupUpsampleFilterAntiAlias(p, input_paras, output_paras, scale_paras, roi, - alloc, get_original_coordinate, exclude_outside, false); + alloc, get_original_coordinate, exclude_outside, true); return UpsampleBaseAntiAlias(p, batch_size, num_channels, input_height, input_width, output_height, output_width, use_extrapolation, extrapolation_value, From 1182b5509ba2604856d02cf22795d6874252892e Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Sat, 10 Feb 2024 00:34:34 -0800 Subject: [PATCH 063/207] Disable streams for the DML EP (#19481) There's currently a bug in the allocation planner when reusing buffers and more than one streams are used that make it possible (although rarely) to reach a reference count of 0 for a buffer that is still being used. Since DML doesn't benefit from multiple streams, disabling it is the safest option for now. This is a high priority issue that we need to fix for 1.17.1 since it breaks stable diffusion. Identifying the perfect fix and fixing the underlying issue would be too risky for a patch release, especially given the limited time that we have. https://github.com/microsoft/onnxruntime/issues/19480 --- cmake/adjust_global_compile_flags.cmake | 9 ++- .../test/framework/allocation_planner_test.cc | 21 +++++-- onnxruntime/test/framework/bfc_arena_test.cc | 2 + .../test/framework/execution_frame_test.cc | 55 +++++++++++++++++-- 4 files changed, 76 insertions(+), 11 deletions(-) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 2c7bf9f1c2f5c..a56864ebf4644 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -92,8 +92,13 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() -# enable stream for all the non-minimal build -if (NOT onnxruntime_MINIMAL_BUILD) +# Enable stream for all the non-minimal build, except for DML. There's currently a bug +# in the allocation planner when reusing buffers and more than one streams are used that +# make it possible (although rarely) to reach a reference count of 0 for a buffer that is +# still being used. Since DML doesn't benefit from multiple streams, disabling it is the +# safest option for now. +# https://github.com/microsoft/onnxruntime/issues/19480 +if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML) add_compile_definitions(ORT_ENABLE_STREAM) endif() diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index b174ee4138be3..d7b1de5c930c5 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -327,10 +327,23 @@ class PlannerTest : public ::testing::Test { if (invoke_createPlan_explicityly) { onnxruntime::GraphViewer graph_viewer{graph_}; - status = SequentialPlanner::CreatePlan(nullptr, graph_viewer, outer_scope_node_args, execution_providers_, - kernel_create_info_map, {}, {}, state_->GetOrtValueNameIdxMap(), test_context, - MockStreamHandleRegsitry(), /* {{kCpuExecutionProvider, 1}}, {},*/ - ORT_TSTR(""), DefaultLoggingManager().DefaultLogger(), plan_); + status = SequentialPlanner::CreatePlan( + nullptr, + graph_viewer, + outer_scope_node_args, + execution_providers_, + kernel_create_info_map, + {}, + {}, + state_->GetOrtValueNameIdxMap(), + test_context, +#ifdef ORT_ENABLE_STREAM + MockStreamHandleRegsitry(), +#endif + /* {{kCpuExecutionProvider, 1}}, {},*/ + ORT_TSTR(""), + DefaultLoggingManager().DefaultLogger(), + plan_); EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); // AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size()); diff --git a/onnxruntime/test/framework/bfc_arena_test.cc b/onnxruntime/test/framework/bfc_arena_test.cc index 0d3e4449da939..e9f734057da1c 100644 --- a/onnxruntime/test/framework/bfc_arena_test.cc +++ b/onnxruntime/test/framework/bfc_arena_test.cc @@ -337,6 +337,7 @@ struct StreamMock : public Stream { Status CleanUpOnRunEnd() override { return Status::OK(); } }; +#ifdef ORT_ENABLE_STREAM TEST(StreamAwareArenaTest, TwoStreamAllocation) { StreamAwareArena a(std::unique_ptr(new CPUAllocator()), 1 << 30, false); CheckStats(&a, 0, 0, 0, 0); @@ -413,6 +414,7 @@ TEST(StreamAwareArenaTest, TestSecureTheChunk) { EXPECT_TRUE(waitFunctionInvoked) << "wait function should be invoked"; a.Free(p2); } +#endif TEST(BFCArenaTest, TestExtendStrategy) { int64_t extend_delta_bytes = 0; diff --git a/onnxruntime/test/framework/execution_frame_test.cc b/onnxruntime/test/framework/execution_frame_test.cc index ec572ce9deed8..60752d7456d97 100644 --- a/onnxruntime/test/framework/execution_frame_test.cc +++ b/onnxruntime/test/framework/execution_frame_test.cc @@ -75,7 +75,16 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) { ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); vector outputs; - ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state); + ExecutionFrame frame( + {}, + {}, + {}, + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); int start_index = frame.GetNodeOffset(node->Index()); ASSERT_EQ(start_index, 0); @@ -150,7 +159,16 @@ TEST_F(ExecutionFrameTest, OutputShapeValidationTest) { ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager)); vector outputs; - ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state); + ExecutionFrame frame( + {}, + {}, + {}, + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); int start_index = frame.GetNodeOffset(node->Index()); ASSERT_EQ(start_index, 0); @@ -216,7 +234,16 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) { ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("Y", y_idx).IsOK()); vector outputs; - ExecutionFrame frame(AsSpan({x_idx}), AsSpan({value}), AsSpan({y_idx}), outputs, {}, {}, state); + ExecutionFrame frame( + AsSpan({x_idx}), + AsSpan({value}), + AsSpan({y_idx}), + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); OrtValue* p_ml_value = frame.GetMutableNodeInputOrOutputMLValue(0); Tensor* p_tensor_arg_0 = p_ml_value ? p_ml_value->GetMutable() : nullptr; @@ -299,7 +326,16 @@ TEST_F(ExecutionFrameTest, MemPatternTest) { std::vector(6, 1.0f), &v3); std::vector outputs; - ExecutionFrame frame(AsSpan({x1_idx, x2_idx, x3_idx}), AsSpan({v1, v2, v3}), AsSpan({t3_idx}), outputs, {}, {}, state); + ExecutionFrame frame( + AsSpan({x1_idx, x2_idx, x3_idx}), + AsSpan({v1, v2, v3}), + AsSpan({t3_idx}), + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); OrtValue& mlvalue3 = *frame.GetMutableNodeInputOrOutputMLValue(3); OrtValue& mlvalue4 = *frame.GetMutableNodeInputOrOutputMLValue(4); @@ -388,7 +424,16 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) { CreateMLValue(cpu_allocator, std::vector{2, 2}, std::vector(4, 1.0f), &t_value); vector outputs; - ExecutionFrame frame(AsSpan({x_idx}), AsSpan({x_value}), AsSpan({y_idx}), outputs, {}, {}, state); + ExecutionFrame frame( + AsSpan({x_idx}), + AsSpan({x_value}), + AsSpan({y_idx}), + outputs, + {}, +#ifdef ORT_ENABLE_STREAM + {}, +#endif + state); ASSERT_FALSE(frame.GetMutableNodeInputOrOutputMLValue(t_idx)->IsTensor()); ASSERT_STATUS_OK(frame.SetOutputMLValue(t_idx, t_value)); From d00adb7989635a4046d896fab6358ad4c7b695db Mon Sep 17 00:00:00 2001 From: Fangrui Song Date: Sun, 11 Feb 2024 19:18:26 -0800 Subject: [PATCH 064/207] Align bins_space_ storage (#17552) Otherwise, `new (BinFromIndex(b)) Bin(this, bin_size);` in bfc_arena.cc would cause a -fsanitize=alignment (part of -fsanitize=undefined) failure like runtime error: constructor call on misaligned address 0xXXX for type 'Bin', which requires 8 byte alignment --- onnxruntime/core/framework/bfc_arena.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/bfc_arena.h b/onnxruntime/core/framework/bfc_arena.h index e16b90ded3381..5e4cd9f62f11b 100644 --- a/onnxruntime/core/framework/bfc_arena.h +++ b/onnxruntime/core/framework/bfc_arena.h @@ -482,7 +482,7 @@ class BFCArena : public IAllocator { Bin* BinForSize(size_t bytes) { return BinFromIndex(BinNumForSize(bytes)); } - char bins_space_[sizeof(Bin) * kNumBins]; + alignas(Bin) char bins_space_[sizeof(Bin) * kNumBins]; // The size of the current region allocation. SafeInt curr_region_allocation_bytes_; From c831031ad54657916f0b0867327f11eb1691ee14 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Mon, 12 Feb 2024 09:24:36 -0800 Subject: [PATCH 065/207] Remove cuda gencode 90 to reduce onnxruntime-training package size (#19486) --- .../azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml index f244851f8cc37..d9ab85ee80ce3 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml @@ -15,7 +15,7 @@ stages: torch_version: '2.0.0' opset_version: '15' cuda_version: '11.8' - cmake_cuda_architectures: 60;61;70;75;80;86;90 + cmake_cuda_architectures: 60;61;70;75;80;86 docker_file: Dockerfile.manylinux2_28_training_cuda11_8 agent_pool: Onnxruntime-Linux-GPU upload_wheel: 'yes' From 9cb97ee507b9b45d4a896f663590083e7e7568ac Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Mon, 12 Feb 2024 09:39:49 -0800 Subject: [PATCH 066/207] Disable CPU EP's allocator's arena when address sanitizer is enabled (#19485) ### Description Disable CPU EP's allocator's arena when address sanitizer is enabled, because it masks problems. For example, the code in onnxruntime/test/quantization/quantization_test.cc has a memory leak problem: it allocated a buffer but didn't free it, but most memory leak check tool cannot detect that because the buffer was from an arena and the arena was finally freed. ### Motivation and Context Provider better memory leak check coverage. --- .../core/providers/cpu/cpu_execution_provider.cc | 3 ++- onnxruntime/test/framework/allocator_test.cc | 3 ++- onnxruntime/test/framework/session_state_test.cc | 3 ++- onnxruntime/test/framework/tensor_test.cc | 4 ++-- onnxruntime/test/quantization/quantization_test.cc | 14 ++++++-------- 5 files changed, 14 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index cbdf79caf3afd..813fdc54ecd0d 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/cpu_execution_provider.h" +#include #include "core/framework/op_kernel.h" #include "core/framework/kernel_registry.h" #include "core/mlas/inc/mlas.h" @@ -29,7 +30,7 @@ CPUExecutionProvider::CPUExecutionProvider(const CPUExecutionProviderInfo& info) std::vector CPUExecutionProvider::CreatePreferredAllocators() { bool create_arena = info_.create_arena; -#if defined(USE_JEMALLOC) || defined(USE_MIMALLOC) +#if defined(USE_JEMALLOC) || defined(USE_MIMALLOC) || defined(ABSL_HAVE_ADDRESS_SANITIZER) // JEMalloc/mimalloc already have memory pool, so just use device allocator. create_arena = false; #elif !(defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) diff --git a/onnxruntime/test/framework/allocator_test.cc b/onnxruntime/test/framework/allocator_test.cc index 2c1cd48d3d02f..8961058628490 100644 --- a/onnxruntime/test/framework/allocator_test.cc +++ b/onnxruntime/test/framework/allocator_test.cc @@ -1,5 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include "core/framework/allocator.h" @@ -15,7 +16,7 @@ TEST(AllocatorTest, CPUAllocatorTest) { EXPECT_EQ(cpu_arena->Info().id, 0); // arena is disabled for CPUExecutionProvider on x86 and JEMalloc -#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_JEMALLOC) && !defined(USE_MIMALLOC) +#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_JEMALLOC) && !defined(USE_MIMALLOC) && !defined(ABSL_HAVE_ADDRESS_SANITIZER) EXPECT_EQ(cpu_arena->Info().alloc_type, OrtAllocatorType::OrtArenaAllocator); #else EXPECT_EQ(cpu_arena->Info().alloc_type, OrtAllocatorType::OrtDeviceAllocator); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 0c2d8bcb2eb93..ed698ab920147 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "asserts.h" #include "core/framework/execution_providers.h" @@ -215,7 +216,7 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { // if the relevant session option config flag is set // For this test we need to enable the arena-based allocator which is not supported on x86 builds, so // enable this test only on x64 builds -#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_MIMALLOC) +#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_MIMALLOC) && !defined(ABSL_HAVE_ADDRESS_SANITIZER) TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { AllocatorPtr cpu_allocator = std::make_shared(); // Part 1: Feature turned ON (i.e.) allocate from non-arena memory diff --git a/onnxruntime/test/framework/tensor_test.cc b/onnxruntime/test/framework/tensor_test.cc index 38e3f184ebc18..9202543b75a6f 100644 --- a/onnxruntime/test/framework/tensor_test.cc +++ b/onnxruntime/test/framework/tensor_test.cc @@ -6,7 +6,7 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" - +#include #include namespace onnxruntime { @@ -138,7 +138,7 @@ TEST(TensorTest, EmptyTensorTest) { EXPECT_EQ(location.id, 0); // arena is disabled for CPUExecutionProvider on x86 and JEMalloc -#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_JEMALLOC) && !defined(USE_MIMALLOC) +#if (defined(__amd64__) || defined(_M_AMD64) || defined(__aarch64__) || defined(_M_ARM64)) && !defined(USE_JEMALLOC) && !defined(USE_MIMALLOC) && !defined(ABSL_HAVE_ADDRESS_SANITIZER) EXPECT_EQ(location.alloc_type, OrtAllocatorType::OrtArenaAllocator); #else EXPECT_EQ(location.alloc_type, OrtAllocatorType::OrtDeviceAllocator); diff --git a/onnxruntime/test/quantization/quantization_test.cc b/onnxruntime/test/quantization/quantization_test.cc index bdfac77b336d4..773f56de5361b 100644 --- a/onnxruntime/test/quantization/quantization_test.cc +++ b/onnxruntime/test/quantization/quantization_test.cc @@ -99,24 +99,22 @@ void EnsureQuantizedTensorParam(const float scale, const T zero_point) { // First, create the scale tensor: auto alloc = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; - auto num_bytes = shape.Size() * sizeof(float); - void* data = alloc->Alloc(num_bytes); - float* float_data = static_cast(data); + IAllocatorUniquePtr buffer = IAllocator::MakeUniquePtr(alloc, shape.Size()); + float* float_data = buffer.get(); float_data[0] = scale; Tensor scale_tensor(DataTypeImpl::GetType(), shape, - data, + float_data, alloc->Info(), /*offset=*/0); // Next, create the zero_point tensor: - auto T_num_bytes = shape.Size() * sizeof(T); - void* T_data = alloc->Alloc(T_num_bytes); - T* typed_data = static_cast(T_data); + IAllocatorUniquePtr buffer2 = IAllocator::MakeUniquePtr(alloc, shape.Size()); + T* typed_data = buffer2.get(); typed_data[0] = zero_point; Tensor zero_point_tensor(DataTypeImpl::GetType(), shape, - T_data, + typed_data, alloc->Info(), /*offset=*/0); From 90e2e8561f86efa634a7fe364d87e2d1dc159a69 Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Tue, 13 Feb 2024 02:01:08 +0530 Subject: [PATCH 067/207] Ovep 1.17.1 (#19482) ### Description Handle bugs for API backward compatability. Update to consume the onnx model path rather the onnx serialised model to OV compile_model API --- .../openvino/backends/basic_backend.cc | 11 +++++++---- .../core/providers/openvino/ov_interface.cc | 4 ++-- .../core/providers/openvino/ov_interface.h | 2 +- .../core/session/provider_bridge_ort.cc | 18 ++++++++++-------- 4 files changed, 20 insertions(+), 15 deletions(-) diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index e6c093d584031..0779940983aea 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -70,10 +70,13 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else - if (global_context_.disable_dynamic_shapes && dev_prec != "CPU_FP16") { - const std::string model = model_proto.SerializeAsString(); - exe_network_ = global_context_.ie_core.LoadNetwork( - model, hw_target, device_config, subgraph_context_.subgraph_name); + if (!subgraph_context_.has_dynamic_input_shape && + global_context_.onnx_model_path_name != "" && + dev_prec != "CPU_FP16") { + exe_network_ = global_context_.ie_core.LoadNetwork(global_context_.onnx_model_path_name, + hw_target, + device_config, + subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 931173fd7ef47..ea481791111fc 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -87,13 +87,13 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, } } -OVExeNetwork OVCore::LoadNetwork(const std::string& model, +OVExeNetwork OVCore::LoadNetwork(const std::string onnx_model_path, std::string& hw_target, ov::AnyMap& device_config, std::string name) { ov::CompiledModel obj; try { - obj = oe.compile_model(model, ov::Tensor(), hw_target, device_config); + obj = oe.compile_model(onnx_model_path, hw_target, device_config); OVExeNetwork exe(obj); return exe; } catch (const Exception& e) { diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 3db19463809cf..cf4d867d4df55 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -45,7 +45,7 @@ class OVCore { std::string& hw_target, ov::AnyMap& device_config, std::string name); - OVExeNetwork LoadNetwork(const std::string& model_stream, + OVExeNetwork LoadNetwork(const std::string model_path, std::string& hw_target, ov::AnyMap& device_config, std::string name); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index bb8732784945d..3bec9aa146f76 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1682,7 +1682,11 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O if (legacy_ov_options->device_type != nullptr) ov_options_converted_map["device_type"] = legacy_ov_options->device_type; - ov_options_converted_map["enable_npu_fast_compile"] = legacy_ov_options->enable_npu_fast_compile; + if (legacy_ov_options->enable_npu_fast_compile) { + ov_options_converted_map["enable_npu_fast_compile"] = "false"; + } else { + ov_options_converted_map["enable_npu_fast_compile"] = "true"; + } if (legacy_ov_options->device_id != nullptr) ov_options_converted_map["device_id"] = legacy_ov_options->device_id; @@ -1701,14 +1705,12 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O ov_options_converted_map["enable_opencl_throttling"] = legacy_ov_options->enable_opencl_throttling; - if (legacy_ov_options->enable_dynamic_shapes != '\0') { - std::string enable_dynamic_shapes = reinterpret_cast(legacy_ov_options->enable_dynamic_shapes); - if (enable_dynamic_shapes == "true" || enable_dynamic_shapes == "True") { - ov_options_converted_map["disable_dynamic_shapes"] = "false"; - } else if (enable_dynamic_shapes == "false" || enable_dynamic_shapes == "False") { - ov_options_converted_map["disable_dynamic_shapes"] = "true"; - } + if (legacy_ov_options->enable_dynamic_shapes) { + ov_options_converted_map["disable_dynamic_shapes"] = "false"; + } else { + ov_options_converted_map["disable_dynamic_shapes"] = "true"; } + // Add new provider option below ov_options_converted_map["num_streams"] = "1"; return ov_options_converted_map; From 7fa6f4fca4a05ac025995da70b7a5fa92ee46d83 Mon Sep 17 00:00:00 2001 From: snadampal <87143774+snadampal@users.noreply.github.com> Date: Mon, 12 Feb 2024 17:20:36 -0600 Subject: [PATCH 068/207] add arm64 bfloat16 fastmath mode option for transformers benchmarking script (#19294) Add arm64 bfloat16 fastmath mode option for transformers benchmarking script. ### Motivation and Context onnxruntime now supports bfloat16 fastmath gemm kernels for arm64 platforms with bfloat16 instruction support. This PR updates benchmark scripts to test that mode. --- onnxruntime/python/tools/transformers/benchmark.py | 13 +++++++++++++ .../python/tools/transformers/benchmark_helper.py | 4 ++++ .../python/tools/transformers/run_benchmark.sh | 9 ++++++++- 3 files changed, 25 insertions(+), 1 deletion(-) mode change 100644 => 100755 onnxruntime/python/tools/transformers/run_benchmark.sh diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 22f2cfb8a01ca..89f9947688583 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -36,6 +36,8 @@ python benchmark.py -e torchscript onnxruntime -p "int8" -o Run OnnxRuntime with the ROCM provider and graph optimization script: python benchmark.py -g -m bert-base-cased --provider rocm --optimizer_info by_script --disable_embed_layer_norm + Run OnnxRuntime with bfloat16 fastmath mode kernels on aarch64 platforms with bfloat16 support: + python benchmark.py --enable_arm64_bfloat16_fastmath_mlas_gemm It is recommended to use run_benchmark.sh to launch benchmark. """ @@ -106,6 +108,7 @@ def run_onnxruntime( use_raw_attention_mask, model_fusion_statistics, model_source, + enable_arm64_bfloat16_fastmath_mlas_gemm, args, ): import onnxruntime @@ -209,6 +212,7 @@ def run_onnxruntime( enable_all_optimization=True, num_threads=num_threads, verbose=verbose, + enable_mlas_gemm_fastmath_arm64_bfloat16=enable_arm64_bfloat16_fastmath_mlas_gemm, ) if ort_session is None: continue @@ -760,6 +764,14 @@ def parse_arguments(): help="Manually set the model's layer number", ) + parser.add_argument( + "--enable_arm64_bfloat16_fastmath_mlas_gemm", + required=False, + action="store_true", + help="Enable bfloat16 mlas gemm kernels on aarch64. Supported only for CPU EP ", + ) + parser.set_defaults(enable_arm64_bfloat16_fastmath_mlas_gemm=False) + FusionOptions.add_arguments(parser) args = parser.parse_args() @@ -905,6 +917,7 @@ def main(): use_raw_attention_mask, model_fusion_statistics, args.model_source, + args.enable_arm64_bfloat16_fastmath_mlas_gemm, args, ) except Exception: diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index c5edae42590fd..c7d93470a729e 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -85,6 +85,7 @@ def create_onnxruntime_session( num_threads=-1, enable_profiling=False, verbose=False, + enable_mlas_gemm_fastmath_arm64_bfloat16=False, provider_options={}, # map execution provider name to its option # noqa: B006 ): session = None @@ -136,6 +137,9 @@ def create_onnxruntime_session( if provider_options: providers = [(name, provider_options[name]) if name in provider_options else name for name in providers] + if enable_mlas_gemm_fastmath_arm64_bfloat16: + sess_options.add_session_config_entry("mlas.enable_gemm_fastmath_arm64_bfloat16", "1") + session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers) except Exception: logger.error("Exception", exc_info=True) diff --git a/onnxruntime/python/tools/transformers/run_benchmark.sh b/onnxruntime/python/tools/transformers/run_benchmark.sh old mode 100644 new mode 100755 index f0422839c11eb..64d6ecde618f6 --- a/onnxruntime/python/tools/transformers/run_benchmark.sh +++ b/onnxruntime/python/tools/transformers/run_benchmark.sh @@ -34,6 +34,9 @@ run_gpu_fp16=true run_cpu_fp32=false run_cpu_int8=false +# Set this to true to enable bfloat16 fastmath gemm kernels on aarch64 platforms with bfloat16 support +arm64_bfloat16_fastmath_mode=false + average_over=1000 # CPU takes longer time to run, only run 100 inferences to get average latency. if [ "$run_cpu_fp32" = true ] || [ "$run_cpu_int8" = true ]; then @@ -63,7 +66,7 @@ models_to_test="bert-base-cased roberta-base distilbert-base-uncased" # export CUDA_VISIBLE_DEVICES=1 # This script will generate a logs file with a list of commands used in tests. -echo echo "ort=$run_ort torch=$run_torch torch2=$run_torch2 torchscript=$run_torchscript tensorflow=$run_tensorflow gpu_fp32=$run_gpu_fp32 gpu_fp16=$run_gpu_fp16 cpu=$run_cpu optimizer=$use_optimizer batch=$batch_sizes sequence=$sequence_length models=$models_to_test" >> benchmark.log +echo echo "ort=$run_ort torch=$run_torch torch2=$run_torch2 torchscript=$run_torchscript tensorflow=$run_tensorflow gpu_fp32=$run_gpu_fp32 gpu_fp16=$run_gpu_fp16 cpu=$run_cpu optimizer=$use_optimizer batch=$batch_sizes sequence=$sequence_length models=$models_to_test" arm64_bfloat16_fastmath_mode=$arm64_bfloat16_fastmath_mode >> benchmark.log # Set it to false to skip testing. You can use it to dry run this script with the log file. run_tests=true @@ -127,6 +130,10 @@ if [ "$force_layer_number" = true ] ; then benchmark_options="$benchmark_options --force_num_layers $layer_number" fi +if [ "$arm64_bfloat16_fastmath_mode" = true ] ; then + benchmark_options="$benchmark_options --enable_arm64_bfloat16_fastmath_mlas_gemm" +fi + # ------------------------------------------- run_one_test() { if [ "$run_ort" = true ] ; then From a622710fe171fab399bee41c8c3cd1dc16dd8f62 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Mon, 12 Feb 2024 19:11:40 -0800 Subject: [PATCH 069/207] Add option to skip session run in perf_test tool (#19501) Enable a option to exit after session creation so that user can measure session creation time to measure impact of enabling any initialization optimizations. --- onnxruntime/test/perftest/command_args_parser.cc | 6 +++++- onnxruntime/test/perftest/main.cc | 7 +++++++ onnxruntime/test/perftest/performance_runner.cc | 5 +++++ onnxruntime/test/perftest/performance_runner.h | 2 ++ onnxruntime/test/perftest/test_configuration.h | 1 + 5 files changed, 20 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 7cfbe0a84e3e6..3874901f86387 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -128,6 +128,7 @@ namespace perftest { "\t\t The number of affinities must be equal to intra_op_num_threads - 1\n\n" "\t-D [Disable thread spinning]: disable spinning entirely for thread owned by onnxruntime intra-op thread pool.\n" "\t-Z [Force thread to stop spinning between runs]: disallow thread from spinning during runs to reduce cpu usage.\n" + "\t-n [Exit after session creation]: allow user to measure session creation time to measure impact of enabling any initialization optimizations.\n" "\t-h: help\n"); } #ifdef _WIN32 @@ -190,7 +191,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqz"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqzn"))) != -1) { switch (ch) { case 'f': { std::basic_string dim_name; @@ -373,6 +374,9 @@ static bool ParseSessionConfigs(const std::string& configs_string, case 'Z': test_config.run_config.disable_spinning_between_run = true; break; + case 'n': + test_config.run_config.exit_after_session_creation = true; + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/main.cc b/onnxruntime/test/perftest/main.cc index 36f08167c2217..43bf54963cabb 100644 --- a/onnxruntime/test/perftest/main.cc +++ b/onnxruntime/test/perftest/main.cc @@ -43,6 +43,13 @@ int real_main(int argc, char* argv[]) { } std::random_device rd; perftest::PerformanceRunner perf_runner(env, test_config, rd); + + // Exit if user enabled -n option so that user can measure session creation time + if (test_config.run_config.exit_after_session_creation) { + perf_runner.LogSessionCreationTime(); + return 0; + } + auto status = perf_runner.Run(); if (!status.IsOK()) { printf("Run failed:%s\n", status.ErrorMessage().c_str()); diff --git a/onnxruntime/test/perftest/performance_runner.cc b/onnxruntime/test/perftest/performance_runner.cc index 9f2cbcf6a21f1..37bf80c80e90b 100644 --- a/onnxruntime/test/perftest/performance_runner.cc +++ b/onnxruntime/test/perftest/performance_runner.cc @@ -115,6 +115,11 @@ void PerformanceResult::DumpToFile(const std::basic_string& path, boo } } +void PerformanceRunner::LogSessionCreationTime() { + std::chrono::duration session_create_duration = session_create_end_ - session_create_start_; + std::cout << "\nSession creation time cost: " << session_create_duration.count() << " s\n"; +} + Status PerformanceRunner::Run() { if (!Initialize()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "failed to initialize."); diff --git a/onnxruntime/test/perftest/performance_runner.h b/onnxruntime/test/perftest/performance_runner.h index da2df9c39f44c..cb1cb661550a7 100644 --- a/onnxruntime/test/perftest/performance_runner.h +++ b/onnxruntime/test/perftest/performance_runner.h @@ -46,6 +46,8 @@ class PerformanceRunner { ~PerformanceRunner(); Status Run(); + void LogSessionCreationTime(); + inline const PerformanceResult& GetResult() const { return performance_result_; } inline void SerializeResult() const { diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 5a49414a49004..74c8eb472cb3e 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -63,6 +63,7 @@ struct RunConfig { std::string intra_op_thread_affinities; bool disable_spinning = false; bool disable_spinning_between_run = false; + bool exit_after_session_creation = false; }; struct PerformanceTestConfig { From 61e07a46e17c6dbc1adbd824d6b36aaac6f26c7b Mon Sep 17 00:00:00 2001 From: Patrice Vignola Date: Mon, 12 Feb 2024 19:36:08 -0800 Subject: [PATCH 070/207] [DML EP] Support split hidden size for RotaryEmbedding (#18852) RotaryEmbedding now supports the `[batchSize, numHeads, sequenceLength, headSize]` format for its input, which is used in Mistral. --- .../Operators/DmlOperatorRotaryEmbedding.cpp | 36 +++++++++++++------ 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 30c339b845b36..44004b5d77f70 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -43,6 +43,10 @@ class DmlOperatorRotaryEmbedding : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + // When the input is 4D, it has the shape [batchSize, numHeads, sequenceLength, headSize]. Otherwise, + // it has the shape [batchSize, sequenceLength, hiddenSize] + const bool inputIs4D = kernelInfo.GetInputTensorDimensionCount(inputDataIndex) == 4; + // When positionIds is a scalar, it represents the start offset for each sequence const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1; @@ -63,9 +67,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes(); - const uint32_t batchSize = inputDataSizes[1]; + const uint32_t batchSize = inputIs4D ? inputDataSizes[0] : inputDataSizes[1]; const uint32_t sequenceLength = inputDataSizes[2]; - const uint32_t numHeads = inputDataSizes[3] / headSize; + const uint32_t numHeads = inputIs4D ? inputDataSizes[1] : inputDataSizes[3] / headSize; const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes(); const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2]; @@ -80,16 +84,24 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector inputDescs = GetDmlInputDescs(); const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; - // Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + + if (inputIs4D) + { + const std::array inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1}; + stridedInputOutputTensorDesc.SetStrides(inputOutputStrides); + } + const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); + const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc(); // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. DML_SCALE_BIAS scaleBias = {1.0f, 0.0f}; DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{}; - copyInputDesc.InputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc; copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc; copyInputDesc.ScaleBias = &scaleBias; const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc}; @@ -104,8 +116,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator : std::vector({batchSize, sequenceLength, numHeads, 1, headSize / 2}); TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); + TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC joinedDataDmlTensorDesc = joinedDataTensorDesc.GetDmlDesc(); + TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; @@ -122,7 +138,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // Swap the 2 halves and join them together DML_JOIN_OPERATOR_DESC joinInputDesc{}; joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); - joinInputDesc.OutputTensor = &inputDataDmlTensorDesc; + joinInputDesc.OutputTensor = &joinedDataDmlTensorDesc; joinInputDesc.Axis = splitInputDesc.Axis; joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; @@ -212,23 +228,23 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; - mulSignDesc.ATensor = &inputDataDmlTensorDesc; + mulSignDesc.ATensor = &joinedDataDmlTensorDesc; mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc; - mulSignDesc.OutputTensor = &inputDataDmlTensorDesc; + mulSignDesc.OutputTensor = &joinedDataDmlTensorDesc; const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc}; // Multiply the non-rotated data with the cos and the rotated data with the sin DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{}; - mulCosSinDesc.ATensor = &inputDataDmlTensorDesc; + mulCosSinDesc.ATensor = &joinedDataDmlTensorDesc; mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc; - mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc; + mulCosSinDesc.OutputTensor = &joinedDataDmlTensorDesc; const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc}; // Add the multiplied cos and sin values together DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; addDesc.ATensor = &inputOutputDmlTensorDesc; addDesc.BTensor = &inputOutputDmlTensorDesc; - addDesc.OutputTensor = &inputOutputDmlTensorDesc; + addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc; const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; // Construct the graph From 4dfba53bfb54e8fa50af52a536920d62840020a4 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 12 Feb 2024 20:54:04 -0800 Subject: [PATCH 071/207] [QNN EP] Build x64 python wheel for QNN EP (#19499) ### Description Adds a job to the python packaging pipeline that builds x64 python wheels for QNN EP. ### Motivation and Context Necessary to create a cached QNN model on Windows x64, which is done by creating a properly configured onnxruntime session with QNN EP. --- .../azure-pipelines/py-packaging-pipeline.yml | 6 + .../templates/py-packaging-stage.yml | 11 ++ .../templates/py-win-x64-qnn.yml | 177 ++++++++++++++++++ 3 files changed, 194 insertions(+) create mode 100644 tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 5349b1ca67ab1..6b0ae085fa4db 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -34,6 +34,11 @@ parameters: type: boolean default: true +- name: enable_windows_x64_qnn + displayName: 'Whether Windows x86_64 package with QNN EP is built.' + type: boolean + default: true + - name: build_py_parameters displayName: 'Specify extra build parameters' type: string @@ -70,5 +75,6 @@ stages: enable_mac_cpu: ${{ parameters.enable_mac_cpu }} enable_linux_arm: ${{ parameters.enable_linux_arm }} enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} + enable_windows_x64_qnn: ${{ parameters.enable_windows_x64_qnn }} build_py_parameters: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 146e3e58444c1..5ac5bda8b0964 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -40,6 +40,11 @@ parameters: type: boolean default: true +- name: enable_windows_x64_qnn + displayName: 'Whether Windows x86_64 package with QNN EP is built.' + type: boolean + default: true + # TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. - name: cmake_build_type type: string @@ -459,3 +464,9 @@ stages: QNN_SDK: 'qnn-v2.18.0.240101_win' PYTHON_VERSION: '3.11' NUMPY_VERSION: '1.25.2' + + - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: + - template: py-win-x64-qnn.yml + parameters: + MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QNN_SDK: 'qnn-v2.18.0.240101_win' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml new file mode 100644 index 0000000000000..30f21e933ee36 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -0,0 +1,177 @@ +parameters: + +- name: MACHINE_POOL + type: string + default: 'Onnxruntime-QNNEP-Windows-2022-CPU' + +- name: QNN_SDK + displayName: QNN Windows SDK path + type: string + default: qnn-v2.18.0.240101_win + +- name: ENV_SETUP_SCRIPT + type: string + default: '' + +- name: BUILD_PY_PARAMETERS + displayName: > + Extra parameters to pass to build.py. Don't put newlines in here. + type: string + default: '' + +jobs: +- job: Win_py_x64_qnn_Wheels + timeoutInMinutes: 210 + workspace: + clean: all + pool: + name: ${{ parameters.MACHINE_POOL }} + strategy: + matrix: + Python38_x64: + PythonVersion: '3.8' + Python39_x64: + PythonVersion: '3.9' + Python310_x64: + PythonVersion: '3.10' + Python311_x64: + PythonVersion: '3.11' + Python312_x64: + PythonVersion: '3.12' + variables: + GRADLE_OPTS: '-Dorg.gradle.daemon=false' + VSGenerator: 'Visual Studio 17 2022' + QNN_SDK_ROOTDIR: 'C:\data\qnnsdk\${{parameters.QNN_SDK}}' + steps: + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - script: | + DIR C:\data\qnnsdk + displayName: Check available QNN SDKs + + - task: UsePythonVersion@0 + inputs: + versionSpec: $(PythonVersion) + addToPath: true + architecture: 'x64' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: false + + - task: PythonScript@0 + inputs: + scriptSource: inline + script: | + import sys + np_version = 'numpy==1.21.6' if sys.version_info < (3, 11) else 'numpy==1.24.2' + import subprocess + subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', np_version]) + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Install python modules' + + - template: download-deps.yml + + - task: PythonScript@0 + displayName: 'Update deps.txt' + inputs: + scriptPath: $(Build.SourcesDirectory)/tools/ci_build/replace_urls_in_deps.py + arguments: --new_dir $(Build.BinariesDirectory)/deps + workingDirectory: $(Build.BinariesDirectory) + + - task: PowerShell@2 + displayName: 'Install ONNX' + inputs: + filePath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/install_third_party_deps.ps1' + workingDirectory: '$(Build.BinariesDirectory)' + arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\RelWithDebInfo\installed -build_config RelWithDebInfo + + - template: set-nightly-build-option-variable-step.yml + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --use_qnn + --qnn_home $(QNN_SDK_ROOTDIR) + --enable_pybind + --parallel --update + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'x64' + configuration: RelWithDebInfo + msbuildArchitecture: 'x64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd,*.dll' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + ArtifactName: onnxruntime_qnn + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' From 5e70c6b3a68b7826e0fe8b3ebf8185ea59ace6de Mon Sep 17 00:00:00 2001 From: George Wu Date: Mon, 12 Feb 2024 22:53:04 -0800 Subject: [PATCH 072/207] allow protobuf lite build for TRT EP (#19498) allow protobuf-lite builds with TensorRT EP as long as it's built with the trt built-in parser and not the oss-parser. This is because trt built-in parser statically links protobuf so there aren't any conflicts for protobuf-lite. --- cmake/CMakeLists.txt | 3 +-- tools/ci_build/build.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 90fe8276ea9c7..ff1c7a84f077f 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -117,8 +117,7 @@ option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF) option(onnxruntime_GCOV_COVERAGE "Compile with options necessary to run code coverage" OFF) option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) -#It's preferred to turn it OFF when onnxruntime is dynamically linked to PROTOBUF. But Tensort always required the full version of protobuf. -cmake_dependent_option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF "NOT onnxruntime_USE_TENSORRT" ON) +option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir") option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 8567d595b7429..96567c8767a82 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1236,9 +1236,15 @@ def generate_build_tree( "-Donnxruntime_USE_OPENVINO_AUTO=" + ("ON" if args.use_openvino.startswith("AUTO") else "OFF"), ] - # TensorRT and OpenVINO providers currently only support - # full_protobuf option. - if args.use_full_protobuf or args.use_tensorrt or args.use_openvino or args.use_vitisai or args.gen_doc: + # VitisAI and OpenVINO providers currently only support + # full_protobuf option. TensorRT provider only requires it if built with oss_parser + if ( + args.use_full_protobuf + or (args.use_tensorrt and args.use_tensorrt_oss_parser) + or args.use_openvino + or args.use_vitisai + or args.gen_doc + ): cmake_args += ["-Donnxruntime_USE_FULL_PROTOBUF=ON", "-DProtobuf_USE_STATIC_LIBS=ON"] if args.use_tvm and args.llvm_path is not None: From 5c7e6b2e2af922e3dc7e58c584e454074a394d3d Mon Sep 17 00:00:00 2001 From: Yifan Li <109183385+yf711@users.noreply.github.com> Date: Mon, 12 Feb 2024 23:04:08 -0800 Subject: [PATCH 073/207] [EP Perf] Add CI option to enable TRT-OSS parser (#19448) ### Description * Introducing CI option to enable TRT-OSS parser, during ep perf testing: ![image](https://github.com/microsoft/onnxruntime/assets/109183385/a9ba6393-6b94-4b8f-8ca4-ba7bc7954504) By default, open-sourced onnx-tensorrt parser listed under [cmake/deps.txt](https://github.com/microsoft/onnxruntime/blob/main/cmake/deps.txt#L39-L40) will be used if enabling this option. ### To verify this option and check the difference during ORT image build: If this option is enabled: image If this option is not enabled (by default): image * update default usage of cmake/trt version to the latest ### Motivation and Context Make it easier to test oss parser and find potential gap between tensorrt builtin/oss parser. Schedule runs with oss parser will be set after this PR gets merged --- .../tools/tensorrt/perf/build/build_image.py | 18 ++++++- .../tensorrt/perf/build/ort_build_latest.py | 48 +++++++++++-------- ...linux-gpu-tensorrt-daily-perf-pipeline.yml | 21 ++++++-- .../Dockerfile.ubuntu_cuda11_8_tensorrt8_6 | 5 +- 4 files changed, 63 insertions(+), 29 deletions(-) diff --git a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py index b98aafc27579a..2ae64a72d08fe 100644 --- a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py +++ b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py @@ -45,7 +45,7 @@ def get_common_docker_build_args(args: argparse.Namespace) -> List[str]: :return: A list of common 'docker build' arguments. """ - return [ + command = [ "--no-cache", "-t", f"{args.image_name}", @@ -54,6 +54,14 @@ def get_common_docker_build_args(args: argparse.Namespace) -> List[str]: "--build-arg", f"ONNXRUNTIME_BRANCH={args.branch}", ] + if args.use_tensorrt_oss_parser: + command.extend( + [ + "--build-arg", + "PARSER_CONFIG=--use_tensorrt_oss_parser", + ] + ) + return command def is_valid_ver_str(version: str, min_comps: int = 0, max_comps: int = 0) -> bool: @@ -187,7 +195,7 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument("-r", "--repo_path", required=True, help="Path to the onnxruntime repository") parser.add_argument("-i", "--image_name", required=True, help="The resulting Docker image name") parser.add_argument("-b", "--branch", default="main", help="Name of the onnxruntime git branch to checkout") - parser.add_argument("-t", "--trt_version", default="8.4.1.5", help="TensorRT version (e.g., 8.4.1.5)") + parser.add_argument("-t", "--trt_version", default="8.6.1.6", help="TensorRT version (e.g., 8.6.1.6)") parser.add_argument("-a", "--cuda_arch", default="75", help="CUDA architecture (e.g., 75)") # Command-line options for installing TensorRT from binaries. @@ -208,6 +216,12 @@ def parse_arguments() -> argparse.Namespace: help="CUDA version (e.g., 8.6) used to find TensorRT EA binary tar.gz package", ) parser.add_argument("--trt_bins_dir", default="", help="Directory containing TensorRT tar.gz package") + parser.add_argument( + "--use_tensorrt_oss_parser", + action="store_true", + default=False, + help="Use TensorRT OSS Parser", + ) return parser.parse_args() diff --git a/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py b/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py index 6e20071683d90..c7d4a7836132a 100755 --- a/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py +++ b/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py @@ -13,6 +13,12 @@ def parse_arguments(): parser.add_argument("-b", "--branch", required=False, default="master", help="Github branch to test perf off of") parser.add_argument("-s", "--save", required=False, help="Directory to archive wheel file") parser.add_argument("-a", "--use_archived", required=False, help="Archived wheel file") + parser.add_argument( + "--use_tensorrt_oss_parser", + action="store_true", + default=False, + help="Use TensorRT OSS Parser", + ) args = parser.parse_args() return args @@ -35,14 +41,14 @@ def install_new_ort_wheel(ort_master_path): def main(): args = parse_arguments() - cmake_tar = "cmake-3.18.4-Linux-x86_64.tar.gz" + cmake_tar = "cmake-3.28.3-linux-x86_64.tar.gz" if not os.path.exists(cmake_tar): - subprocess.run(["wget", "-c", "https://cmake.org/files/v3.18/" + cmake_tar], check=True) + subprocess.run(["wget", "-c", "https://cmake.org/files/v3.28/" + cmake_tar], check=True) tar = tarfile.open(cmake_tar) tar.extractall() tar.close() - os.environ["PATH"] = os.path.join(os.path.abspath("cmake-3.18.4-Linux-x86_64"), "bin") + ":" + os.environ["PATH"] + os.environ["PATH"] = os.path.join(os.path.abspath("cmake-3.28.3-linux-x86_64"), "bin") + ":" + os.environ["PATH"] os.environ["CUDACXX"] = os.path.join(args.cuda_home, "bin", "nvcc") ort_master_path = args.ort_master_path @@ -57,24 +63,24 @@ def main(): subprocess.run(["git", "fetch"], check=True) subprocess.run(["git", "checkout", args.branch], check=True) subprocess.run(["git", "pull", "origin", args.branch], check=True) - subprocess.run( - [ - "./build.sh", - "--config", - "Release", - "--use_tensorrt", - "--tensorrt_home", - args.tensorrt_home, - "--cuda_home", - args.cuda_home, - "--cudnn", - "/usr/lib/x86_64-linux-gnu", - "--build_wheel", - "--skip_tests", - "--parallel", - ], - check=True, - ) + command = [ + "./build.sh", + "--config", + "Release", + "--use_tensorrt", + "--tensorrt_home", + args.tensorrt_home, + "--cuda_home", + args.cuda_home, + "--cudnn", + "/usr/lib/x86_64-linux-gnu", + "--build_wheel", + "--skip_tests", + "--parallel", + ] + if args.use_tensorrt_oss_parser: + command.append("--use_tensorrt_oss_parser") + subprocess.run(command, check=True) ort_wheel_file = install_new_ort_wheel(ort_master_path) diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index e75bb68a8bfeb..eaadc6ad728c0 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -15,6 +15,11 @@ parameters: - 8.6.1.6 - BIN +- name: UseTensorrtOssParser + displayName: Use TensorRT-OSS Parser + type: boolean + default: false + - name: ModelGroups type: object default: @@ -73,7 +78,7 @@ jobs: value: ort-image-$(Build.BuildId) steps: - - ${{ if eq(parameters.TrtVersion, 'BIN') }}: + - ${{ if and(eq(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, false)) }}: - script: 'ls -al $(trtBinsDir)' displayName: 'Show available TensorRT .tar.gz packages' @@ -83,11 +88,19 @@ jobs: - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75 --install_bin --tar_cuda_version=$(tarCudaVersion) --tar_cudnn_version=$(tarCudnnVersion) --trt_bins_dir=.' displayName: 'Install TensorRT from binaries and build latest ORT Image' workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build' - - ${{ else }}: + + # Build ORT with TensorRT built-in parser + - ${{ if and(ne(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, false)) }}: - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75' - displayName: 'Build latest ORT Image' + displayName: 'Build latest ORT Image with TensorRT built-in parser' workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build' - + + # Build ORT with TensorRT OSS parser + - ${{ if and(ne(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, true)) }}: + - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75 --use_tensorrt_oss_parser' + displayName: 'Build latest ORT Image with TensorRT OSS parser' + workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build' + - ${{ if eq(parameters.MemTest, true) }}: - script: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/mem_test/run_mem_test_docker.sh -d $(image) -p $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/mem_test/ -w /code/ -l false' displayName: 'Run Memory Test' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 index 04a6af962b5e6..f1ffba3b3e1c9 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 @@ -82,8 +82,9 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" -RUN /bin/sh build.sh --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' +ENV CUDA_MODULE_LOADING "LAZY" +ARG PARSER_CONFIG="" +RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' # Switch to root to continue following steps of CI USER root From 1e10cdb2b9b9aa8ae6b248bca5ea139185e92947 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Wed, 14 Feb 2024 00:49:19 +0100 Subject: [PATCH 074/207] Fix subgraph quantization regression in onnxruntime 1.17 (#19421) As per title, fixes https://github.com/microsoft/onnxruntime/issues/19418 ONNX Runtime 1.17 broke the quantization of ONNX models with subgraphs where initializers are placed on the top-level graph, while different subgraphs use the same initializer. --- .../tools/quantization/onnx_quantizer.py | 10 ++- .../test/python/quantization/test_subgraph.py | 64 +++++++++++++++++++ 2 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 onnxruntime/test/python/quantization/test_subgraph.py diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index a72d21c03a8a6..ecfbaa569ca0a 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -1332,9 +1332,15 @@ def _dequantize_value(self, value_name): if (value_name in self.quantized_value_map) and (value_name not in self.generated_value_names): quantized_value = self.quantized_value_map[value_name] # Add DequantizeLinear Node for this input + scale_init = find_by_name(quantized_value.scale_name, self.model.initializer()) - # axis is not specified so scale_init must be a scalar. - assert onnx.numpy_helper.to_array(scale_init).size == 1 + + # In case we are working with subgraphs, the graph `producer_name` is set to `"onnx-quantizer"` in the `quantize_subgraph` method. In this case, the scale initializer may be on the top level graph, so the check below can not be done. + if self.model.model.producer_name != "onnx-quantizer" or ( + self.model.model.producer_name == "onnx-quantizer" and scale_init is not None + ): + # axis is not specified so scale_init must be a scalar. + assert onnx.numpy_helper.to_array(scale_init).size == 1 dqlinear_name = value_name + "_DequantizeLinear" dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph()) diff --git a/onnxruntime/test/python/quantization/test_subgraph.py b/onnxruntime/test/python/quantization/test_subgraph.py new file mode 100644 index 0000000000000..c425bf956f976 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_subgraph.py @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import tempfile +import unittest +import urllib.request + +import onnx + +from onnxruntime.quantization import quantize_dynamic + + +class TestDynamicQuantizationSubgraph(unittest.TestCase): + def test_dynamic_quantization_subgraph(self): + with tempfile.TemporaryDirectory() as tmpdir: + onnx_path = os.path.join(tmpdir, "decoder_model_merged.onnx") + quantized_onnx_path = os.path.join(tmpdir, "decoder_model_merged_quantized.onnx") + urllib.request.urlretrieve( + "https://huggingface.co/fxmarty/t5-tiny-onnx-testing/resolve/main/decoder_model_merged.onnx", onnx_path + ) + + quantize_dynamic( + model_input=onnx_path, + model_output=quantized_onnx_path, + per_channel=True, + op_types_to_quantize=[ + "Conv", + "MatMul", + "Attention", + "LSTM", + "Gather", + "Transpose", + "EmbedLayerNormalization", + ], + extra_options={"EnableSubgraph": True}, + ) + model = onnx.load(quantized_onnx_path) + + # The initializer `shared.weight_merged_0` is attached to the top-level graph, and used in a Gather node in each subgraphs. + # We expect the quantized Gather (after which a DequantizeLinear is attached) initializer to also be attached to the top-level graph. + found_gather_quantized = False + for initializer in model.graph.initializer: + if initializer.name == "shared.weight_merged_0_quantized": + found_gather_quantized = True + break + self.assertTrue(found_gather_quantized) + + found_gather_scale = False + for initializer in model.graph.initializer: + if initializer.name == "shared.weight_merged_0_scale": + found_gather_scale = True + break + self.assertTrue(found_gather_scale) + + # No initializers related to the Gather node should be attached to the subgraphs. + for node in model.graph.node: + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + for initializer in attr.g.initializer: + self.assertTrue("shared.weight" not in initializer.name) From f048fb5b14f5495950fb984dc474c8930861e474 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:59:15 -0800 Subject: [PATCH 075/207] Bump nuget/setup-nuget from 1 to 2 (#19411) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [nuget/setup-nuget](https://github.com/nuget/setup-nuget) from 1 to 2.
Release notes

Sourced from nuget/setup-nuget's releases.

v2.0.0

What's Changed

New Contributors

Full Changelog: https://github.com/NuGet/setup-nuget/compare/v1.2.0...v1.3.0

... (truncated)

Commits
  • a21f25c Update dist for release (#118)
  • 5166d73 build(deps-dev): bump @​typescript-eslint/parser from 6.19.0 to 6.20.0 (#117)
  • b915545 build(deps-dev): bump ts-jest from 29.1.1 to 29.1.2 (#113)
  • 00081d4 build(deps-dev): bump nock from 13.4.0 to 13.5.1 (#115)
  • e44f8a5 build(deps-dev): bump @​types/node from 20.11.5 to 20.11.10 (#116)
  • f685ada build(deps-dev): bump prettier from 3.1.1 to 3.2.4 (#109)
  • aee2c69 build(deps-dev): bump @​types/node from 20.10.4 to 20.11.5 (#110)
  • 2bd1cef build(deps-dev): bump eslint-plugin-jest from 27.6.0 to 27.6.3 (#106)
  • c5ed90c build(deps-dev): bump @​typescript-eslint/parser from 6.13.2 to 6.19.0 (#107)
  • 34040aa build(deps-dev): bump eslint from 8.55.0 to 8.56.0 (#94)
  • Additional commits viewable in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=nuget/setup-nuget&package-manager=github_actions&previous-version=1&new-version=2)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/publish-csharp-apidocs.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index c03399f4693be..5bc21595bf882 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -37,7 +37,7 @@ jobs: wget https://github.com/dotnet/docfx/releases/download/v${DOCFXVERSION}/docfx-linux-x64-v${DOCFXVERSION}.zip -O build/docfx/docfx.zip unzip build/docfx/docfx.zip -d build/docfx - name: Install NuGet - uses: nuget/setup-nuget@v1 + uses: nuget/setup-nuget@v2 - name: Build Documentation run: | build/docfx/docfx metadata csharp/ApiDocs/docfx.json From 18f76bd25ded7a6ec4b8675e1c2813753fec5343 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Tue, 13 Feb 2024 15:59:24 -0800 Subject: [PATCH 076/207] Bump gradle/wrapper-validation-action from 1 to 2 (#19412) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [gradle/wrapper-validation-action](https://github.com/gradle/wrapper-validation-action) from 1 to 2.
Release notes

Sourced from gradle/wrapper-validation-action's releases.

v2.0.0

What's Changed

The version of the Node.js runtime was updated to 20, and the majority of dependencies were updated to the latest versions. From now on, the wrapper-validation-action will require a Node.js 20 runtime environment.

There are no functional changes in this release. This release is tagged with the v2 version label.

  • [NEW] Update Node.js runtime to version 20 (#170)

v2.0.0-rc.1

This is a release candidate for v2.0.0. It is also available under the v2 version label.

What's Changed

The version of the Node.js runtime was updated to 20, and the majority of dependencies were updated to the latest versions. From now on, the wrapper-validation-action will require a Node.js 20 runtime environment.

There are no functional changes in this release.

  • [NEW] Update Node.js runtime to version 20 (#170)

v1.1.0

The action now adds the path of the failed wrapper Jar as a failed-wrapper Step output parameter. This makes the value available for reporting in later Steps/Jobs.

v1.0.6

Gradle Wrapper Validation

v1.0.5

Gradle Wrapper Validation

  • Update dependencies for Node 16 (#53)
  • Update dependencies with security vulnerabilities (#67)
  • Update various other dependencies (#45, #47, #48, #54)

v1.0.4

Gradle Wrapper Validation

v1.0.3

Gradle Wrapper Validation

Update minimist version to 1.2.5

v1.0.2

... (truncated)

Commits
  • 27152f6 Update to Node 20 (#170)
  • d8758a9 Build output
  • e916071 Update NPM dependencies
  • d9359e4 Add asdf config file
  • 77d43de Update upload-artifact version
  • 2f8436d Use setup-node@v4 instead of pinning to a revision
  • bfa0fe4 Consistently use npm cache for workflows
  • 8be8473 Update workflows and action to NodeJS 20
  • c8fad9e Bump @​babel/traverse from 7.14.7 to 7.23.2
  • 342dbeb Update README to use actions/checkout@v4
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=gradle/wrapper-validation-action&package-manager=github_actions&previous-version=1&new-version=2)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/gradle-wrapper-validation.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 03ea773a25130..bc2d8117930bc 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -11,4 +11,4 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: gradle/wrapper-validation-action@v1 + - uses: gradle/wrapper-validation-action@v2 From 544407038d96521617fe633cf97153d3e75561f5 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 14 Feb 2024 10:05:16 -0800 Subject: [PATCH 077/207] SimplifiedLayerNormalization Fusion BFloat16 support for Llama-v2 on A100 (#18898) ### Description Adds bfloat16 as a supported dtype for SimplifiedLayerNormFusion which will provide speedup for Llama-v2 on A100 using bfloat16 numerical format. _layernorm_optimized_training.onnx exported in bfloat16 vs. float16:_ ![image](https://github.com/microsoft/onnxruntime/assets/31260940/8c0a5f0f-5fcb-4637-bcd9-f34272ec0284) ### Repro Instructions ```python from torch import nn from onnxruntime.training.ortmodule import ORTModule, DebugOptions, LogLevel import torch dtype = torch.bfloat16 # dtype = torch.float16 class Net(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(784, 10, dtype=dtype) self.layernorm = nn.LayerNorm([784], dtype=dtype) def forward(self, x): x = x.view(x.shape[0], -1) x = self.layernorm(x) x = self.fc(x) return x model = Net() model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='layernorm', log_level=LogLevel.INFO)) model.to("cuda") images = torch.randn((8, 28, 28), dtype=dtype).to("cuda") output = model(images) ``` ### Motivation and Context ONNX Runtime integration with Llama-v2 family of LLMs. --------- Co-authored-by: Prathik Rao --- docs/OperatorKernels.md | 2 +- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 2 ++ .../core/optimizer/layer_norm_fusion.cc | 2 +- .../test/contrib_ops/layer_norm_op_test.cc | 23 +++++++++++++++++++ .../cuda/cuda_training_kernels.cc | 2 ++ 5 files changed, 29 insertions(+), 2 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 2ea557b7d61fe..2c00fe3d26752 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -765,7 +765,7 @@ Do not modify directly.* |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)| +|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| |Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 8f368251f12c7..be8c0dc86c135 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -120,6 +120,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); @@ -318,6 +319,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 159e3b23d1ab0..b6ad4fde6c1f7 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -13,7 +13,7 @@ using namespace onnxruntime::common; namespace onnxruntime { // LayerNorm supports limited data types. -static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)"}; +static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}; // Default epsilon static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f; diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 84bbee35eed5a..98fb62e435f31 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -7,6 +7,7 @@ #include "core/session/inference_session.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" #include "test/framework/test_utils.h" #include "test/util/include/default_providers.h" #include "test/providers/provider_test_utils.h" @@ -75,6 +76,28 @@ TEST(LayerNormTest, LayerNorm) { test.Run(); } +TEST(LayerNormTest, LayerNorm_BFloat16Input) { +// prevents test from running on non-BF16-supporting hardware +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; + return; + } +#endif + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 2, 3}; + test.AddInput("x", dims, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); + test.AddInput("gamma", {3}, MakeBFloat16({1.0f, 1.0f, 1.0f})); + test.AddOutput("output", dims, MakeBFloat16({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f})); + // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); +} + TEST(LayerNormTest, LayerNorm_Scale) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index dcf733153bdad..8b2bc7e2ef2b3 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -196,6 +196,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MixedPrecisionScale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, LayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2); @@ -452,6 +453,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, From f53d2c2465d81cdb4e14c7241eab327184192c88 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Wed, 14 Feb 2024 18:08:11 +0000 Subject: [PATCH 078/207] Phi2 script fixes (#19500) ### Description This PR is intended to support Phi2 passes in Olive. Merge it before https://github.com/microsoft/Olive/pull/938 ### Motivation and Context --- .../tools/transformers/fusion_options.py | 7 ++ .../models/phi2/convert_to_onnx.py | 3 - .../tools/transformers/onnx_model_phi.py | 98 +++++++++++-------- 3 files changed, 62 insertions(+), 46 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 4c43e4487bfb1..edac1989e4e9e 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -29,6 +29,13 @@ class AttentionOpType(Enum): def __str__(self): return self.value + # Override __eq__ to return string comparison + def __hash__(self): + return hash(self.value) + + def __eq__(self, other): + return other.value == self.value + class FusionOptions: """Options of fusion in graph optimization""" diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index b7881d064067d..796d6ec55ef80 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -138,9 +138,6 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str): # We keep last three layers of Attention as float32 or bfloat16 to avoid overflow. node_block_list = ( [ - "GroupQueryAttention_29", - "GroupQueryAttention_30", - "GroupQueryAttention_31", "Attention_29", "Attention_30", "Attention_31", diff --git a/onnxruntime/python/tools/transformers/onnx_model_phi.py b/onnxruntime/python/tools/transformers/onnx_model_phi.py index e68c3120e3f09..0fdce29ae0fa0 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_phi.py +++ b/onnxruntime/python/tools/transformers/onnx_model_phi.py @@ -80,14 +80,17 @@ def set_attention_op_type(self, attn_op_type: AttentionOpType): def get_uname(self, layer_id, name): return name + "_" + str(layer_id) - def get_io_by_name(self, node, name): - for input in node.input: - if input == name or input.endswith(name) or input.startswith(name): - return input - for output in node.output: - if output == name or output.endswith(name) or output.startswith(name): - return output - raise Exception(f"input {name} not found in node {node.name}") + def get_edge_by_name(self, edges, name): + for edge in edges: + if edge == name or edge.endswith(name) or edge.startswith(name): + return edge + raise ValueError(f"Edge {name} not found") + + def get_input_by_name(self, node, name): + return self.get_edge_by_name(node.input, name) + + def get_output_by_name(self, node, name): + return self.get_edge_by_name(node.output, name) def process_initializer(self, initializer_name, functor, custom_name=None): i = self.model.get_initializer(initializer_name) @@ -287,7 +290,6 @@ def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): self.num_attention_heads = num_heads self.hidden_size = hidden_size - self.phi2_edge_dict = self.get_phi2_edge_dict() self.func_name = "modeling_phi_PhiModel_model_1" def get_phi2_edge_dict(self) -> dict: @@ -296,11 +298,20 @@ def get_phi2_edge_dict(self) -> dict: edge_dict["l_input_ids_"] = "input_ids" edge_dict["key_states"] = "past_key_0" edge_dict["value_states"] = "past_value_0" - for i in range(self.num_hidden_layers): + for i in range(1, self.num_hidden_layers, 1): edge_dict[f"key_states_{i}"] = f"past_key_{i}" edge_dict[f"value_states_{i}"] = f"past_value_{i}" edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}" edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}" + + outputs = [o.name for o in self.model.graph.output] + if "model_layers_0_1_1" in outputs and "model_layers_0_1_2" in outputs: + edge_dict["model_layers_0_1_1"] = "present_key_0" + edge_dict["model_layers_0_1_2"] = "present_value_0" + else: + assert "model_layers_0_1" in outputs and "model_layers_0_1_1" in outputs + edge_dict["model_layers_0_1"] = "present_key_0" + edge_dict["model_layers_0_1_1"] = "present_value_0" return edge_dict def simplify_phi2_op_type(self): @@ -441,7 +452,7 @@ def preprocess_onnx(self, attn_op_type: AttentionOpType): break assert function_name is not None self.unroll_function(function_name) - self.update_edges(self.phi2_edge_dict) + self.update_edges(self.get_phi2_edge_dict()) self.simplify_phi2_op_type() self.remove_dropout_layer() if attn_op_type == AttentionOpType.PagedAttention: @@ -465,7 +476,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): input = node.input[0] output = node.output[0] - embedding = self.get_io_by_name(node, "embed_tokens.weight") + embedding = self.get_input_by_name(node, "embed_tokens.weight") layer_known_edges_names = [input, output, embedding] @@ -499,8 +510,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): input = node.input[0] output = node.output[0] - ln_weight = self.get_io_by_name(node, "final_layernorm.weight") - ln_bias = self.get_io_by_name(node, "final_layernorm.bias") + ln_weight = self.get_input_by_name(node, "final_layernorm.weight") + ln_bias = self.get_input_by_name(node, "final_layernorm.bias") layer_known_edges_names = [input, output, ln_weight, ln_bias] @@ -532,8 +543,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): input = node.input[2] output = node.output[0] - fc_weight = self.process_initializer(self.get_io_by_name(node, "lm_head.weight"), ProcessGemmWFunc()) - fc_bias = self.get_io_by_name(node, "lm_head.bias") + fc_weight = self.process_initializer(self.get_input_by_name(node, "lm_head.weight"), ProcessGemmWFunc()) + fc_bias = self.get_input_by_name(node, "lm_head.bias") layer_known_edges_names = [input, output, fc_weight, fc_bias] @@ -670,15 +681,15 @@ def fuse( layer_id = self.get_layer_id(node) i_hidden_states = node.input[0] - i_key_cache = self.get_io_by_name(node, "past_key") - i_value_cache = self.get_io_by_name(node, "past_value") + i_key_cache = self.get_input_by_name(node, "past_key") + i_value_cache = self.get_input_by_name(node, "past_value") - o_hidden_states = node.output[3] - o_key_cache = self.get_io_by_name(node, "present_key") - o_value_cache = self.get_io_by_name(node, "present_value") + o_hidden_states = node.output[-1] + o_key_cache = self.get_output_by_name(node, "present_key") + o_value_cache = self.get_output_by_name(node, "present_value") - ln_weight = self.get_io_by_name(node, "input_layernorm.weight") - ln_bias = self.get_io_by_name(node, "input_layernorm.bias") + ln_weight = self.get_input_by_name(node, "input_layernorm.weight") + ln_bias = self.get_input_by_name(node, "input_layernorm.bias") attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = ( None, @@ -693,45 +704,45 @@ def fuse( if self.attn_op_type != AttentionOpType.Attention: attn_q_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc() + self.get_input_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc() ) attn_k_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc() + self.get_input_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc() ) attn_v_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc() + self.get_input_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc() ) - attn_q_bias = self.get_io_by_name(node, "self_attn.q_proj.bias") - attn_k_bias = self.get_io_by_name(node, "self_attn.k_proj.bias") - attn_v_bias = self.get_io_by_name(node, "self_attn.v_proj.bias") + attn_q_bias = self.get_input_by_name(node, "self_attn.q_proj.bias") + attn_k_bias = self.get_input_by_name(node, "self_attn.k_proj.bias") + attn_v_bias = self.get_input_by_name(node, "self_attn.v_proj.bias") cos_cache = self.process_initializer( - self.get_io_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc() + self.get_input_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc() ) sin_cache = self.process_initializer( - self.get_io_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc() + self.get_input_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc() ) else: attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm( - self.get_io_by_name(node, "self_attn.q_proj.weight"), - self.get_io_by_name(node, "self_attn.k_proj.weight"), - self.get_io_by_name(node, "self_attn.v_proj.weight"), - self.get_io_by_name(node, "self_attn.q_proj.bias"), - self.get_io_by_name(node, "self_attn.k_proj.bias"), - self.get_io_by_name(node, "self_attn.v_proj.bias"), + self.get_input_by_name(node, "self_attn.q_proj.weight"), + self.get_input_by_name(node, "self_attn.k_proj.weight"), + self.get_input_by_name(node, "self_attn.v_proj.weight"), + self.get_input_by_name(node, "self_attn.q_proj.bias"), + self.get_input_by_name(node, "self_attn.k_proj.bias"), + self.get_input_by_name(node, "self_attn.v_proj.bias"), self.get_uname(layer_id, "attn_qkv_weight"), self.get_uname(layer_id, "attn_qkv_bias"), ) attn_out_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc() + self.get_input_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc() ) - attn_out_bias = self.get_io_by_name(node, "self_attn.dense.bias") + attn_out_bias = self.get_input_by_name(node, "self_attn.dense.bias") - mlp_fc1_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc()) - mlp_fc2_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc()) - mlp_fc1_bias = self.get_io_by_name(node, "mlp.fc1.bias") - mlp_fc2_bias = self.get_io_by_name(node, "mlp.fc2.bias") + mlp_fc1_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc()) + mlp_fc2_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc()) + mlp_fc1_bias = self.get_input_by_name(node, "mlp.fc1.bias") + mlp_fc2_bias = self.get_input_by_name(node, "mlp.fc2.bias") layer_known_edges_names = [] layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache]) @@ -771,6 +782,7 @@ def fuse( subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_")) subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_")) subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_")) + # vllm engine requires full position ids as the input pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step" subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_")) subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_")) From fbff99a432caef529f90d20137fa5aee33f38fcf Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 14 Feb 2024 10:08:46 -0800 Subject: [PATCH 079/207] Change Jave Test Threshold (#19508) ### Description Increase the threshold to 1e-5 to avoid test failed in CUDA when difference is slightly larger than 1e-6. May because TF32 is used in those CUDA tests. ### Motivation and Context https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1291322&view=logs&j=f2f63060-d9d6-52d0-adee-b97db5a9ab91&t=28e21ca6-87a4-5e1e-0441-72b5e8326f2d ProviderOptionsTest > testCUDAOptions() FAILED org.opentest4j.AssertionFailedError: array contents differ at index [103], expected: <0.0102678> but was: <0.010266338> at app//org.junit.jupiter.api.AssertionFailureBuilder.build(AssertionFailureBuilder.java:151) at app//org.junit.jupiter.api.AssertionFailureBuilder.buildAndThrow(AssertionFailureBuilder.java:132) at app//org.junit.jupiter.api.AssertArrayEquals.failArraysNotEqual(AssertArrayEquals.java:440) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:290) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:123) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:119) at app//org.junit.jupiter.api.Assertions.assertArrayEquals(Assertions.java:1360) at app//ai.onnxruntime.providers.ProviderOptionsTest.runProvider(ProviderOptionsTest.java:99) at app//ai.onnxruntime.providers.ProviderOptionsTest.testCUDAOptions(ProviderOptionsTest.java:43) https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1293200&view=logs&jobId=f2f63060-d9d6-52d0-adee-b97db5a9ab91&j=f2f63060-d9d6-52d0-adee-b97db5a9ab91&t=28e21ca6-87a4-5e1e-0441-72b5e8326f2d InferenceTest > testCUDA() FAILED org.opentest4j.AssertionFailedError: array contents differ at index [103], expected: <0.0102678> but was: <0.010266337> at app//org.junit.jupiter.api.AssertionFailureBuilder.build(AssertionFailureBuilder.java:151) at app//org.junit.jupiter.api.AssertionFailureBuilder.buildAndThrow(AssertionFailureBuilder.java:132) at app//org.junit.jupiter.api.AssertArrayEquals.failArraysNotEqual(AssertArrayEquals.java:440) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:290) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:123) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:119) at app//org.junit.jupiter.api.Assertions.assertArrayEquals(Assertions.java:1360) at app//ai.onnxruntime.InferenceTest.runProvider(InferenceTest.java:676) at app//ai.onnxruntime.InferenceTest.testCUDA(InferenceTest.java:615) --- java/src/test/java/ai/onnxruntime/InferenceTest.java | 2 +- .../test/java/ai/onnxruntime/providers/ProviderOptionsTest.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 7fef2dc784b7b..9925197e4507c 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -673,7 +673,7 @@ private void runProvider(OrtProvider provider) throws OrtException { // CoreML gives slightly different answers on a 2020 13" M1 MBP assertArrayEquals(expectedOutput, resultArray, 1e-2f); } else { - assertArrayEquals(expectedOutput, resultArray, 1e-6f); + assertArrayEquals(expectedOutput, resultArray, 1e-5f); } } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); diff --git a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java index 1ed883ace36e5..0e3bc15ba9c70 100644 --- a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java +++ b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java @@ -96,7 +96,7 @@ private static void runProvider(OrtProvider provider, OrtSession.SessionOptions OnnxValue resultTensor = result.get(0); float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue()); assertEquals(expectedOutput.length, resultArray.length); - assertArrayEquals(expectedOutput, resultArray, 1e-6f); + assertArrayEquals(expectedOutput, resultArray, 1e-5f); } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); } From 1508c2ee39023274417417b290303cf058ceedd6 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Wed, 14 Feb 2024 10:31:03 -0800 Subject: [PATCH 080/207] Restrict L2 Cache Core check to Intel devices (#19483) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description Limit SoC core detection via 2 level cache core logic to Intel and Hybrid processors. ### Motivation and Context The following code was added to add support for a new class of CPU cores present in Intel’s next generation Intel Core Ultra mobile processors. This code is essential to avoid placing threads on low performing SoC cores that don’t have L3 cache. SoC cores are meant to specialize in system bringup and help improve responsiveness and power usage, in other words they are not meant to run compute heavy AI workloads. In order to avoid broad exposure of this logic, it is currently designed to be restricted to Intel platforms that have hybrid enabled. --------- Co-authored-by: Sheil Kumar --- winml/lib/Api/HardwareCoreEnumerator.cpp | 25 ++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp index a89ac561f8860..fa069c7fb66a7 100644 --- a/winml/lib/Api/HardwareCoreEnumerator.cpp +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -14,7 +14,7 @@ struct LogicalProcessorInformation { struct CoreCounter { uint32_t PhysicalCores = 0; - uint32_t SocDieCores = 0; + uint32_t Num2CacheCores = 0; }; static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) { @@ -75,7 +75,7 @@ static CoreCounter GetNumberOPhysicalAndEngineeringCores() { read += currentProcessorInfo->Size; } - cores.SocDieCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); + cores.Num2CacheCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); return cores; } @@ -83,8 +83,25 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetNumberOPhysicalAndEngineeringCores(); - // We want to use the number of physical cores, but exclude soc cores - return cores.PhysicalCores - cores.SocDieCores; + + const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" + int regs_leaf0[4]; + int regs_leaf7[4]; + __cpuid(regs_leaf0, 0); + __cpuid(regs_leaf7, 0x7); + + auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && + (kVendorID_Intel[2] == regs_leaf0[3]); + + auto isHybrid = (regs_leaf7[3] & (1 << 15)); + + if (isIntel && isHybrid) { + // We want to use the number of physical cores, but exclude soc cores + // On Intel Hybrid processors, numSocCores == cores.Num2CacheCores + return cores.PhysicalCores - cores.Num2CacheCores; + } + + return cores.PhysicalCores; } } // namespace WINMLP From 3b03b2e046092522e84f0b9aebac1b394a3e4b13 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 14 Feb 2024 11:19:33 -0800 Subject: [PATCH 081/207] Upgrade default ORTModule opset from 15 to 17 (#19315) ### Description This PR upgrades ORTModule's default opset from 15 to 17. Opset 17 is the final opset supported by torchscript exporter (https://github.com/pytorch/pytorch/pull/107829) ### Motivation and Context Engineering excellence contribution for ORT Training DRI. --------- Co-authored-by: Prathik Rao --- .../python/training/ortmodule/__init__.py | 2 +- .../ortmodule/_custom_op_symbolic_registry.py | 24 +++++++++++++++++++ .../test/optimizer/compute_optimizer_test.cc | 8 +++---- .../test/optimizer/graph_transform_test.cc | 6 ++--- .../test/optimizer/shape_optimizer_test.cc | 20 ++++++++-------- .../python/orttraining_test_ortmodule_api.py | 4 ++-- .../orttraining_test_ortmodule_onnx_ops.py | 2 +- ...orttraining-py-packaging-pipeline-cuda.yml | 2 +- ...ttraining-py-packaging-pipeline-cuda12.yml | 2 +- .../docker/Dockerfile.manylinux2_28_rocm | 2 +- ...Dockerfile.manylinux2_28_training_cuda11_8 | 2 +- ...Dockerfile.manylinux2_28_training_cuda12_2 | 2 +- .../pai/rocm-ci-pipeline-env.Dockerfile | 2 +- 13 files changed, 51 insertions(+), 27 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index fbf1b7c2bac42..4a03465cf2ead 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -39,7 +39,7 @@ def _defined_from_envvar(name, default_value, warn=True): # NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and # assign them new values. Importing them directly do not propagate changes. ################################################################################ -ONNX_OPSET_VERSION = 15 +ONNX_OPSET_VERSION = 17 MINIMUM_RUNTIME_PYTORCH_VERSION_STR = "1.8.1" ORTMODULE_TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__), "torch_cpp_extensions") _FALLBACK_INIT_EXCEPTION = None diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 9288027f0188c..f81aef5f6b9c4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -821,3 +821,27 @@ def upsample_bicubic2d(g, input, output_size, align_corners, scale_factors): operator_s="upsample_bicubic2d", overload_name_s="vec", ) + + +@register_symbolic("layer_norm") +@parse_args("v", "is", "v", "v", "f", "none") +def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): + # normalized_shape: input shape from an expected input of size + # axis: The first normalization dimension. + # layer_norm normalizes on the last D dimensions, + # where D is the size of normalized_shape + axis = -len(normalized_shape) + + res, new_running_mean, new_running_var = g.op( + "LayerNormalization", + input, + weight, + bias, + epsilon_f=eps, + axis_i=axis, + outputs=3, # force all 3 outputs to be exported in training mode + operator_s="layer_norm", + overload_name_s="vec", + ) + + return res diff --git a/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc b/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc index cf510ea43c89f..509937bdd0c3a 100644 --- a/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc @@ -135,7 +135,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_Allowed) { } }; - std::vector opsets{12, 13, 14, 15}; + std::vector opsets{12, 13, 14, 15, 17}; for (auto opset : opsets) { std::unique_ptr transformer = std::make_unique(compatible_eps, std::vector{"label"}); @@ -206,7 +206,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_NotAllowed_LabelNameNotMat } }; - std::vector opsets{12, 13, 14, 15}; + std::vector opsets{12, 13, 14, 15, 17}; for (auto opset : opsets) { std::unique_ptr transformer = std::make_unique(compatible_eps, std::vector{"label"}); @@ -277,7 +277,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_NotAllowed_ReduceNone) { } }; - std::vector opsets{12, 13, 14, 15}; + std::vector opsets{12, 13, 14, 15, 17}; for (auto opset : opsets) { std::unique_ptr transformer = std::make_unique(compatible_eps, std::vector{"label"}); @@ -344,7 +344,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_NotAllowed_NoIgnoreIndex) } }; - std::vector opsets{12, 13, 14, 15}; + std::vector opsets{12, 13, 14, 15, 17}; for (auto opset : opsets) { std::unique_ptr transformer = std::make_unique(compatible_eps, std::vector{"label"}); diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index b774fec11cc8d..bab7c09839273 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -1523,7 +1523,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionThreeInputs) { builder.AddNode("Identity", {add2_out}, {graph_out}); }; - const std::vector opsets{12, 13, 14, 15}; + const std::vector opsets{12, 13, 14, 15, 17}; for (auto& opset_version : opsets) { std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), @@ -1616,7 +1616,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionThreeInputs_LastAddNotHaveScaleI builder.AddNode("Identity", {add2_out}, {graph_out}); }; - const std::vector opsets{12, 13, 14, 15}; + const std::vector opsets{12, 13, 14, 15, 17}; for (auto& opset_version : opsets) { std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), @@ -1710,7 +1710,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionTwoInputs) { builder.AddNode("Identity", {add1_out}, {graph_output2}); }; - const std::vector opsets{12, 13, 14, 15}; + const std::vector opsets{12, 13, 14, 15, 17}; for (auto& opset_version : opsets) { std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), diff --git a/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc b/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc index ea05b29c8668b..a1629eb73eeb6 100644 --- a/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc @@ -67,7 +67,7 @@ TEST(ShapeOptimizerTests, Shape15CannotFold) { return Status::OK(); }; - std::vector opset_candidates{15}; + std::vector opset_candidates{15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> identity_input_shape; @@ -145,7 +145,7 @@ TEST(ShapeOptimizerTests, Shape15) { return Status::OK(); }; - std::vector opset_candidates{15}; + std::vector opset_candidates{15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> identity_input_shape; @@ -218,7 +218,7 @@ TEST(ShapeOptimizerTests, Shape15TakesGraphInput) { return Status::OK(); }; - std::vector opset_candidates{15}; + std::vector opset_candidates{15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> shape_input_shape; @@ -289,7 +289,7 @@ TEST(ShapeOptimizerTests, Shape15GeneratesGraphOutput) { return Status::OK(); }; - std::vector opset_candidates{15}; + std::vector opset_candidates{15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> identity_input_shape; @@ -366,7 +366,7 @@ TEST(ShapeOptimizerTests, Slice) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> shape_input_shape; @@ -446,7 +446,7 @@ TEST(ShapeOptimizerTests, SliceGeneratesGraphOutput) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> shape_input_shape; @@ -530,7 +530,7 @@ TEST(ShapeOptimizerTests, Gather) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> shape_input_shape; @@ -639,7 +639,7 @@ TEST(ShapeOptimizerTests, ConcreteDimUsedBySlice) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> dropout_input_shape; @@ -810,7 +810,7 @@ TEST(ShapeOptimizerTests, ConcreteDimUsedByGatherSlice) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> reshape_input_shape; @@ -976,7 +976,7 @@ TEST(ShapeOptimizerTests, SymbolicDimUsedByGather_ConcreteDimUsedByGather) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> reshape_input_shape; diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 51aa1564cbfbe..365c2bb8ebe0e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -34,7 +34,7 @@ from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient from onnxruntime.training.ortmodule.options import _SkipCheck -DEFAULT_OPSET = 15 +DEFAULT_OPSET = 17 # PyTorch model definitions for tests @@ -5280,7 +5280,7 @@ def run_step(model, x): assert ort_model._torch_module._execution_manager(True)._runtime_options.onnx_opset_version == 13 -@pytest.mark.parametrize("opset_version", [12, 13, 14, 15]) +@pytest.mark.parametrize("opset_version", [12, 13, 14, 15, 17]) def test_opset_version_change(opset_version): original_env = None if "ORTMODULE_ONNX_OPSET_VERSION" in os.environ: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py index 4f0925c5c855b..2f240406b25b9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py @@ -79,7 +79,7 @@ def run_step(model, x): for onnx_model in [onnx_graph_inf, onnx_graph_train]: for oimp in onnx_model.opset_import: if oimp.domain == "": - self.assertEqual(oimp.version, 15) + self.assertEqual(oimp.version, 17) # Needs to match latest default ORTModule opset if op_grad_type is not None: if isinstance(op_grad_type, tuple): text = str(onnx_graph_train) diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml index d9ab85ee80ce3..47b1e0933417e 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml @@ -13,7 +13,7 @@ stages: parameters: build_py_parameters: --enable_training --update --build torch_version: '2.0.0' - opset_version: '15' + opset_version: '17' cuda_version: '11.8' cmake_cuda_architectures: 60;61;70;75;80;86 docker_file: Dockerfile.manylinux2_28_training_cuda11_8 diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml index 422fb33eec5de..86dce7ae465fc 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml @@ -13,7 +13,7 @@ stages: parameters: build_py_parameters: --enable_training --update --build torch_version: '2.1.0' - opset_version: '15' + opset_version: '17' cuda_version: '12.2' cmake_cuda_architectures: 70;75;80;86;90 docker_file: Dockerfile.manylinux2_28_training_cuda12_2 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index dd7c669c37885..e1914d5fe2f06 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -178,7 +178,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ARG PYTHON_VERSION=3.8 -ARG OPSET_VERSION=15 +ARG OPSET_VERSION=17 ARG INSTALL_DEPS_EXTRA_ARGS diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 index a6a75afb0f4c3..fed29689fbe5e 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 @@ -161,7 +161,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ARG PYTHON_VERSION=3.9 ARG TORCH_VERSION=2.0.0 -ARG OPSET_VERSION=15 +ARG OPSET_VERSION=17 ARG INSTALL_DEPS_EXTRA_ARGS #Add our own dependencies diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 index d29157daef611..e1caa141ef317 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 @@ -161,7 +161,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ARG PYTHON_VERSION=3.9 ARG TORCH_VERSION=2.1.0 -ARG OPSET_VERSION=15 +ARG OPSET_VERSION=17 ARG INSTALL_DEPS_EXTRA_ARGS #Add our own dependencies diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 4767c74afd28f..64710a982a29d 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -131,7 +131,7 @@ RUN pip install \ # Install migraphx RUN apt update && apt install -y migraphx -ENV ORTMODULE_ONNX_OPSET_VERSION=15 +ENV ORTMODULE_ONNX_OPSET_VERSION=17 ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev From 944d8f85135e0caf836ae7f6ad1bfac8dcba2f21 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Wed, 14 Feb 2024 12:49:34 -0800 Subject: [PATCH 082/207] Update the default std flag used during torch extensions compilation (#19516) --- .../torch_cpp_extensions/cpu/torch_interop_utils/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py index fa72f3b134917..898c242bb3c32 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py @@ -23,7 +23,7 @@ cur_file_dir, ] -extra_compile_args = {"cxx": ["-O3"]} +extra_compile_args = {"cxx": ["-O3", "-std=c++17"]} setup( name="torch_interop_utils", ext_modules=[ From 4e5119760d8cf1c2e751f4264f23ab3e5a25aebc Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 15 Feb 2024 08:46:03 +1000 Subject: [PATCH 083/207] Add initial support for CoreML ML Program to the CoreML EP. (#19347) ### Description Adds infrastructure to create an ML Package containing the Model using ML Program. Updated coremltools files to v7.1 to bring in new protobuf definitions along with the tools to write the weight.bin file and create an ML Package correctly. Enables building a CoreML Model on all platforms which means all the operator builder code can be debugged anywhere. Execution of the generated CoreML model is obviously limited to Apple platforms. The Conv operator builder has been updated to be able to generate an ML Program Operation. ### Motivation and Context NeuralNetwork is no longer being developed and ML Program is the replacement going forward. --- cmake/onnxruntime_providers.cmake | 6 +- cmake/onnxruntime_providers_coreml.cmake | 127 +++- cmake/onnxruntime_unittests.cmake | 18 +- .../coreml/coreml_provider_factory.h | 5 +- .../include/ort_coreml_execution_provider.h | 11 + objectivec/ort_coreml_execution_provider.mm | 5 +- .../providers/coreml/builders/coreml_spec.h | 24 +- .../core/providers/coreml/builders/helper.cc | 48 +- .../core/providers/coreml/builders/helper.h | 10 +- .../coreml/builders/impl/LRN_op_builder.cc | 22 +- .../builders/impl/activation_op_builder.cc | 28 +- .../coreml/builders/impl/argmax_op_builder.cc | 20 +- .../coreml/builders/impl/base_op_builder.cc | 96 +-- .../coreml/builders/impl/base_op_builder.h | 51 +- .../builders/impl/batch_norm_op_builder.cc | 23 +- .../coreml/builders/impl/binary_op_builder.cc | 47 +- .../coreml/builders/impl/builder_utils.cc | 184 ++++- .../coreml/builders/impl/builder_utils.h | 102 ++- .../coreml/builders/impl/cast_op_builder.cc | 25 +- .../coreml/builders/impl/clip_op_builder.cc | 31 +- .../coreml/builders/impl/concat_op_builder.cc | 17 +- .../coreml/builders/impl/conv_op_builder.cc | 435 +++++++---- .../builders/impl/depthtospace_op_builder.cc | 18 +- .../builders/impl/flatten_op_builder.cc | 22 +- .../coreml/builders/impl/gather_op_builder.cc | 27 +- .../coreml/builders/impl/gemm_op_builder.cc | 22 +- .../coreml/builders/impl/pad_op_builder.cc | 20 +- .../coreml/builders/impl/pool_op_builder.cc | 19 +- .../builders/impl/reduction_op_builder.cc | 24 +- .../builders/impl/reshape_op_builder.cc | 20 +- .../coreml/builders/impl/resize_op_builder.cc | 23 +- .../coreml/builders/impl/shape_op_builder.cc | 20 +- .../coreml/builders/impl/slice_op_builder.cc | 29 +- .../builders/impl/softmax_op_builder.cc | 34 +- .../coreml/builders/impl/split_op_builder.cc | 36 +- .../builders/impl/squeeze_op_builder.cc | 44 +- .../builders/impl/transpose_op_builder.cc | 15 +- .../coreml/builders/impl/unary_op_builder.cc | 18 +- .../coreml/builders/model_builder.cc | 717 ++++++++++++++++-- .../providers/coreml/builders/model_builder.h | 195 ++++- .../providers/coreml/builders/op_builder.h | 21 +- .../coreml/builders/op_builder_factory.h | 2 +- .../coreml/coreml_execution_provider.cc | 74 +- .../coreml/coreml_execution_provider.h | 15 +- .../core/providers/coreml/model/host_utils.h | 49 +- .../core/providers/coreml/model/host_utils.mm | 28 +- .../providers/coreml/model/host_utils_stub.cc | 40 + .../core/providers/coreml/model/model.h | 48 +- .../core/providers/coreml/model/model.mm | 46 +- .../core/providers/coreml/model/model_stub.cc | 38 + .../builders/impl/split_op_builder.cc | 4 +- .../core/providers/shared/utils/utils.cc | 136 ++-- .../core/providers/shared/utils/utils.h | 19 +- onnxruntime/test/util/default_providers.cc | 10 +- 54 files changed, 2131 insertions(+), 1037 deletions(-) create mode 100644 onnxruntime/core/providers/coreml/model/host_utils_stub.cc create mode 100644 onnxruntime/core/providers/coreml/model/model_stub.cc diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index c6c9d8f4894c5..7e7819ac31a19 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -66,11 +66,7 @@ if(onnxruntime_USE_CUDA) set(PROVIDERS_CUDA onnxruntime_providers_cuda) endif() if(onnxruntime_USE_COREML) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - set(PROVIDERS_COREML onnxruntime_providers_coreml coreml_proto) - else() - set(PROVIDERS_COREML onnxruntime_providers_coreml) - endif() + set(PROVIDERS_COREML onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_NNAPI_BUILTIN) set(PROVIDERS_NNAPI onnxruntime_providers_nnapi) diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake index 2ca4a22aca7d2..c9f35e5337f9b 100644 --- a/cmake/onnxruntime_providers_coreml.cmake +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -7,6 +7,27 @@ endif() add_compile_definitions(USE_COREML=1) +# Check if we can build the coremltools code for creating an mlpackage with an mlprogram. +# The coremltools source requires std::filesystem::path which is only available from iOS 13 on. +set(_enable_ML_PROGRAM ON) +if (IOS AND CMAKE_OSX_DEPLOYMENT_TARGET VERSION_LESS 13.0) + message(WARNING "CoreML ML Program is not supported on iOS < 13.0. Excluding ML Program support from build.") + set(_enable_ML_PROGRAM OFF) +elseif(LINUX) + # uuid-dev is required. we don't bother installing on CIs as it's really for manual developer testing. + find_library(LibUUID_LIBRARY NAMES uuid) + find_path(LibUUID_INCLUDE_DIR NAMES uuid/uuid.h) + if (NOT LibUUID_INCLUDE_DIR) + message(STATUS "uuid/uuid.h was not found as is required for ML Program support. " + "Run `sudo apt install uuid-dev` if you need to test ML Program related CoreML EP code. ") + set(_enable_ML_PROGRAM OFF) + endif() +endif() + +if (_enable_ML_PROGRAM) + add_compile_definitions(COREML_ENABLE_MLPROGRAM=1) +endif() + # Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml_proto set(COREML_PROTO_ROOT ${coremltools_SOURCE_DIR}/mlmodel/format) file(GLOB coreml_proto_srcs "${COREML_PROTO_ROOT}/*.proto") @@ -19,8 +40,8 @@ target_compile_definitions(coreml_proto PUBLIC $) set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") -set(_src_sub_dir "coreml_proto/") +set(_src_sub_dir "coreml_proto/") onnxruntime_protobuf_generate( APPEND_PATH GEN_SRC_SUB_DIR ${_src_sub_dir} @@ -55,6 +76,10 @@ file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" ) +file(GLOB onnxruntime_providers_coreml_public_headers CONFIGURE_DEPENDS + "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/coreml/*.h" +) + file(GLOB onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" @@ -67,15 +92,38 @@ file(GLOB_RECURSE "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" ) -if (NOT CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(REMOVE_ITEM onnxruntime_providers_coreml_cc_srcs_nested - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.cc" + +if(_enable_ML_PROGRAM) + # Add helpers to create mlpackage weights. limit to just the files we need to minimize the changes to make them + # build on Windows and Linux. + file(GLOB + onnxruntime_providers_coreml_milblob_cc_srcs CONFIGURE_DEPENDS + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.cpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Util/*.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/BlobDataType.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageFormat.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/FileWriter.?pp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageWriter.?pp" + ) + + # Add helpers to create mlpackage + file(GLOB + onnxruntime_providers_coreml_modelpackage_cc_srcs CONFIGURE_DEPENDS + "${coremltools_SOURCE_DIR}/modelpackage/src/ModelPackage.?pp" + "${coremltools_SOURCE_DIR}/modelpackage/src/Utils/JsonMap.?pp" ) + + set(coremltools_srcs + ${onnxruntime_providers_coreml_milblob_cc_srcs} + ${onnxruntime_providers_coreml_modelpackage_cc_srcs} + ) + + source_group(TREE ${coremltools_SOURCE_DIR} PREFIX coremltools FILES ${coremltools_srcs}) endif() # Add CoreML objective c++ source code -if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") +if (APPLE) file(GLOB onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" @@ -83,26 +131,79 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" ) +else() + # add the Model implementation that uses the protobuf types but excludes any actual CoreML dependencies + # by using stub implementations on non-Apple platforms. + file(GLOB + onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils_stub.cc" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model_stub.cc" + ) endif() set(onnxruntime_providers_coreml_cc_srcs ${onnxruntime_providers_coreml_cc_srcs_top} ${onnxruntime_providers_coreml_cc_srcs_nested} ${onnxruntime_providers_shared_utils_cc_srcs} + ${onnxruntime_providers_coreml_objcc_srcs} ) -source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_coreml_cc_srcs}) +source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_coreml_cc_srcs}) +source_group(TREE ${ONNXRUNTIME_INCLUDE_DIR} FILES ${onnxruntime_providers_coreml_public_headers}) + onnxruntime_add_static_library(onnxruntime_providers_coreml - ${onnxruntime_providers_coreml_cc_srcs} ${onnxruntime_providers_coreml_objcc_srcs} + ${onnxruntime_providers_coreml_public_headers} + ${onnxruntime_providers_coreml_cc_srcs} + ${coremltools_srcs} ) + onnxruntime_add_include_to_target(onnxruntime_providers_coreml - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 + safeint_interface ) -if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - onnxruntime_add_include_to_target(onnxruntime_providers_coreml coreml_proto) - target_link_libraries(onnxruntime_providers_coreml PRIVATE coreml_proto "-framework Foundation" "-framework CoreML") - add_dependencies(onnxruntime_providers_coreml coreml_proto) + +onnxruntime_add_include_to_target(onnxruntime_providers_coreml coreml_proto) +target_link_libraries(onnxruntime_providers_coreml PRIVATE coreml_proto) +add_dependencies(onnxruntime_providers_coreml coreml_proto) + +if (APPLE) + target_compile_definitions(onnxruntime_providers_coreml PRIVATE __APPLE__) endif() + +if (_enable_ML_PROGRAM) + # Setup coremltools fp16 and json dependencies for creating an mlpackage. + # + # These are also used by external/xnnpack.cmake. fp16 depends on psimd + FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd}) + onnxruntime_fetchcontent_makeavailable(psimd) + set(PSIMD_SOURCE_DIR ${psimd_SOURCE_DIR}) + FetchContent_Declare(fp16 URL ${DEP_URL_fp16} URL_HASH SHA1=${DEP_SHA1_fp16}) + set(FP16_BUILD_TESTS OFF CACHE INTERNAL "") + set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") + onnxruntime_fetchcontent_makeavailable(fp16) + + # need to tweak the include paths to match what the coreml source code expects + target_include_directories(onnxruntime_providers_coreml PRIVATE + ${fp16_SOURCE_DIR}/include + ${nlohmann_json_SOURCE_DIR}/single_include/nlohmann + ${coremltools_SOURCE_DIR} + ${coremltools_SOURCE_DIR}/mlmodel/src/ + ${coremltools_SOURCE_DIR}/modelpackage/src/ + ) + + add_dependencies(onnxruntime_providers_coreml nlohmann_json::nlohmann_json fp16) + + if (LINUX) + target_link_libraries(onnxruntime_providers_coreml PRIVATE uuid) + endif() +endif() + +if (APPLE) + target_link_libraries(onnxruntime_providers_coreml PRIVATE "-framework Foundation" "-framework CoreML") +endif() + add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 308caad296831..3ed695327c183 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -567,11 +567,7 @@ if(onnxruntime_USE_ROCM) endif() if(onnxruntime_USE_COREML) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) - else() - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) - endif() + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_ACL) @@ -676,15 +672,9 @@ endif() if(onnxruntime_USE_COREML) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/coreml/*) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) - list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto) - else() - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) - list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml) - endif() + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) + list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_XNNPACK) diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h index 03715eb5b78b2..55abb90b981f5 100644 --- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h +++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h @@ -28,9 +28,12 @@ enum COREMLFlags { // dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes. COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008, + // Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. + COREML_FLAG_CREATE_MLPROGRAM = 0x010, + // Keep COREML_FLAG_LAST at the end of the enum definition // And assign the last COREMLFlag to it - COREML_FLAG_LAST = COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES, + COREML_FLAG_LAST = COREML_FLAG_CREATE_MLPROGRAM, }; #ifdef __cplusplus diff --git a/objectivec/include/ort_coreml_execution_provider.h b/objectivec/include/ort_coreml_execution_provider.h index a015b6fd60c8f..6ff18176ebeb2 100644 --- a/objectivec/include/ort_coreml_execution_provider.h +++ b/objectivec/include/ort_coreml_execution_provider.h @@ -41,6 +41,17 @@ NS_ASSUME_NONNULL_BEGIN */ @property BOOL onlyEnableForDevicesWithANE; +/** + * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also allow inputs with + * dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes. + */ +@property BOOL onlyAllowStaticInputShapes; + +/** + * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. + */ +@property BOOL createMLProgram; + @end @interface ORTSessionOptions (ORTSessionOptionsCoreMLEP) diff --git a/objectivec/ort_coreml_execution_provider.mm b/objectivec/ort_coreml_execution_provider.mm index 6340fdea1c3a7..58b47d68eea63 100644 --- a/objectivec/ort_coreml_execution_provider.mm +++ b/objectivec/ort_coreml_execution_provider.mm @@ -26,7 +26,10 @@ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOpti const uint32_t flags = (options.useCPUOnly ? COREML_FLAG_USE_CPU_ONLY : 0) | (options.enableOnSubgraphs ? COREML_FLAG_ENABLE_ON_SUBGRAPH : 0) | - (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0); + (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0) | + (options.onlyAllowStaticInputShapes ? COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES : 0) | + (options.createMLProgram ? COREML_FLAG_CREATE_MLPROGRAM : 0); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML( [self CXXAPIOrtSessionOptions], flags)); return YES; diff --git a/onnxruntime/core/providers/coreml/builders/coreml_spec.h b/onnxruntime/core/providers/coreml/builders/coreml_spec.h index e9cd4af94e5fd..c9adba9e579d0 100644 --- a/onnxruntime/core/providers/coreml/builders/coreml_spec.h +++ b/onnxruntime/core/providers/coreml/builders/coreml_spec.h @@ -3,12 +3,28 @@ #pragma once -// TODO come up with a more intuitive way of limiting this to Apple platform builds -// E.g., putting CoreML EP files that should be enabled iff `defined(__APPLE__)` in a separate directory. -#if !defined(__APPLE__) -#error "This file should only be included when building on Apple platforms." +#include "onnxruntime_config.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic push + +// Disable warning from protobuf code. +// +// In file included from coreml_proto/Model.pb.h:30: +// In file included from _deps/protobuf-src/src/google/protobuf/extension_set.h:53: +// _deps/protobuf-src/src/google/protobuf/parse_context.h:328:47: +// error: implicit conversion loses integer precision: 'long' to 'int' [-Werror,-Wshorten-64-to-32] +#ifdef HAS_SHORTEN_64_TO_32 +#pragma GCC diagnostic ignored "-Wshorten-64-to-32" +#endif #endif +// Model.pb.h is generated in the build output directory from the CoreML protobuf files in +// onnxruntime/core/providers/coreml/coremltools/mlmodel/format #include "coreml_proto/Model.pb.h" +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + namespace COREML_SPEC = CoreML::Specification; diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc index 897856256cc79..bc3ba4432e66d 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.cc +++ b/onnxruntime/core/providers/coreml/builders/helper.cc @@ -22,22 +22,35 @@ namespace onnxruntime { namespace coreml { -OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, uint32_t coreml_flags) { +OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, + int32_t coreml_version, + uint32_t coreml_flags) { return OpBuilderInputParams{graph_viewer, - (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0}; + coreml_version, + (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0, + (coreml_flags & COREML_FLAG_CREATE_MLPROGRAM) != 0}; } -bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) { +const IOpBuilder* GetOpBuilder(const Node& node) { const auto& op_builders = GetOpBuilders(); - if (Contains(op_builders, node.OpType())) { - const auto* op_builder = op_builders.at(node.OpType()); + const auto it = op_builders.find(node.OpType()); + if (it != op_builders.cend()) { + return it->second; + } + + return nullptr; +} + +bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) { + const auto* op_builder = GetOpBuilder(node); + if (op_builder) { return op_builder->IsOpSupported(node, input_params, logger); } else { return false; } } -bool IsInputSupported(const NodeArg& input, const std::string& parent_name, +bool IsInputSupported(const Node& node, const NodeArg& input, const OpBuilderInputParams& input_params, const logging::Logger& logger) { if (!input.Exists()) { // optional input that is not provided @@ -48,8 +61,8 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, std::vector shape; // We do not support input with no shape if (!GetShape(input, shape, logger)) { - LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name - << "] has no shape"; + LOGS(logger, VERBOSE) << MakeString("Input [", input_name, "] of Node [", node.Name(), "] type [", node.OpType(), + "] has no shape"); return false; } @@ -63,11 +76,19 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, // For some undocumented reason, Apple CoreML framework will fail loading the model if the model // input has dimension > 16384 // See this issue, https://github.com/apple/coremltools/issues/1003 + // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf has maximum texture widths which may be the + // root cause. if (dim > 16384) { LOGS(logger, WARNING) << "CoreML does not support input dim > 16384. Input:" << input_name << ", shape: " << Shape2String(shape); return false; } + + if (dim == 0) { + LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name + << ", shape: " << Shape2String(shape); + return false; + } } // Limit input shape rank to 5. @@ -87,13 +108,6 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe const logging::Logger& logger) { std::unordered_set supported_nodes{}; -#ifdef __APPLE__ - if (!util::HasRequiredBaseOS()) { - LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because we do not have supported OS"; - return supported_nodes; - } -#endif - for (const auto& node : graph_viewer.Nodes()) { const bool supported = IsNodeSupported(node, input_params, logger); LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType() @@ -149,7 +163,9 @@ bool HasNeuralEngine(const logging::Logger& logger) { #else // In this case, we are running the EP on non-apple platform, which means we are running the model // conversion with CoreML EP enabled, for this we always assume the target system has Neural Engine - LOGS(logger, VERBOSE) << "HasNeuralEngine running on non-Apple hardware for model conversion only"; + LOGS(logger, INFO) << "HasNeuralEngine running on non-Apple hardware. " + "Returning true to enable model conversion and local testing of CoreML EP implementation. " + "No CoreML model will be compiled or run."; has_neural_engine = true; #endif // #ifdef __APPLE__ diff --git a/onnxruntime/core/providers/coreml/builders/helper.h b/onnxruntime/core/providers/coreml/builders/helper.h index d8b27ac76ae73..300de2dedd122 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.h +++ b/onnxruntime/core/providers/coreml/builders/helper.h @@ -23,10 +23,14 @@ class Logger; namespace coreml { -OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, uint32_t coreml_flags); +OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, + int32_t coreml_version, + uint32_t coreml_flags); -bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, - const OpBuilderInputParams& input_params, const logging::Logger& logger); +const IOpBuilder* GetOpBuilder(const Node& node); + +bool IsInputSupported(const Node& node, const NodeArg& node_arg, const OpBuilderInputParams& input_params, + const logging::Logger& logger); bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc index 53f18b205880c..e9e520156576e 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc @@ -3,39 +3,26 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class LRNOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ - Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_lrn = layer->mutable_lrn(); @@ -56,9 +43,6 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool LRNOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index 88d6616b4e097..dee87ce3632a8 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -2,44 +2,32 @@ // Licensed under the MIT License. #include "core/common/narrow.h" +#include "core/framework/tensorprotoutils.h" #include "core/optimizer/initializer.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/framework/tensorprotoutils.h" -#include "core/providers/coreml/builders/impl/builder_utils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ActivationOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + int GetMinSupportedOpSet(const Node& node) const override; }; -// Add operator related - -#ifdef __APPLE__ void ActivationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); @@ -86,7 +74,7 @@ Status AddPReluWeight(ModelBuilder& model_builder, const Node& node, Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& op_type(node.OpType()); if (op_type == "Sigmoid") { @@ -115,14 +103,10 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related namespace { // assumes that node.OpType() == "PRelu" -bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_params, - const logging::Logger& logger) { +bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); // X input rank must be 3 or 4 diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index 7a5d4a5af673b..e9a8176c8349b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -1,37 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class ArgMaxOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& graph_viewer = model_builder.GetGraphViewer(); NodeAttrHelper helper(node); @@ -67,9 +56,6 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 25d5bad14ceb6..2570e6d88ae0d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -1,21 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif +using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { -// Shared functions - +namespace { // TODO, move this to shared_library bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, const logging::Logger& logger) { @@ -37,93 +34,78 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node return false; } +} // namespace -// Add operator related -#ifdef __APPLE__ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, - const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - ORT_RETURN_IF_NOT( - IsOpSupported(node, input_params, logger), - "Unsupported operator ", - node.OpType()); - - ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); - LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() - << "] type: [" << node.OpType() << "] was added"; - return Status::OK(); -} + Status status = AddToModelBuilderImpl(model_builder, node, logger); -/* static */ std::unique_ptr -BaseOpBuilder::CreateNNLayer(ModelBuilder& model_builder, const Node& node) { - auto layer_name = node.Name(); - if (layer_name.empty()) { - // CoreML requires layer has a name, while the node name is optional in ONNX - // In this case, create a unique name for the layer - layer_name = model_builder.GetUniqueName(MakeString("Node_", node.Index(), "_type_", node.OpType())); + if (status.IsOK()) { + LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() << "] type: [" << node.OpType() << "] was added"; } - return CreateNNLayer(layer_name); -} -/* static */ std::unique_ptr -BaseOpBuilder::CreateNNLayer(const std::string& layer_name) { - std::unique_ptr layer = std::make_unique(); - layer->set_name(layer_name); - return layer; + return status; } -#endif - -// Operator support related bool BaseOpBuilder::IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, input_params, logger)) + if (input_params.create_mlprogram && !SupportsMLProgram()) { + LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] does not support MLProgram"; return false; + } - // We do not support external initializers for now - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (HasExternalInitializer(initializers, node, logger)) + if (!HasSupportedOpSet(node, logger)) { + return false; + } + + if (!HasSupportedInputs(node, input_params, logger)) { return false; + } - if (!HasSupportedOpSet(node, logger)) + // We do not support external initializers for now + const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); + if (HasExternalInitializer(initializers, node, logger)) { return false; + } return IsOpSupportedImpl(node, input_params, logger); } bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { - if (!IsInputSupported(*input, node_name, input_params, logger)) { + if (!IsInputSupported(node, *input, input_params, logger)) { return false; } } - return HasSupportedInputsImpl(node, logger); + return HasSupportedInputsImpl(node, input_params, logger); } -bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { - // We only check the type of input 0 by default - // specific op builder can override this +/* static */ +bool BaseOpBuilder::IsInput0Supported(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) { const auto& input = *node.InputDefs()[0]; - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; + int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; - if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Input type: [" << input_type - << "] is not supported for now"; + // currently only float is supported + if (!GetType(input, input_type, logger) || input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; return false; } return true; } -bool BaseOpBuilder::HasSupportedOpSet(const Node& node, - const logging::Logger& logger) const { +bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + // We only check the type of input 0 by default + // specific op builder can override this + return IsInput0Supported(node, input_params, logger); +} + +bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const { auto since_version = node.SinceVersion(); if (since_version < GetMinSupportedOpSet(node) || since_version > GetMaxSupportedOpSet(node)) { LOGS(logger, VERBOSE) << node.OpType() << "is only supported for opset [" diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index b4132d3b770ec..06c4dd94ea30d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -3,11 +3,9 @@ #pragma once -#include "core/providers/coreml/builders/op_builder.h" - -#ifdef __APPLE__ +#include "core/common/span_utils.h" #include "core/providers/coreml/builders/coreml_spec.h" -#endif +#include "core/providers/coreml/builders/op_builder.h" namespace onnxruntime { namespace coreml { @@ -18,45 +16,40 @@ class BaseOpBuilder : public IOpBuilder { public: virtual ~BaseOpBuilder() = default; - // Add operator related + // does the operator implementation support creating an ML Program + bool SupportsMLProgram() const override { return false; } + + bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override final; -#ifdef __APPLE__ - public: - virtual void AddInitializersToSkip(ModelBuilder& /* model_builder */, const Node& /* node */) const override {} Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node, - const OpBuilderInputParams& input_params, const logging::Logger& logger) const override final; - protected: - virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const = 0; - - static std::unique_ptr - CreateNNLayer(ModelBuilder& model_builder, const Node& node); - - static std::unique_ptr CreateNNLayer(const std::string& layer_name); -#endif - - // Operator support related - public: - bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, - const logging::Logger& logger) const override final; + void AddInitializersToSkip(ModelBuilder& /*model_builder*/, const Node& /*node*/) const override {} protected: - virtual bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, - const logging::Logger& /* logger */) const { + // check if the first input's data type is supported. + static bool IsInput0Supported(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger); + + private: + virtual bool IsOpSupportedImpl(const Node& /*node*/, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& /*logger*/) const { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const; + virtual bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const; - virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } - virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 20; } + virtual int GetMinSupportedOpSet(const Node& /*node*/) const { return 1; } + virtual int GetMaxSupportedOpSet(const Node& /*node*/) const { return 20; } - private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; bool HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const; + + virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const = 0; }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc index 391b02eaec497..8da58f659acf1 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc @@ -5,30 +5,20 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class BatchNormalizationOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -36,9 +26,6 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 7; } }; -// Add operator related - -#ifdef __APPLE__ void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { // skip everything except input0 for BatchNormalization const auto& input_defs = node.InputDefs(); @@ -48,10 +35,9 @@ void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_buil model_builder.AddInitializerToSkip(input_defs[4]->Name()); // var } -Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, - const Node& node, +Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); @@ -81,9 +67,6 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index 10c9b32d03f37..6074fba1433d9 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -1,35 +1,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/framework/tensorprotoutils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - -#include "base_op_builder.h" namespace onnxruntime { namespace coreml { - class BinaryOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related + int GetMinSupportedOpSet(const Node& node) const override; - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; }; -#ifdef __APPLE__ -static bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger) { +namespace { +bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); const auto* x_shape_proto = input_defs[0]->Shape(); @@ -57,15 +50,14 @@ static bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& y_shape_proto->dim().begin(), y_shape_proto->dim().end(), dim_eq); } - -// Add operator related +} // namespace Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "Add") { // original mutable_add() has limited broadcasting support @@ -99,31 +91,28 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related int BinaryOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { // Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now return 7; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { - bool is_pow = node.OpType() == "Pow"; - if (!is_pow) { - return BaseOpBuilder::HasSupportedInputsImpl(node, logger); +bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + if (node.OpType() != "Pow") { + return IsInput0Supported(node, input_params, logger); } const auto& input_1 = *node.InputDefs()[0]; const auto& input_2 = *node.InputDefs()[1]; + // Pow we only support both inputs as fp32 for now int32_t input_type_1; - if (!GetType(input_1, input_type_1, logger)) - return false; - int32_t input_type_2; - if (!GetType(input_2, input_type_2, logger)) + if (!GetType(input_1, input_type_1, logger) || + !GetType(input_2, input_type_2, logger)) { return false; + } if (input_type_1 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || input_type_1 != input_type_2) { LOGS(logger, VERBOSE) << "Pow only supports fp32 inputs, actual input type" diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index ef66e6b877a1f..710f596b2a562 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -1,17 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef __APPLE__ - #include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/common/narrow.h" #include "core/framework/tensorprotoutils.h" +#include "core/providers/coreml/builders/coreml_spec.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/shared/utils/utils.h" #include "core/optimizer/initializer.h" -#include "coreml_proto/NeuralNetwork.pb.h" +using namespace COREML_SPEC; namespace onnxruntime { namespace coreml { @@ -133,7 +132,182 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span> shape) { + tensor_type.set_datatype(data_type); + if (shape) { + tensor_type.set_rank(shape->size()); + for (const auto& dim : *shape) { + if (dim >= 0) { + tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim)); + } else { + tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + } + } + } +} + +void SetTensorTypeInfo(MILSpec::TensorType& tensor_type, MILSpec::DataType data_type, + const ONNX_NAMESPACE::TensorShapeProto* shape) { + tensor_type.set_datatype(data_type); + if (shape) { + tensor_type.set_rank(shape->dim_size()); + for (const auto& dim : shape->dim()) { + if (dim.has_dim_value()) { + tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim.dim_value())); + } else { + tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + } + } + } +} + +template +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + // need a 'false' that is dependent on the template types to make gcc happy and give a meaningful error message. + static_assert(false_for_T && false_for_T, "Unsupported data type"); // add specializations below as needed +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_floats()->mutable_values()->Add(data.begin(), data.end()); +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_ints()->mutable_values()->Add(data.begin(), data.end()); +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_strings()->mutable_values()->Add(data.begin(), data.end()); +} + +// copy int64_t (used by ONNX for strides/indexes/etc.) to int32_t (used by CoreML) +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + auto& int32_out = *tensor_value.mutable_ints()->mutable_values(); + int32_out.Reserve(narrow(data.size())); + for (const int64_t v : data) { + int32_out.AddAlreadyReserved(narrow(v)); + } +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_bools()->mutable_values()->Add(data.begin(), data.end()); +} + +} // namespace + +MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type) { + switch (static_cast(onnx_type)) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return MILSpec::DataType::FLOAT32; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + return MILSpec::DataType::FLOAT64; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + return MILSpec::DataType::BFLOAT16; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return MILSpec::DataType::FLOAT16; + + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + return MILSpec::DataType::INT8; + case ONNX_NAMESPACE::TensorProto_DataType_INT16: + return MILSpec::DataType::INT16; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + return MILSpec::DataType::INT32; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + return MILSpec::DataType::INT64; + + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + return MILSpec::DataType::UINT8; + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: + return MILSpec::DataType::UINT16; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + return MILSpec::DataType::UINT32; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + return MILSpec::DataType::UINT64; + + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + return MILSpec::DataType::BOOL; + case ONNX_NAMESPACE::TensorProto_DataType_STRING: + return MILSpec::DataType::STRING; + default: + ORT_THROW("Unsupported data type: ", onnx_type); + } +} + +template +MILSpec::Value CreateTensorValue(const gsl::span data, + std::optional> shape) { + MILSpec::Value value; + MILSpec::TensorType& tensor_type = *value.mutable_type()->mutable_tensortype(); + + if (shape) { + SetTensorTypeInfo(tensor_type, DataTypeToMILSpec(), *shape); + } else { + // infer as 1D shape + std::vector coreml_shape{narrow(data.size())}; + SetTensorTypeInfo(tensor_type, DataTypeToMILSpec(), coreml_shape); + } + + MILSpec::TensorValue& tensor_value = *value.mutable_immediatevalue()->mutable_tensor(); + CopyDataToTensorValue(tensor_value, data); + + return value; +} + +template +MILSpec::Value CreateScalarTensorValue(const T& data) { + gsl::span data_span{&data, 1}; + std::vector shape = {}; // empty for scalar + return CreateTensorValue(data_span, shape); +} + +// explicit specializations for types we handle so the implementation can be in the .cc file +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); + +template MILSpec::Value CreateScalarTensorValue(const float& data); +template MILSpec::Value CreateScalarTensorValue(const int32_t& data); +template MILSpec::Value CreateScalarTensorValue(const std::string& data); +template MILSpec::Value CreateScalarTensorValue(const bool& data); + +COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg) { + MILSpec::NamedValueType nvt; + nvt.set_name(node_arg.Name()); + MILSpec::TensorType& tensor_type = *nvt.mutable_type()->mutable_tensortype(); + + SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(node_arg.TypeAsProto()->tensor_type().elem_type()), + node_arg.Shape()); + + return nvt; +} + +void AddOperationInput(MILSpec::Operation& op, std::string_view input_name, std::string_view value_name) { + MILSpec::Argument arg; + arg.mutable_arguments()->Add()->set_name(std::string(value_name)); + + (*op.mutable_inputs())[input_name] = std::move(arg); +} + +void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output) { + auto& outputs = *op.mutable_outputs(); + auto& output_arg = *outputs.Add(); + output_arg.set_name(output.Name()); + + MILSpec::ValueType& value = *output_arg.mutable_type(); + MILSpec::TensorType& tensor_type = *value.mutable_tensortype(); + + SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(output.TypeAsProto()->tensor_type().elem_type()), + output.Shape()); +} + } // namespace coreml } // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 23b11928f7dc2..8126f0c126914 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -5,22 +5,19 @@ #pragma once -#ifdef __APPLE__ +#include #include "core/common/gsl.h" #include "core/common/status.h" #include "core/graph/basic_types.h" #include "core/providers/common.h" -namespace CoreML { -namespace Specification { -class WeightParams; -} -} // namespace CoreML +#include "core/providers/coreml/builders/coreml_spec.h" namespace onnxruntime { -namespace coreml { +class NodeArg; +namespace coreml { // Try to see if we can map explicit padding to auto padding for Conv/Pool // Since usually use auto padding is more efficient Status HandleAutoPad(const std::vector input_shape, @@ -32,6 +29,10 @@ Status HandleAutoPad(const std::vector input_shape, AutoPadType auto_pad_type, AutoPadType& auto_pad_type_out); +// +// NeuralNetwork utils +// + // Copy an onnx initializer data to a coreml weight Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, const ONNX_NAMESPACE::TensorProto& tensor); @@ -44,7 +45,90 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); +// +// MLProgram utils +// + +// helper for static_assert where the value needs to be dependent on a template parameter +template +constexpr bool false_for_T = false; + +template +COREML_SPEC::MILSpec::DataType DataTypeToMILSpec() { + if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::FLOAT32; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::FLOAT64; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::BFLOAT16; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::FLOAT16; + + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT8; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT16; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT32; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT64; + + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT8; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT16; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT32; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT64; + + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::BOOL; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::STRING; + } else { + static_assert(false_for_T, "Unsupported type."); + } +} + +// The TensorProto.data_type field is an int, but must be a valid TensorProto_DataType value. +// Use int for the arg so the caller can pass TensorProto.data_type() value and do the cast to enum internally +COREML_SPEC::MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type); + +/// +/// Create a CoreML MILSpec::TensorValue for the given input data. +/// +/// Original C++ data type +/// CoreML C++ data type +/// ONNX data +/// ONNX data shape. Inferred to be a 1D shape of `{data.size()}` if not specified. +/// TensorValue containing data. +template +COREML_SPEC::MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape = std::nullopt); + +template +COREML_SPEC::MILSpec::Value CreateScalarTensorValue(const T& data); + +/// Create a NamedValueType from an ONNX tensor NodeArg. +/// Used to create inputs for the 'main' function in an ML Program. +COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg); + +/// +/// Add an input argument to a MILSpec::Operation +/// +/// Operation to update. +/// The input name defined by the spec for the operation. +/// The name of the value that is providing the input. +/// "https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html" +void AddOperationInput(COREML_SPEC::MILSpec::Operation& op, + std::string_view input_name, std::string_view value_name); + +/// +/// Add an output to a MILSpec::Operation. Name, data type and shape are used from the NodeArg. +/// +/// Operation to update. +/// NodeArg with details of output to add. +void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output); } // namespace coreml } // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index 15ee1f0fc7284..70053c2c606a0 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -1,34 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared/utils/utils.h" #include "core/providers/coreml/builders/helper.h" -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class CastOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; -}; -// Add operator related + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; +}; -#ifdef __APPLE__ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */, const Node& /* node */, const logging::Logger& /* logger */) const { @@ -37,9 +28,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */, // Cast node is not provided in CoreML model, so we're skipping adding the Cast node here. return Status::OK(); } -#endif - -// Operator support related bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -84,7 +72,8 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return true; } -bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { +bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { // We only check the type of input 0 const auto& input = *node.InputDefs()[0]; diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index a298a8d12c741..9aca172abec98 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -1,37 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace coreml { class ClipOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ void ClipOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { // Both min and max values will be injected into the layer, no need to add to the model if (node.SinceVersion() >= 11) { @@ -58,7 +45,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (!has_min && !has_max) { // Clip without min/max is an identity node // In CoreML we don't have identity, use ActivationLinear instead - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); layer->mutable_activation()->mutable_linear()->set_alpha(1.0f); *layer->mutable_input()->Add() = input_name; *layer->mutable_output()->Add() = output_name; @@ -83,8 +70,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Handle clipping at min first if (has_min) { - const auto clip_min_layer_name = model_builder.GetUniqueName(MakeString(node_name, "_Clip_min")); - std::unique_ptr min_layer = CreateNNLayer(clip_min_layer_name); + std::unique_ptr min_layer = model_builder.CreateNNLayer(node, "_Clip_min"); if (min == 0.0f) { // If min is 0. then this min will be handled by relu min_layer->mutable_activation()->mutable_relu(); } else { // otherwise, min will be handled by unary->threshold @@ -101,9 +87,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (has_max) { const auto threshold_output_name = model_builder.GetUniqueName(MakeString(node_name, "threshold_output")); { // Add threshold layer, which is actually max( -1 * min_output, -max) - const auto clip_max_threshold_layer_name = - model_builder.GetUniqueName(MakeString(node_name, "_Clip_max_threshold")); - auto threshold_layer = CreateNNLayer(clip_max_threshold_layer_name); + auto threshold_layer = model_builder.CreateNNLayer(node, "_Clip_max_threshold"); threshold_layer->mutable_unary()->set_alpha(-max); threshold_layer->mutable_unary()->set_scale(-1.0f); threshold_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); @@ -112,9 +96,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(threshold_layer)); } { // Add linear activation layer -1 * threshold_output - const auto clip_max_linear_layer_name = - model_builder.GetUniqueName(MakeString(node_name, "_Clip_max_linear")); - auto linear_layer = CreateNNLayer(clip_max_linear_layer_name); + auto linear_layer = model_builder.CreateNNLayer(node, "_Clip_max_linear"); linear_layer->mutable_activation()->mutable_linear()->set_alpha(-1.0f); *linear_layer->mutable_input()->Add() = threshold_output_name; *linear_layer->mutable_output()->Add() = output_name; @@ -125,9 +107,6 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related bool ClipOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc index b1e761024f5c9..34193318a0264 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc @@ -4,37 +4,26 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ConcatOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); layer->mutable_concat()->set_sequenceconcat(false); @@ -48,9 +37,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif -// Operator support related bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc index ff9dcbd9f8874..05e43dbbd16af 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc @@ -4,39 +4,35 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" -#include "core/providers/coreml/builders/op_builder_factory.h" -#include "core/providers/shared/utils/utils.h" - -#ifdef __APPLE__ #include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" -#endif +#include "core/providers/shared/utils/utils.h" + +using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { class ConvOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, const logging::Logger& /* logger */) const override; -}; -// Add operator related + bool SupportsMLProgram() const override { return true; } +}; -#ifdef __APPLE__ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + if (model_builder.CreateMLProgram()) { + // we add the initializers as 'const' operations via ModelBuilder::RegisterInitializers + return; + } + const auto& input_defs = node.InputDefs(); // skip the weight and bias (if has it) for conv as we will directly set those as part of the NN layer @@ -49,136 +45,251 @@ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); - const auto& input_defs = node.InputDefs(); const auto& output_defs = node.OutputDefs(); const auto& input_name = input_defs[0]->Name(); const auto& output_name = output_defs[0]->Name(); - const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - std::vector weight_shape = {weight_tensor.dims().cbegin(), weight_tensor.dims().cend()}; + NodeAttrHelper helper(node); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; - const bool is_1d_conv = (weight_shape.size() == 3); + // https://github.com/apple/coremltools/blob/7.1/coremltools/converters/mil/mil/ops/defs/iOS15/conv.py - if (is_1d_conv) { - // weight_shape needs to be expanded from MXCXH->MXCXHx1 - weight_shape.push_back(1); - } + std::unique_ptr conv_op = model_builder.CreateOperation(node, "conv"); - NodeAttrHelper helper(node); - auto strides = helper.Get("strides", std::vector{1, 1}); - auto dilations = helper.Get("dilations", std::vector{1, 1}); - auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - // Strides/dilations for 1d conv is normally of length 1. Expand them by 1 - // to meet the required length 2 (for 2d conv it's normally 2) - // Similarly 1d conv normally has a length 2 padding. Expand it to length 4 by adding additional zeros. - if (is_1d_conv) { - if (strides.size() < 2) { - ORT_RETURN_IF_NOT(strides.size() == 1, "strides size does not equal 1 for Conv 1d"); - strides.push_back(1); + AddOperationInput(*conv_op, "x", input_name); + AddOperationInput(*conv_op, "weight", input_defs[1]->Name()); + + if (input_defs.size() > 2) { + AddOperationInput(*conv_op, "bias", input_defs[2]->Name()); } - if (dilations.size() < 2) { - ORT_RETURN_IF_NOT(dilations.size() == 1, "dilations size does not equal 1 for Conv 1d"); - dilations.push_back(1); + + // ONNX attributes. Add as inputs if specified/required + auto strides = helper.GetInt64s("strides"); + auto dilations = helper.GetInt64s("dilations"); + auto groups = helper.GetInt64("group"); + + // we know this input has a valid shape due to the check in IsOpSupportedImpl. ignore N and C dims. + const auto num_spatial_dims = input_defs[1]->Shape()->dim_size() - 2; + const auto& op_type = conv_op->type(); + + if (strides) { + AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", *strides)); + } else { + // spec says optional. testing suggests otherwise for at least the iOS15 target (CoreML5) + static const auto default_value = std::vector(num_spatial_dims, 1); + AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", default_value)); } - if (onnx_pads.size() < 4) { - ORT_RETURN_IF_NOT(onnx_pads.size() == 2, "onnx_pads size does not equal 2 for Conv 1d"); - onnx_pads.insert(onnx_pads.begin() + 1, 0); - onnx_pads.push_back(0); + + if (dilations) { + AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", *dilations)); + } else { + // spec says optional. testing suggests otherwise for at least the iOS15 target (CoreML5) + static const auto default_value = std::vector(num_spatial_dims, 1); + AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", default_value)); } - } - const auto group = helper.Get("group", static_cast(1)); - - auto* coreml_conv = layer->mutable_convolution(); - - std::string expand_output_name = model_builder.GetUniqueName(node.Name() + "_expandDims"); - - if (is_1d_conv) { - const auto expand_layer_name = model_builder.GetUniqueName(MakeString(node.Name(), "_Conv_expand")); - std::unique_ptr expand_layer = CreateNNLayer(expand_layer_name); - // Add an expanddims layer here. CoreML only supports 2d convolution, so for 1d Conv case - // we need to add an additional dimension here to the input to make it "2d Conv" like. - // NxCxH -> NxCxHx1 - expand_layer->mutable_expanddims()->add_axes(-1); - *expand_layer->mutable_input()->Add() = input_name; - *expand_layer->mutable_output()->Add() = expand_output_name; - model_builder.AddLayer(std::move(expand_layer)); - } - coreml_conv->set_outputchannels(weight_shape[0]); // M - coreml_conv->set_kernelchannels(weight_shape[1]); // C/Group - coreml_conv->add_kernelsize(weight_shape[2]); // H - coreml_conv->add_kernelsize(weight_shape[3]); // W - coreml_conv->set_ngroups(group); - *coreml_conv->mutable_stride() = {strides.cbegin(), strides.cend()}; - *coreml_conv->mutable_dilationfactor() = {dilations.cbegin(), dilations.cend()}; - - coreml_conv->set_isdeconvolution(false); - - // Add Padding - // Usually using autopadding is more efficient than using explicit padding - // Try to see if we can map explicit padding to auto padding - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - AutoPadType auto_pad_type; - ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], - onnx_pads, strides, dilations, - StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), - auto_pad_type)); - - if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { - auto* padding_type = coreml_conv->mutable_same(); - if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER - padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + + if (groups) { + AddOperationInput(*conv_op, "groups", model_builder.AddScalarConstant(op_type, "groups", *groups)); } - } else { - auto* padding_type = coreml_conv->mutable_valid(); - if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { - // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts - auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - height_border->set_startedgesize(onnx_pads[0]); - height_border->set_endedgesize(onnx_pads[2]); - auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - width_border->set_startedgesize(onnx_pads[1]); - width_border->set_endedgesize(onnx_pads[3]); + + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); + + // pad type (string) + // valid - no pads (ONNX auto_pad VALID) + // custom - pads input (ONNX NOTSET) + // same - inferred to be `d_out[i] = ceil(d_in[i] / strides[i])` (assuming == ONNX SAME_UPPER) + // same_lower - as per same but any extra rows/cols are added at top/left if padding is odd (ONNX SAME_LOWER) + // + // TODO: See if we want to update HandleAutoPad to support 1D (and 3D) so we can infer if an autopad value + // can be used. TBD if that provides any performance benefit with ML Program though as CoreML could + // potentially do that for us. + switch (auto_pad_type) { + case AutoPadType::NOTSET: { + // use `pads` attribute. + auto onnx_pads = helper.GetInt64s("pads"); // 'pads' must be provided if auto_pad is NOTSET + if (onnx_pads) { + AddOperationInput(*conv_op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string("custom"))); + + // need to re-order from x1_start, x2_start..., x1_end, x2_end... to + // x1_start, x1_end, x2_start, x2_end,... + size_t num_pads = onnx_pads->size(); + size_t num_dims = num_pads / 2; + std::vector reordered_pads(num_pads, 0); + for (size_t i = 0; i < num_pads; ++i) { + auto cur_dim = i % num_dims; + if (i < num_dims) { // start values + reordered_pads[cur_dim * 2] = (*onnx_pads)[i]; + } else { // end values + reordered_pads[cur_dim * 2 + 1] = (*onnx_pads)[i]; + } + } + + AddOperationInput(*conv_op, "pad", model_builder.AddConstant(op_type, "pad", reordered_pads)); + + break; + } + + // in theory the pads may not be provided and in that case the default is no padding. + // as that is the same as 'valid', fall through + [[fallthrough]]; + } + case AutoPadType::VALID: + AddOperationInput(*conv_op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string("valid"))); + + break; + case AutoPadType::SAME_UPPER: + case AutoPadType::SAME_LOWER: { + const auto pad_type = (auto_pad_type == AutoPadType::SAME_UPPER ? "same" : "same_lower"); + AddOperationInput(*conv_op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string(pad_type))); + + // despite what the spec says, a 'pad' input seems to be required. + // https://github.com/apple/coremltools/issues/2127 + // provide the default value. passing in an empty vector also works. TBD what's better. + std::vector ignored_pads(num_spatial_dims * 2, 0); + AddOperationInput(*conv_op, "pad", model_builder.AddConstant(op_type, "pad", ignored_pads)); + + break; + } } - } - // Add weight - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_weights(), weight_tensor)); + // set output + AddOperationOutput(*conv_op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(conv_op)); + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + auto strides = helper.Get("strides", std::vector{1, 1}); + auto dilations = helper.Get("dilations", std::vector{1, 1}); + auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + const auto group = helper.Get("group", static_cast(1)); + + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + + const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); + std::vector weight_shape = {weight_tensor.dims().cbegin(), weight_tensor.dims().cend()}; + + const bool is_1d_conv = (weight_shape.size() == 3); + + // add dummy 'W' dim with value of 1 so we can use 2D conv. + if (is_1d_conv) { + input_shape.push_back(1); + weight_shape.push_back(1); + + // Strides/dilations for 1d conv is normally of length 1. Expand them by 1 + // to meet the required length 2 (for 2d conv it's normally 2) + if (strides.size() < 2) { + ORT_RETURN_IF_NOT(strides.size() == 1, "strides size does not equal 1 for Conv 1d"); + strides.push_back(1); + } + + if (dilations.size() < 2) { + ORT_RETURN_IF_NOT(dilations.size() == 1, "dilations size does not equal 1 for Conv 1d"); + dilations.push_back(1); + } + + // Similarly 1d conv normally has a length 2 padding. Expand it to length 4 by adding additional zeros. + if (onnx_pads.size() < 4) { + ORT_RETURN_IF_NOT(onnx_pads.size() == 2, "onnx_pads size does not equal 2 for Conv 1d"); + onnx_pads.insert(onnx_pads.begin() + 1, 0); + onnx_pads.push_back(0); + } + } - // Add bias if present - if (input_defs.size() > 2) { - coreml_conv->set_hasbias(true); - const auto& bias_tensor = *model_builder.GetInitializerTensors().at(input_defs[2]->Name()); - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_bias(), bias_tensor)); - } + auto* coreml_conv = layer->mutable_convolution(); - if (is_1d_conv) { - std::string conv_output_name = model_builder.GetUniqueName(node.Name() + "_conv_output"); - *layer->mutable_input()->Add() = expand_output_name; - *layer->mutable_output()->Add() = conv_output_name; - model_builder.AddLayer(std::move(layer)); - - // Add a squeeze layer here. Since CoreML only supports 2d conv and we expanded the dimension by 1 before, - // we need to squeeze it back from NxCxHx1->NxCxH. - const auto squeeze_layer_name = model_builder.GetUniqueName(MakeString(node.Name(), "_Conv_squeeze")); - std::unique_ptr squeeze_layer = CreateNNLayer(squeeze_layer_name); - squeeze_layer->mutable_squeeze()->add_axes(-1); - *squeeze_layer->mutable_input()->Add() = conv_output_name; - *squeeze_layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(squeeze_layer)); - } else { - *layer->mutable_input()->Add() = input_name; - *layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(layer)); + std::string expand_output_name = model_builder.GetUniqueName(node.Name() + "_expandDims"); + + if (is_1d_conv) { + // Add an expanddims layer here. CoreML only supports 2d convolution, so for 1d Conv case + // we need to add an additional dimension here to the input to make it "2d Conv" like. + // NxCxH -> NxCxHx1 + auto expand_layer = model_builder.CreateNNLayer(node, "_Conv_expand"); + expand_layer->mutable_expanddims()->add_axes(-1); + *expand_layer->mutable_input()->Add() = input_name; + *expand_layer->mutable_output()->Add() = expand_output_name; + model_builder.AddLayer(std::move(expand_layer)); + } + + coreml_conv->set_outputchannels(weight_shape[0]); // M + coreml_conv->set_kernelchannels(weight_shape[1]); // C/Group + coreml_conv->add_kernelsize(weight_shape[2]); // H + coreml_conv->add_kernelsize(weight_shape[3]); // W + coreml_conv->set_ngroups(group); + *coreml_conv->mutable_stride() = {strides.cbegin(), strides.cend()}; + *coreml_conv->mutable_dilationfactor() = {dilations.cbegin(), dilations.cend()}; + + coreml_conv->set_isdeconvolution(false); + + // Add Padding + // Usually using autopadding is more efficient than using explicit padding + // Try to see if we can map explicit padding to auto padding + AutoPadType auto_pad_type; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], + onnx_pads, strides, dilations, + StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), + auto_pad_type)); + + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + auto* padding_type = coreml_conv->mutable_same(); + if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER + padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + } + } else { + auto* padding_type = coreml_conv->mutable_valid(); + if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { + // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts + auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + height_border->set_startedgesize(onnx_pads[0]); + height_border->set_endedgesize(onnx_pads[2]); + auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + width_border->set_startedgesize(onnx_pads[1]); + width_border->set_endedgesize(onnx_pads[3]); + } + } + + // Add weight + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_weights(), weight_tensor)); + + // Add bias if present + if (input_defs.size() > 2) { + coreml_conv->set_hasbias(true); + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_bias(), bias_tensor)); + } + + if (is_1d_conv) { + std::string conv_output_name = model_builder.GetUniqueName(node.Name() + "_conv_output"); + *layer->mutable_input()->Add() = expand_output_name; + *layer->mutable_output()->Add() = conv_output_name; + model_builder.AddLayer(std::move(layer)); + + // Add a squeeze layer here. Since CoreML only supports 2d conv and we expanded the dimension by 1 before, + // we need to squeeze it back from NxCxHx1->NxCxH. + auto squeeze_layer = model_builder.CreateNNLayer(node, "_Conv_squeeze"); + squeeze_layer->mutable_squeeze()->add_axes(-1); + *squeeze_layer->mutable_input()->Add() = conv_output_name; + *squeeze_layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(squeeze_layer)); + } else { + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(layer)); + } } return Status::OK(); } -#endif - -// Operator support related bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -186,23 +297,73 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara const auto& input_defs = node.InputDefs(); const auto& weight_name = input_defs[1]->Name(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (Contains(initializers, weight_name)) { - const auto& tensor = *initializers.at(weight_name); - if (tensor.dims().size() != 4 && tensor.dims().size() != 3) { - LOGS(logger, VERBOSE) << "Conv [" << name << "] dimension: " << tensor.dims().size() - << " Only conv 2d and conv 1d are supported."; + const auto* weight = input_params.graph_viewer.GetConstantInitializer(weight_name, true); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram) { + // ML Program supports non-const weight, 1D, 2D and 3D. + // keep to 1D and 2D for consistency with the NeuralNetwork implementation for now. + // add 3D support as/when needed. + } else +#endif // defined (COREML_ENABLE_MLPROGRAM) + { + if (!weight) { + LOGS(logger, VERBOSE) << "The weight of Conv [" << name << "] must be a constant initializer"; return false; } - } else { - LOGS(logger, VERBOSE) << "The weight of Conv [" << name << "] must be known"; + } + + // use the weight for the shape as it should always be known + const auto* weight_shape = input_defs[1]->Shape(); + int64_t num_dims = weight_shape ? weight_shape->dim_size() : -1; + + // ONNX spec requires N and C as first 2 dims + if (num_dims != 3 && num_dims != 4) { + LOGS(logger, VERBOSE) << "Conv [" << name << "] is " << num_dims - 2 << "D. " + << "Only 1D and 2D Conv are supported currently."; return false; } - if (input_defs.size() > 2) { - const auto& bias_name = input_defs[2]->Name(); - if (!Contains(initializers, bias_name)) { - LOGS(logger, VERBOSE) << "The bias of Conv [" << name << "] must be a constant initializer"; + if (input_defs.size() > 2 && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name(), true)) { + LOGS(logger, VERBOSE) << "The bias of Conv [" << name << "] must be a constant initializer"; + return false; + } + + NodeAttrHelper helper(node); + +#if defined(COREML_ENABLE_MLPROGRAM) + // spec says same_lower is supported in CoreML 5. it lies. CoreML 6 is required otherwise you get + // `Unexpected value for parameter pad_type[0] "same_lower" not in ("custom", "same", "valid").` + // We _could_ manually calculate the pads, but not implementing that until we have a real use case to justify + // the effort as it's not clear how common usage of same_lower is. + if (input_params.create_mlprogram && input_params.coreml_version < 6) { + if (StringToAutoPadType(helper.Get("auto_pad", "NOTSET")) == AutoPadType::SAME_LOWER) { + LOGS(logger, VERBOSE) << "Pad type of SAME_LOWER [" << name << "] is not supported until CoreML 6." + << "Available version is CoreML " << input_params.coreml_version; + return false; + } + } +#endif + + // there's no equivalent to allow a manual kernel shape in CoreML. + // it's OK if a specified kernel_shape matches kH and kW dims of the weight input. + auto kernel_shape = helper.GetInt64s("kernel_shape"); + if (kernel_shape) { + bool valid = true; + if (static_cast(kernel_shape->size()) == num_dims - 2) { + for (int i = 0; i < num_dims - 2; ++i) { + // check the specified kernel shape matches the weight shape. skip the initial N and C dims in the latter. + if ((*kernel_shape)[i] != weight_shape->dim()[i + 2].dim_value()) { + valid = false; + break; + } + } + } else { + valid = false; + } + + if (!valid) { + LOGS(logger, VERBOSE) << "Conv [" << name << "] kernel_shape attribute does not match the weight shape"; return false; } } diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc index a4ad1c31b5027..1eba312b2577b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -4,37 +4,26 @@ #include "core/common/safeint.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class DepthToSpaceOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_defs = node.InputDefs(); const auto& output_defs = node.OutputDefs(); @@ -54,9 +43,6 @@ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc index b303fe7884cb1..f0adb70587bcf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc @@ -3,39 +3,26 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class FlattenOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ - Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + std::unique_ptr layer = model_builder.CreateNNLayer(node); // Note: ONNX Flatten corresponds to CoreML FlattenTo2DLayerParams auto* coreml_flatten = layer->mutable_flattento2d(); @@ -51,9 +38,6 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related bool FlattenOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc index 9c7ec306ca093..7d32675e3e510 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc @@ -2,34 +2,24 @@ // Licensed under the MIT License. #include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime::coreml { class GatherOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related -#if defined(__APPLE__) namespace { int64_t GetAxisAttribute(const Node& node) { NodeAttrHelper node_attr_helper{node}; @@ -38,8 +28,8 @@ int64_t GetAxisAttribute(const Node& node) { } // namespace Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - auto layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + auto layer = model_builder.CreateNNLayer(node); layer->mutable_gather()->set_axis(GetAxisAttribute(node)); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); // data *layer->mutable_input()->Add() = node.InputDefs()[1]->Name(); // indices @@ -47,10 +37,9 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif // defined(__APPLE__) -// Operator support related -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { +bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 71b08db6d44d8..48f77354d7c30 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -7,38 +7,25 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/impl/builder_utils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class GemmOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, const logging::Logger& /* logger */) const override; }; -// Add operator related - -#ifdef __APPLE__ void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& op = node.OpType(); const auto& input_defs(node.InputDefs()); @@ -71,7 +58,7 @@ static Status GetTensorFloatDataTransposed(const ONNX_NAMESPACE::TensorProto& te Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); @@ -120,9 +107,6 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool GemmOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc index ba12600e8bc40..99d6f01cb8c5b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc @@ -7,30 +7,20 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class PadOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -64,9 +54,6 @@ static InlinedVector GetPaddingAxesData(const InitializedTensorSet& ini return axes_tensor_data; } -// Add operator related - -#ifdef __APPLE__ void PadOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // pads model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); // constant_value @@ -78,7 +65,7 @@ void PadOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_pad = layer->mutable_padding(); auto* constant_padding_type = coreml_pad->mutable_constant(); // CoreML::Specification::PaddingLayerParams_PaddingConstant @@ -122,9 +109,6 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related bool PadOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc index fd1c77c851e6f..01aced739b36d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc @@ -4,38 +4,27 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/impl/builder_utils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class PoolOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_pool = layer->mutable_pooling(); const auto& op_type = node.OpType(); @@ -108,9 +97,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif -// Operator support related bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, const logging::Logger& logger) const { const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc index 6a2014e7952a2..32378b1f654d8 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc @@ -1,36 +1,27 @@ // Copyright (c) Shukant Pal. // Licensed under the MIT License. +#include "core/optimizer/initializer.h" #include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" - -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" -#include "core/optimizer/initializer.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class ReductionOpBuilder : public BaseOpBuilder { -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - private: + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -#ifdef __APPLE__ namespace { template void AddReductionParams(T* params, const std::vector& axes, bool keepdims, bool noop_with_empty_axes) { @@ -76,7 +67,7 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co const bool keepdims = helper.Get("keepdims", 1) != 0; const bool noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0) != 0; - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "ReduceSum") { AddReductionParams(layer->mutable_reducesum(), axes, keepdims, noop_with_empty_axes); @@ -93,7 +84,6 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -124,4 +114,4 @@ void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations } } // namespace coreml -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc index 67aee73630cdb..7ae1746be3122 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc @@ -6,31 +6,21 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/reshape_helper.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ReshapeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -38,9 +28,6 @@ class ReshapeOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 5; } }; -// Add operator related - -#ifdef __APPLE__ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); } @@ -48,7 +35,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); @@ -69,9 +56,6 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 5f963dc30dd8f..35dcde41a6bcf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -8,31 +8,21 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/reshape_helper.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ResizeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -41,7 +31,7 @@ class ResizeOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } }; -// Helper functions +namespace { bool GetResizeScales(const InitializedTensorSet& initializers, const Node& node, std::vector& scales, const logging::Logger&) { @@ -73,10 +63,8 @@ bool GetResizeOutputSizes(const InitializedTensorSet& initializers, sizes = std::vector(sizes_data.begin(), sizes_data.end()); return true; } +} // namespace -// Add operator related - -#ifdef __APPLE__ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { // We don't really use ROI here, so add it to skipped list if it's an initializer tensor model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // ROI @@ -96,7 +84,7 @@ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_upsample = layer->mutable_upsample(); NodeAttrHelper helper(node); @@ -131,9 +119,6 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc index fd64153ffd283..a86e3d9538d87 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc @@ -2,44 +2,30 @@ // Licensed under the MIT License. #include "core/providers/coreml/builders/impl/base_op_builder.h" - +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" // for NodeAttrHelper -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime::coreml { class ShapeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related -#if defined(__APPLE__) Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - auto layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + auto layer = model_builder.CreateNNLayer(node); layer->mutable_getshape(); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif // defined(__APPLE__) -// Operator support related bool ShapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { NodeAttrHelper node_attr_helper{node}; diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index 2c250b3cc9f5a..b716af738e1b1 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -1,39 +1,31 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/optimizer/initializer.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/slice_helper.h" #include "core/providers/shared/utils/utils.h" -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime::coreml { class SliceOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: int GetMinSupportedOpSet(const Node& /* node */) const override { // Before Slice-10, some inputs were attributes instead. We don't support that for now. return 10; } - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& builder_params, const logging::Logger& logger) const override; }; @@ -107,9 +99,6 @@ bool ValidateSliceComputeMetadataForCoreML(const SliceOp::PrepareForComputeMetad } } // namespace -// Add operator related -#if defined(__APPLE__) - void SliceOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& input_defs = node.InputDefs(); @@ -132,7 +121,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const ORT_RETURN_IF_ERROR(PrepareSliceComputeMetadataFromConstantInitializers(node, model_builder.GetGraphViewer(), compute_metadata)); - auto layer = CreateNNLayer(model_builder, node); + auto layer = model_builder.CreateNNLayer(node); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); auto* slice_static = layer->mutable_slicestatic(); @@ -163,10 +152,8 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } -#endif // defined(__APPLE__) - -// Operator support related -bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { +bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc index c454a2a779f6e..266396a0fe90e 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc @@ -1,43 +1,29 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" -#include "core/providers/coreml/shape_utils.h" -#include "core/providers/shared/utils/utils.h" - -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/shape_utils.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class SoftmaxOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ - Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_name = node.InputDefs()[0]->Name(); const auto& output_name = node.OutputDefs()[0]->Name(); @@ -68,9 +54,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto reshape1_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "reshape1_output")); { // Add reshape layer - const auto softmax_reshape1_layer_name = - model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape1")); - auto reshape_layer = CreateNNLayer(softmax_reshape1_layer_name); + auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape1"); *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; *reshape_layer->mutable_input()->Add() = input_name; *reshape_layer->mutable_output()->Add() = reshape1_output_name; @@ -86,9 +70,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } { // Add reshape back layer - const auto softmax_reshape2_layer_name = - model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape2")); - auto reshape_layer = CreateNNLayer(softmax_reshape2_layer_name); + auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape2"); *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {data_shape.cbegin(), data_shape.cend()}; *reshape_layer->mutable_input()->Add() = softmax_output_name; *reshape_layer->mutable_output()->Add() = output_name; @@ -99,10 +81,6 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related - bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 56c87c883156b..0497357c45c54 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -1,35 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/optimizer/initializer.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class SplitOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -37,10 +26,6 @@ class SplitOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; } }; -// Add operator related - -#ifdef __APPLE__ - void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& input_defs = node.InputDefs(); @@ -63,7 +48,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // attribute introduced since opset 18 uint64_t num_outputs; - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_splitnd = layer->mutable_splitnd(); coreml_splitnd->set_axis(axis); @@ -82,7 +67,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, coreml_splitnd->set_numsplits(num_outputs); } else { // note: for opset 18+ 'num_outputs' is a required attribute - num_outputs = narrow(helper.GetInt("num_outputs").value()); + num_outputs = narrow(helper.GetInt64("num_outputs").value()); // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; uint64_t chunk_size = narrow((split_dim_size + num_outputs - 1) / num_outputs); @@ -111,10 +96,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related - bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); @@ -159,7 +140,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar } } else { if (node.SinceVersion() >= 18) { - const auto num_outputs = helper.GetInt("num_outputs"); + const auto num_outputs = helper.GetInt64("num_outputs"); if (!num_outputs.has_value()) { LOGS(logger, VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; return false; @@ -169,9 +150,10 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar << "CoreML SplitND requires at least 2 outputs. num_outputs: " << num_outputs.value(); return false; } - if (num_outputs.value() != static_cast(node.OutputDefs().size()) || num_outputs.value() > split_dims_at_axis) { - LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n." - << "The value should be smaller or equal to the size of dimension being split. num_outputs: " + if (num_outputs.value() != static_cast(node.OutputDefs().size()) || + num_outputs.value() > split_dims_at_axis) { + LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n. The value should be smaller or equal to the size " + "of dimension being split. num_outputs: " << num_outputs.value(); return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc index 2e14c85ce69c1..e9cc1c2dbf638 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc @@ -1,48 +1,30 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include + +#include "core/common/safeint.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/optimizer/initializer.h" - -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" +#include "core/optimizer/initializer.h" namespace onnxruntime { namespace coreml { class SqueezeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ -void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { - if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) { - model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); - } -} - -/* static */ Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector& axes) { +namespace { +Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector& axes) { // Squeeze opset 13 use input as axes if (node.SinceVersion() > 12) { // If axes is not provided, return an empty axes as default to squeeze all @@ -62,11 +44,18 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const return Status::OK(); } +} // namespace + +void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) { + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); + } +} Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_squeeze = layer->mutable_squeeze(); std::vector axes; @@ -84,9 +73,6 @@ Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool SqueezeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& /*logger*/) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc index 7d5018a19f74c..f6a61d55a3d63 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc @@ -3,33 +3,23 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class TransposeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif }; -// Add operator related - -#ifdef __APPLE__ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); NodeAttrHelper helper(node); std::vector perm = helper.Get("perm", std::vector()); @@ -51,7 +41,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index 660755b43c043..3403378d59114 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -3,32 +3,25 @@ #include "core/providers/common.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace coreml { class UnaryOpBuilder : public BaseOpBuilder { - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif }; -#ifdef __APPLE__ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "Sqrt") { layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT); @@ -45,9 +38,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); @@ -55,4 +45,4 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op } } // namespace coreml -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 9c8b7bce507e4..daab36f7b933d 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -2,56 +2,555 @@ // Licensed under the MIT License. #include -#include - -#include "model_builder.h" -#include "helper.h" -#include "op_builder_factory.h" +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/platform/env.h" #include "core/providers/common.h" +#include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/coreml_provider_factory.h" #include "core/providers/coreml/model/host_utils.h" -#include "core/providers/coreml/model/model.h" #include "core/providers/coreml/shape_utils.h" +#if defined(COREML_ENABLE_MLPROGRAM) +// includes from coremltools-src in _deps +#include "modelpackage/src/ModelPackage.hpp" +#include "mlmodel/src/MILBlob/Blob/StorageWriter.hpp" +using MILBlob::Blob::StorageWriter; +#endif + +using namespace CoreML::Specification; + namespace onnxruntime { namespace coreml { -ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, uint32_t coreml_flags) - : graph_viewer_(graph_viewer), - logger_(logger), - coreml_flags_(coreml_flags) { +namespace { +#if defined(COREML_ENABLE_MLPROGRAM) +// Should the initializer be written to file or kept as an immediate value +bool ShouldWriteInitializerToWeightsFile(const ONNX_NAMESPACE::TensorProto& tensor_proto) { + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/load.py#L51-L57 + + bool use_weight_file = false; + + switch (tensor_proto.data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + auto num_elements = TensorShape(utils::GetTensorShapeFromTensorProto(tensor_proto)).Size(); + use_weight_file = num_elements >= 10; + break; + } + default: + break; + } + + return use_weight_file; +} + +// copy from the ONNX TensorProto to a CoreML field. +// T1 is the source type. T2 is the target type. If the types differ, T1 must be smaller than T2. +// e.g. uint32_t data can be written to RepeatedField +template +void CopyRawDataToRepeatedField(const ONNX_NAMESPACE::TensorProto& tensor_proto, + google::protobuf::RepeatedField& repeated_field) { + const auto& raw_data = tensor_proto.raw_data(); + const T1* data = reinterpret_cast(raw_data.data()); + const T1* data_end = data + (raw_data.size() / sizeof(T1)); + if constexpr (sizeof(T1) == sizeof(T2)) { + repeated_field.Add(data, data_end); + } else { + static_assert(sizeof(T1) < sizeof(T2)); + // we need to iterate over the data and copy to the repeated field, converting to T2 as we go. + repeated_field.Resize(data_end - data, T2(0)); + for (int i = 0; data != data_end; ++data, ++i) { + repeated_field[i] = static_cast(*data); + } + } +} + +// copy T data from the TensorProto.int32_t field to TensorValue.bytes +template +void CopyInt32DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue tensor_value) { + const int num_entries = tensor_proto.int32_data_size(); + std::string& bytes = *tensor_value.mutable_bytes()->mutable_values(); + bytes.resize(num_entries * sizeof(T)); + T* out = reinterpret_cast(bytes.data()); + + const int32_t* in = tensor_proto.int32_data().data(); + for (int i = 0; i < num_entries; ++i) { + out[i] = static_cast(in[i]); + } +} + +// copy T data from the TensorProto.uint64_data field to TensorValue.bytes +template +void CopyUInt64DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue tensor_value) { + const int num_entries = tensor_proto.uint64_data_size(); + std::string& bytes = *tensor_value.mutable_bytes()->mutable_values(); + bytes.resize(num_entries * sizeof(T)); + T* out = reinterpret_cast(bytes.data()); + + const uint64_t* in = tensor_proto.uint64_data().data(); + for (int i = 0; i < num_entries; ++i) { + out[i] = static_cast(in[i]); + } +} + +// NOTE: This supports all the ONNX data types. Weights in CoreML may not need all these +void CopyOnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + MILSpec::TensorValue& tensor_value) { + bool has_raw_data = tensor_proto.has_raw_data(); + auto data_type = tensor_proto.data_type(); + + // handling based on + // ONNX TensorProto field usage + // https://github.com/onnx/onnx/blob/b86cc54efce19530fb953e4b21f57e6b3888534c/onnx/onnx.proto#L544-L572 + // CoreMLTools conversion implementation that maps data types to fields + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L98 + // along with some special cased types that are stored in bytes + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L23 + // IMMEDIATE_VALUE_TYPES_IN_BYTES = (types.fp16, types.int8, types.uint8, types.uint32) + + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + // from: float_data/raw, to: floats + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_floats()->mutable_values()); + } else { + tensor_value.mutable_floats()->mutable_values()->CopyFrom(tensor_proto.float_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { + // from: double_data/raw, to: doubles + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_doubles()->mutable_values()); + } else { + tensor_value.mutable_doubles()->mutable_values()->CopyFrom(tensor_proto.double_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + // from: int32_data/raw, to: ints + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_ints()->mutable_values()); + } else { + tensor_value.mutable_ints()->mutable_values()->CopyFrom(tensor_proto.int32_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + // from: int64_data/raw, to: longints + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); + + } else { + tensor_value.mutable_longints()->mutable_values()->CopyFrom(tensor_proto.int64_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data(); + } else { + // iterate the int32_data, taking the 16-bits from each entry, and copying to the bytes. + // we use uint16_t as only the size of the data type matters + CopyInt32DataToBytes(tensor_proto, tensor_value); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data(); + } else { + // copy from int32_data to bytes. uint8_t for both as only the size of the data type matters when copying + CopyInt32DataToBytes(tensor_proto, tensor_value); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { + // from: uint64_data/raw, to: bytes + if (has_raw_data) { + *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data(); + } else { + // copy uint32_t values from TensorProto.uint64_data + CopyUInt64DataToBytes(tensor_proto, tensor_value); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { + // from: uint64_data/raw, to: longints + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); + } else { + // TODO: Is this safe? Need to check the CopyFrom implementation. As it's a straight copy of bytes this + // hopefully can do it as one block instead of iterating and potentially doing a static_cast of each + // individual value. + tensor_value.mutable_longints()->mutable_values()->CopyFrom( + reinterpret_cast&>(tensor_proto.uint64_data())); + } + + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: { + // from: int32_data/raw, to: bools + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_bools()->mutable_values()); + } else { + const auto& int32s = tensor_proto.int32_data(); + auto& bools = *tensor_value.mutable_bools()->mutable_values(); + const int num_entries = int32s.size(); + bools.Reserve(num_entries); + const int32_t* in = int32s.data(); + for (int i = 0; i < num_entries; ++i) { + *bools.AddAlreadyReserved() = *in++; + } + } + + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_STRING: { + // from: string_data (which is protobuf type bytes), to: strings (protobuf type string) + // due to the protobuf type mismatch we need to iterate and copy + auto& in = tensor_proto.string_data(); + auto& out = *tensor_value.mutable_strings()->mutable_values(); + out.Reserve(in.size()); + for (const auto& iter : in) { + *out.Add() = iter; + } + + break; + } + /* Not clear if there's an actual use-case for 16-bit int data currently, so leaving commented out + case ONNX_NAMESPACE::TensorProto_DataType_INT16: + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + // from: int32_data/raw, to: ints + // WARNING: This may change to write to mutable_bytes + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L113-L115 + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_ints()->mutable_values()); + } else { + tensor_value.mutable_ints()->mutable_values()->CopyFrom(tensor_proto.int32_data()); + } + break; + } */ + default: + ORT_THROW("AddTensorProtoDataToMILSpecTensorValue: Unsupported data type: ", data_type); + } +} + +template +uint64_t WriteRawDataUsingStorageWriter(const onnx::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& writer) { + MILBlob::Util::Span data(reinterpret_cast(tensor_proto.raw_data().data()), + tensor_proto.raw_data().size() / sizeof(T)); + return writer.WriteData(data); +} + +// Write T1 data from the TensorProto.int32_data field using StorageWriter. +// Currently int32_data can have any of these data types: +// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, +// FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ +// T1 provides the size of the ONNX data type. T2 is the CoreML type. +// The sizes and layout of T1 and T2 must match as we simply cast the bytes to T2. +template +uint64_t WriteFromInt32DataUsingStorageWriter(const onnx::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& writer) { + static_assert(sizeof(T1) == sizeof(T2), "Data sizes must match"); + + // need to copy to temporary data as we have to extract a subset of bytes from each int32_t entry. + // works better to extract the ONNX type first with static_cast, and reinterpret_cast to the CoreML type at the end. + std::vector values; + const int num_values = tensor_proto.int32_data_size(); + values.resize(num_values); // resize so we're not updating the length inside the copy loop + + const int32_t* in = tensor_proto.int32_data().data(); + for (int i = 0; i < num_values; ++i) { + values[i] = static_cast(in[i]); + } + + MILBlob::Util::Span data(reinterpret_cast(values.data()), + num_values); + return writer.WriteData(data); +} + +// write the initializer to weight.bin and return the offset +// StorageWriter is currently limited to fp32, fp16, bfloat16, uint8/int8, uint16/int16. +// AFAIK we don't use bfloat16/int16/uint16 for weights in ONNX, so limit handling to fp32, fp16, uint8/int8 +uint64_t CopyOnnxTensorToCoreMLWeightsFile(const onnx::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& writer) { + bool has_raw_data = tensor_proto.has_raw_data(); + auto data_type = tensor_proto.data_type(); + + uint64_t offset = 0; + + // See AddTensorProtoDataToMILSpecTensorValue for links to sources for info on where the different typed data is + // stored for ONNX and CoreML + + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + // from: float_data/raw, to: floats + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + } else { + MILBlob::Util::Span data(tensor_proto.float_data().data(), tensor_proto.float_data().size()); + offset = writer.WriteData(data); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + } else { + offset = WriteFromInt32DataUsingStorageWriter(tensor_proto, writer); + } + + break; + } + + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + } else { + offset = WriteFromInt32DataUsingStorageWriter(tensor_proto, writer); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + + } else { + offset = WriteFromInt32DataUsingStorageWriter(tensor_proto, writer); + } + break; + } + default: + ORT_THROW("AddWeightToFile: Unsupported data type: ", data_type); + } + + return offset; +} + +MILSpec::Value OnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& weights_file_writer) { + MILSpec::Value value; + + // populate ValueType with tensor data type, dims and rank + MILSpec::ValueType& value_type = *value.mutable_type(); + MILSpec::TensorType& tensor_type = *value_type.mutable_tensortype(); + tensor_type.set_datatype(OnnxDataTypeToMILSpec(tensor_proto.data_type())); + + tensor_type.set_rank(tensor_proto.dims().size()); + for (const auto& dim : tensor_proto.dims()) { + tensor_type.add_dimensions()->mutable_constant()->set_size(dim); + } + + // add data to either weights.bin or as an immediate value + if (ShouldWriteInitializerToWeightsFile(tensor_proto)) { + uint64_t offset = CopyOnnxTensorToCoreMLWeightsFile(tensor_proto, weights_file_writer); + + auto* file_value = value.mutable_blobfilevalue(); + // Filename copied from + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L329 + file_value->set_filename("@model_path/weights/weight.bin"); + file_value->set_offset(offset); + } else { + MILSpec::TensorValue& tensor_value = *value.mutable_immediatevalue()->mutable_tensor(); + CopyOnnxTensorToCoreMLTensor(tensor_proto, tensor_value); + } + + return value; +} + +void CreateEmptyFile(const std::string& filename) { + std::ofstream file(filename, std::ofstream::out | std::ofstream::binary); + ORT_ENFORCE(file.is_open(), "Failed to open file ", filename); } -Status ModelBuilder::Initialize() { - coreml_model_ = std::make_unique(); - { // initialize CoreML model +#endif // defined(COREML_ENABLE_MLPROGRAM) + +std::string GetModelOutputPath(bool create_ml_program) { + // path is used to create the ML Package directory for ML Program, and for the model directly otherwise. + auto path = util::GetTemporaryFilePath(); + if (!create_ml_program) { + path += ".model.mlmodel"; + } + + return path; +} +} // namespace + +ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags) + : graph_viewer_(graph_viewer), + logger_(logger), + coreml_version_(coreml_version), + coreml_flags_(coreml_flags), + create_ml_program_((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0), + model_output_path_(GetModelOutputPath(create_ml_program_)), + coreml_model_(std::make_unique()) { + if (create_ml_program_) { +#if defined(COREML_ENABLE_MLPROGRAM) + coreml_model_->set_specificationversion(CoreMLSpecVersion()); + MILSpec::Program& mlprogram = *coreml_model_->mutable_mlprogram(); + MILSpec::Function& main = (*mlprogram.mutable_functions())["main"]; + + const std::string coreml_opset = "CoreML" + std::to_string(CoreMLVersion()); + *main.mutable_opset() = coreml_opset; + mlprogram_main_ = &(*main.mutable_block_specializations())[coreml_opset]; + + // create the ModelPackage. this creates the output directory. + mlpackage_ = std::make_unique(model_output_path_, /* create */ true); + + // ModelPackage::addItem does a copy of the file. Due to this we 'add' an empty file first, + // and do the actual writes to the file created in the package. + // We can't use ModelPackage::createFile as we have to add a directory for the weights. + std::string tmp_dir = model_output_path_ + "/tmp"; + ORT_THROW_IF_ERROR(Env::Default().CreateFolder(ToPathString(tmp_dir))); + CreateEmptyFile(tmp_dir + "/weight.bin"); + + std::string weights_id = mlpackage_->addItem(tmp_dir, "weights", "com.microsoft.OnnxRuntime", + "CoreML Model Weights"); + auto weights_info = mlpackage_->findItem(weights_id); + weights_file_writer_ = std::make_unique(weights_info->path() + "/weight.bin"); +#else + // should never happen due to handling in coreml_execution_provider.cc + ORT_THROW("ML Program is not enabled in this build"); +#endif + } else { // We support CorelML Specification Version 4 (Core ML 3) coreml_model_->set_specificationversion(4); auto* neural_network = coreml_model_->mutable_neuralnetwork(); - neural_network->set_arrayinputshapemapping(::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); + neural_network->set_arrayinputshapemapping( + CoreML::Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); } +} - PreprocessInitializers(); - ORT_RETURN_IF_ERROR(RegisterInitializers()); - ORT_RETURN_IF_ERROR(RegisterModelInputs()); - ORT_RETURN_IF_ERROR(AddOperations()); - ORT_RETURN_IF_ERROR(RegisterModelOutputs()); +ModelBuilder::~ModelBuilder() = default; - return Status::OK(); +/* + * NeuralNetwork related helpers + */ +std::unique_ptr ModelBuilder::CreateNNLayer(const Node& node, std::string_view suffix) { + auto layer_name = GetUniqueName(node, suffix); + + std::unique_ptr layer = std::make_unique(); + layer->set_name(layer_name); + return layer; +} + +void ModelBuilder::AddLayer(std::unique_ptr layer) { + auto* neural_network = coreml_model_->mutable_neuralnetwork(); + neural_network->mutable_layers()->AddAllocated(layer.release()); } -/* static */ const IOpBuilder* ModelBuilder::GetOpBuilder(const Node& node) { - const auto& op_builders = GetOpBuilders(); - const auto it = op_builders.find(node.OpType()); - if (it != op_builders.cend()) - return it->second; +#if defined(COREML_ENABLE_MLPROGRAM) + +/* + * ML Program related helpers + */ +std::unique_ptr ModelBuilder::CreateOperation(const Node& node, + std::string_view op_type, + std::string_view suffix) { + std::string operation_name = GetUniqueName(node, suffix); + + std::unique_ptr op = std::make_unique(); + op->set_type(std::string(op_type)); + (*op->mutable_attributes())["name"] = CreateScalarTensorValue(operation_name); + + return op; +} + +void ModelBuilder::AddConstant(std::string_view name, const ONNX_NAMESPACE::TensorProto& initializer) { + MILSpec::Value coreml_tensor = OnnxTensorToCoreMLTensor(initializer, *weights_file_writer_); + AddConstantOperation(name, std::move(coreml_tensor)); +} + +void ModelBuilder::AddConstantOperation(std::string_view name, MILSpec::Value&& coreml_tensor) { + // Replicates coremltools/converters/mil/backend/mil/load.py translate_const logic + MILSpec::Operation& const_op = *mlprogram_main_->mutable_operations()->Add(); + const_op.set_type("const"); + + MILSpec::NamedValueType& output = *const_op.mutable_outputs()->Add(); + output.set_name(std::string(name)); + *output.mutable_type() = coreml_tensor.type(); + + auto& attr_map = *const_op.mutable_attributes(); + attr_map["name"] = CreateScalarTensorValue(std::string(name)); + attr_map["val"] = std::move(coreml_tensor); +} + +// Add operation to the Block for the main function in the ML Program +void ModelBuilder::AddOperation(std::unique_ptr operation) { + mlprogram_main_->mutable_operations()->AddAllocated(operation.release()); +} + +std::string ModelBuilder::AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, + MILSpec::Value&& input_value) { + auto unique_value_name = GetUniqueName(MakeString(op_type, "_", value_type)); + AddConstantOperation(unique_value_name, std::move(input_value)); + return unique_value_name; +} + +template +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape) { + // add specialization below + static_assert(false_for_T, "Missing specialization for value type"); + return ""; // unreachable +} + +template <> +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + +template <> +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); // CoreML uses int32 + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + +template <> +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} - return nullptr; +template <> +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } +#endif // defined(COREML_ENABLE_MLPROGRAM) + +/* + * General implementation + */ void ModelBuilder::PreprocessInitializers() { - // TODO: We should be using GetConstantInitializer not GetAllInitializedTensors in all places + // TODO: We should be using GetConstantInitializer not GetAllInitializedTensors in all places. + // non-constant initializers need to be passed in as model inputs in case they're overridden at runtime. const auto& initializers = graph_viewer_.GetAllInitializedTensors(); const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); @@ -64,6 +563,7 @@ void ModelBuilder::PreprocessInitializers() { initializer_usage_[input->Name()]++; } } + if (const auto* op_builder = GetOpBuilder(node)) { op_builder->AddInitializersToSkip(*this, node); } @@ -77,27 +577,34 @@ Status ModelBuilder::RegisterInitializers() { // skip initializer if there is no remaining usage auto usage_count = initializer_usage_[name]; - if (usage_count == 0) + if (usage_count == 0) { continue; + } - std::unique_ptr layer = std::make_unique(); - layer->set_name(GetUniqueName("initializer_" + name)); - - // TODO,look at using LoadConstantLayer instead of LoadConstantNDLayer - auto* constant_tensor = layer->mutable_loadconstantnd(); - const auto& shape = tensor.dims(); - if (shape.empty()) { - // This is a scalar initializer, CoreML constant layer requires a shape, make this a {1} tensor - constant_tensor->mutable_shape()->Add(1); + if (create_ml_program_) { +#if defined(COREML_ENABLE_MLPROGRAM) + AddConstant(name, tensor); +#endif } else { - std::transform(shape.cbegin(), shape.cend(), - google::protobuf::RepeatedFieldBackInserter(constant_tensor->mutable_shape()), - [](int64_t dim) -> uint64_t { return SafeInt(dim); }); - } + std::unique_ptr layer = std::make_unique(); + layer->set_name(GetUniqueName("initializer_" + name)); + + // TODO,look at using LoadConstantLayer instead of LoadConstantNDLayer + auto* constant_tensor = layer->mutable_loadconstantnd(); + const auto& shape = tensor.dims(); + if (shape.empty()) { + // This is a scalar initializer, CoreML constant layer requires a shape, make this a {1} tensor + constant_tensor->mutable_shape()->Add(1); + } else { + std::transform(shape.cbegin(), shape.cend(), + google::protobuf::RepeatedFieldBackInserter(constant_tensor->mutable_shape()), + [](int64_t dim) -> uint64_t { return SafeInt(dim); }); + } - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*constant_tensor->mutable_data(), tensor)); - *layer->mutable_output()->Add() = name; - AddLayer(std::move(layer)); + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*constant_tensor->mutable_data(), tensor)); + *layer->mutable_output()->Add() = name; + AddLayer(std::move(layer)); + } } return Status::OK(); @@ -179,15 +686,15 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i data_type = type_proto->tensor_type().elem_type(); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::FLOAT32); + multi_array->set_datatype(ArrayFeatureType::FLOAT32); break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: - multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::INT32); + multi_array->set_datatype(ArrayFeatureType::INT32); break; case ONNX_NAMESPACE::TensorProto_DataType_INT64: // If we have an int64 input/output type, since COREML_SPEC:ArrayFeatureType does not support INT64 // we assign it to be INT32 here - multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::INT32); + multi_array->set_datatype(ArrayFeatureType::INT32); if (!is_input) { // Record the output names and we need to change them back to Int64 when CoreML EP returns these values to ORT AddInt64Output(name); @@ -204,6 +711,19 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i input_output_info_.emplace(name, OnnxTensorInfo{data_type, shape}); +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + MILSpec::Function& main = (*coreml_model_->mutable_mlprogram()->mutable_functions())["main"]; + if (is_input) { + // the model inputs need to be wired up as args to the 'main' function + main.mutable_inputs()->Add(CreateNamedTensorValueType(node_arg)); + } else { + // the model outputs need to be set as outputs of the Block for the 'main' function + *mlprogram_main_->mutable_outputs()->Add() = node_arg.Name(); + } + } +#endif // defined(COREML_ENABLE_MLPROGRAM) + return Status::OK(); } @@ -215,16 +735,16 @@ Status ModelBuilder::RegisterModelInputs() { return Status::OK(); } -Status ModelBuilder::AddOperations() { - const auto builder_params = MakeOpBuilderParams(graph_viewer_, coreml_flags_); - const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer_.GetNode(node_indices[i])); - if (const auto* op_builder = GetOpBuilder(*node)) { - ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, *node, builder_params, logger_)); +Status ModelBuilder::ProcessNodes() { + for (const auto node_idx : graph_viewer_.GetNodesInTopologicalOrder()) { + const auto& node = *graph_viewer_.GetNode(node_idx); + if (const auto* op_builder = GetOpBuilder(node)) { + ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, node, logger_)); } else { + // This shouldn't happen as this is called from CoreMLExecutionProvider::Compile and should only be processing + // nodes that we said were supported and were returned from CoreMLExecutionProvider::GetCapability. return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Node [", node->Name(), "], type [", node->OpType(), "] is not supported"); + "Node [", node.Name(), "], type [", node.OpType(), "] is not supported"); } } @@ -239,29 +759,72 @@ Status ModelBuilder::RegisterModelOutputs() { return Status::OK(); } -Status ModelBuilder::Compile(std::unique_ptr& model, const std::string& path) { - ORT_RETURN_IF_ERROR(SaveCoreMLModel(path)); - model.reset(new Model(path, logger_, coreml_flags_)); - model->SetScalarOutputs(std::move(scalar_outputs_)); - model->SetInt64Outputs(std::move(int64_outputs_)); - model->SetInputOutputInfo(std::move(input_output_info_)); - return model->LoadModel(); +Status ModelBuilder::CreateModel() { + PreprocessInitializers(); + + ORT_RETURN_IF_ERROR(RegisterInitializers()); + ORT_RETURN_IF_ERROR(RegisterModelInputs()); + ORT_RETURN_IF_ERROR(ProcessNodes()); + ORT_RETURN_IF_ERROR(RegisterModelOutputs()); + + return Status::OK(); } -Status ModelBuilder::SaveCoreMLModel(const std::string& path) { - ORT_RETURN_IF_ERROR(Initialize()); - std::ofstream stream(path, std::ofstream::out | std::ofstream::binary); - ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&stream), "Save the CoreML model failed"); +Status ModelBuilder::SaveModel() { + std::string output_path = model_output_path_; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + std::string tmp_model_path = model_output_path_ + "/tmp/model.mlmodel"; + CreateEmptyFile(tmp_model_path); + + std::string model_id = mlpackage_->setRootModel(tmp_model_path, "model.mlmodel", "com.microsoft.OnnxRuntime", + "CoreML Model Specification"); + auto model_info = mlpackage_->findItem(model_id); + output_path = model_info->path(); + } +#endif - // TODO, Delete, debug only - if (const char* path = std::getenv("ORT_COREML_EP_CONVERTED_MODEL_PATH")) { - std::ofstream temp_stream(path, std::ofstream::out | std::ofstream::binary); - ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&temp_stream), "Save the CoreML model failed"); + // scope this so the stream is closed and flushed by the ofstream dtor + { + LOGS(logger_, INFO) << "Writing CoreML Model to " << output_path; + std::ofstream stream(output_path, std::ofstream::out | std::ofstream::binary); + ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&stream), "Saving the CoreML model failed. Path=", output_path); } +#if defined(COREML_ENABLE_MLPROGRAM) + // need to delete the ModelPackage instance for it to write out the manifest. clear out the other ML Program + // related types as well. + mlprogram_main_ = nullptr; + mlpackage_.reset(); + weights_file_writer_.reset(); +#endif + return Status::OK(); } +Status ModelBuilder::LoadModel(std::unique_ptr& model) { + model = std::make_unique(model_output_path_, + std::move(input_output_info_), + std::move(scalar_outputs_), + std::move(int64_outputs_), + logger_, coreml_flags_); + + return model->LoadModel(); // load using CoreML API, including compilation +} + +// static +Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags, + std::unique_ptr& model) { + ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_flags); + + ORT_RETURN_IF_ERROR(builder.CreateModel()); + ORT_RETURN_IF_ERROR(builder.SaveModel()); + + return builder.LoadModel(model); +} + void ModelBuilder::AddScalarOutput(const std::string& output_name) { scalar_outputs_.insert(output_name); } @@ -270,11 +833,6 @@ void ModelBuilder::AddInt64Output(const std::string& output_name) { int64_outputs_.insert(output_name); } -void ModelBuilder::AddLayer(std::unique_ptr layer) { - auto* neural_network = coreml_model_->mutable_neuralnetwork(); - neural_network->mutable_layers()->AddAllocated(layer.release()); -} - void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { // decrement usage count if this is a known initializer. // For simplicity the OpBuilder::AddInitializersToSkip implementations may call this for arbitrary input names @@ -289,7 +847,7 @@ void ModelBuilder::AddInputToSkip(const std::string& input_name) { skipped_inputs_.insert(input_name); } -std::string ModelBuilder::GetUniqueName(const std::string& base_name) { +std::string ModelBuilder::GetUniqueName(std::string_view base_name) { std::string unique_name; do { std::ostringstream os; @@ -300,5 +858,12 @@ std::string ModelBuilder::GetUniqueName(const std::string& base_name) { return unique_name; } +std::string ModelBuilder::GetUniqueName(const Node& node, std::string_view suffix) { + if (node.Name().empty()) { + return GetUniqueName(MakeString("Node_", node.Index(), "_", node.OpType(), suffix)); + } else { + return GetUniqueName(node.Name() + std::string(suffix)); + } +} } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index af2d5437be8d1..961ba647257b5 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -3,57 +3,171 @@ #pragma once +#include "core/common/span_utils.h" #include "core/graph/graph_viewer.h" #include "core/providers/coreml/builders/coreml_spec.h" +#include "core/providers/coreml/model/model.h" + +#if defined(COREML_ENABLE_MLPROGRAM) +// coremltools classes +namespace MPL { +class ModelPackage; +} + +namespace MILBlob { +namespace Blob { +class StorageWriter; +} +} // namespace MILBlob +#endif namespace onnxruntime { namespace coreml { class IOpBuilder; class Model; -struct OnnxTensorInfo; class ModelBuilder { + private: + ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags); + public: - ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, uint32_t coreml_flags); - ~ModelBuilder() = default; + // Create the CoreML model, serialize to disk, load and compile using the CoreML API and return in `model` + static Status Build(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags, + std::unique_ptr& model); - Status Compile(std::unique_ptr& model, const std::string& path); - Status SaveCoreMLModel(const std::string& path); + ~ModelBuilder(); - // Accessors for members const GraphViewer& GetGraphViewer() const { return graph_viewer_; } const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); } - + const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name) const { + return graph_viewer_.GetConstantInitializer(name, true); + } + + // Since CoreML 2 the spec version is +1 as CoreML 1.1 was spec version 2. + // We only support CoreML 3 and later so the spec version is always version + 1. + int32_t CoreMLVersion() const { return coreml_version_; } + int32_t CoreMLSpecVersion() const { return coreml_version_ + 1; } + + // Returns true if we are creating an ML Program + bool CreateMLProgram() const { +#if defined(COREML_ENABLE_MLPROGRAM) + return create_ml_program_; +#else + return false; +#endif + } + + /* + * NeuralNetworkLayer helpers + */ + + // Create a NeuralNetwork layer using the node name and optional suffix for the name. + // If Node has no name a unique name will be generated from the node index and operator. + std::unique_ptr CreateNNLayer(const Node& node, std::string_view suffix = ""); + + // Add layer to the Core ML NeuralNetwork model void AddLayer(std::unique_ptr layer); - // The initializer will be processed separately, skip it as an initializer +#if defined(COREML_ENABLE_MLPROGRAM) + /* + * MLProgram helpers + */ + + // Create Operation, set type and the unique name attribute. + std::unique_ptr CreateOperation(const Node& node, std::string_view op_type, + std::string_view suffix = ""); + + // + // Helpers for adding attributes from ONNX nodes as inputs to an ML Program Operation + // + + /// + /// Add a value as a 'const' operation, generating a unique name for the value from op_type and value_type. + /// Use for values that were not initializers in the original ONNX model. e.g. attributes from ONNX nodes. + /// Add existing initializers using AddConstant with the TensorProto. + /// + /// e.g. adding the bias input of Gemm would have op_type='gemm' and value_type='bias'. + /// + /// Value type. + /// Typically MILSpec::Operation.type(). + /// Typically the input name of the operation that will consume the value. + /// Value to add. + /// Optional shape for the value. + /// If T is a primitive type `shape` is ignored and the value is treated as a scalar. + /// For a container type, if `shape` is not provided the shape is inferred to be 1-D of {value.size()}. + /// + /// Unique name generated for value. + template + std::string AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape = std::nullopt) { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, + // add specialization in AddConstantImpl for new types if needed + "AddConstant currently supports float, int64_t, std::string and bool."); + return AddConstantImpl(op_type, value_type, value, shape); + } + + template + std::string AddConstant(std::string_view op_type, std::string_view value_type, const std::vector& value, + std::optional> shape = std::nullopt) { + return AddConstant(op_type, value_type, AsSpan(value), shape); + } + + /// + /// Add a scalar value as a 'const' operation. See AddConstant for details. + /// + template + std::string AddScalarConstant(std::string_view op_type, std::string_view value_type, const T& value) { + return AddConstant(op_type, value_type, AsSpan({value}), AsSpan({})); + } + + /// + /// Add an existing a constant ONNX initializer to the ML Program as a 'const' operation + /// + /// Initializer name + /// Initializer data + void AddConstant(std::string_view name, const ONNX_NAMESPACE::TensorProto& initializer); + + // add the operation to the main function + void AddOperation(std::unique_ptr operation); +#endif + + /* + * General helpers + */ + + // The initializer is processed separately (e.g. layout is transformed) by the operator builder, + // so we don't do a copy of the original initializer into the model. void AddInitializerToSkip(const std::string& tensor_name); // There are some input which will not be used, add it to a list which will not // be added to CoreML model, since CoreML does not like input unused void AddInputToSkip(const std::string& input_name); - std::string GetUniqueName(const std::string& base_name); + std::string GetUniqueName(std::string_view base_name); + std::string GetUniqueName(const Node& node, std::string_view suffix); private: - const GraphViewer& graph_viewer_; - const logging::Logger& logger_; - uint32_t coreml_flags_; - - std::unique_ptr coreml_model_; - std::unordered_set scalar_outputs_; - std::unordered_set int64_outputs_; - std::unordered_map input_output_info_; - - std::unordered_map initializer_usage_; - std::unordered_set skipped_inputs_; - - uint32_t name_token_{0}; - std::unordered_set unique_names_; - - // Convert the onnx model to CoreML::Specification::Model - Status Initialize(); +#if defined(COREML_ENABLE_MLPROGRAM) + template + std::string AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape = std::nullopt); + + void AddConstantOperation(std::string_view name, COREML_SPEC::MILSpec::Value&& initializer); + std::string AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, + COREML_SPEC::MILSpec::Value&& input_value); +#endif + + // Convert the ONNX model in graph_viewer_ to a CoreML::Specification::Model and serialize to disk. + // We then load it using CoreML in order compile it. + Status CreateModel(); + Status SaveModel(); + Status LoadModel(std::unique_ptr& model); // If a CoreML operation will use initializers directly, we will add the initializers to the skip list void PreprocessInitializers(); @@ -61,7 +175,7 @@ class ModelBuilder { // Copy and process all the initializers to CoreML model Status RegisterInitializers(); - Status AddOperations(); + Status ProcessNodes(); Status RegisterModelInputs(); Status RegisterModelOutputs(); Status RegisterModelInputOutput(const NodeArg& node_arg, bool is_input); @@ -72,7 +186,32 @@ class ModelBuilder { // Record the onnx int64 type output names void AddInt64Output(const std::string& output_name); - static const IOpBuilder* GetOpBuilder(const Node& node); + const GraphViewer& graph_viewer_; + const logging::Logger& logger_; + const int32_t coreml_version_; + const uint32_t coreml_flags_; + const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old) + const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel + + std::unique_ptr coreml_model_; + std::unordered_set scalar_outputs_; + std::unordered_set int64_outputs_; + std::unordered_map input_output_info_; + + std::unordered_map initializer_usage_; + std::unordered_set skipped_inputs_; + + uint32_t name_token_{0}; + std::unordered_set unique_names_; + +#if defined(COREML_ENABLE_MLPROGRAM) + // mlprogram_main_ is the main block of the CoreML ML Program. + // It is set in CreateModel to the CoreML Model.mlprogram.functions['main'].block_specializations['CoreML'] + // entry we create. + COREML_SPEC::MILSpec::Block* mlprogram_main_{nullptr}; + std::unique_ptr mlpackage_; + std::unique_ptr weights_file_writer_; +#endif }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/op_builder.h b/onnxruntime/core/providers/coreml/builders/op_builder.h index 79de6438c9700..0bb7f280c33e6 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder.h @@ -11,36 +11,39 @@ namespace coreml { class ModelBuilder; struct OpBuilderInputParams { - OpBuilderInputParams(const GraphViewer& graph_viewer, bool only_allow_static_input_shapes) + OpBuilderInputParams(const GraphViewer& graph_viewer, + int32_t coreml_version, + bool only_allow_static_input_shapes, + bool create_mlprogram) : graph_viewer(graph_viewer), - only_allow_static_input_shapes(only_allow_static_input_shapes) {} + coreml_version(coreml_version), + only_allow_static_input_shapes(only_allow_static_input_shapes), + create_mlprogram(create_mlprogram) {} const GraphViewer& graph_viewer; + const int32_t coreml_version; // required to determine which version of an operation can be used. const bool only_allow_static_input_shapes; + const bool create_mlprogram; // whether to create ML Program (Core ML 5+) or NeuralNetwork (Core ML 3+) }; class IOpBuilder { public: virtual ~IOpBuilder() = default; - // Add operator related -#ifdef __APPLE__ - public: // Check if the initializers of this operator need preprocess // which will not be copied virtual void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const = 0; // Add the operator to CoreML model virtual Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node, - const OpBuilderInputParams& input_params, const logging::Logger& logger) const = 0; -#endif - // Operator support related - public: // Check if an operator is supported virtual bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const = 0; + + // Does the builder implementation support creating an ML Program? + virtual bool SupportsMLProgram() const = 0; }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index d72420bcfff88..6469b4cefa5ea 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -3,7 +3,7 @@ #pragma once -#include "op_builder.h" +#include "core/providers/coreml/builders/op_builder.h" namespace onnxruntime { namespace coreml { diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index c133f7b82aba4..8e718da07703c 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -2,9 +2,11 @@ // Licensed under the MIT License. #include "core/providers/coreml/coreml_execution_provider.h" +#include "core/providers/coreml/coreml_provider_factory.h" // defines flags #include +#include "core/common/logging/logging.h" #include "core/framework/compute_capability.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_viewer.h" @@ -12,12 +14,10 @@ #include "core/providers/partitioning_utils.h" #include "core/session/onnxruntime_cxx_api.h" -#ifdef __APPLE__ #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/model/host_utils.h" #include "core/providers/coreml/model/model.h" #include "core/providers/coreml/shape_utils.h" -#endif namespace onnxruntime { @@ -25,7 +25,24 @@ constexpr const char* COREML = "CoreML"; CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags) : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider}, - coreml_flags_(coreml_flags) { + coreml_flags_(coreml_flags), + coreml_version_(coreml::util::CoreMLVersion()) { + if (coreml_version_ < MINIMUM_COREML_VERSION) { + LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform."; + } + +#if defined(COREML_ENABLE_MLPROGRAM) + if (coreml_version_ < MINIMUM_COREML_MLPROGRAM_VERSION && + (coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) { + LOGS_DEFAULT(WARNING) << "ML Program is not supported on this OS version. Falling back to NeuralNetwork."; + coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM; + } +#else + if ((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) { + LOGS_DEFAULT(WARNING) << "ML Program is not supported in this build. Falling back to NeuralNetwork."; + coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM; + } +#endif } CoreMLExecutionProvider::~CoreMLExecutionProvider() {} @@ -35,28 +52,34 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; - // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes - // TODO investigate whether we want to support subgraph using CoreML EP - if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) { + if (coreml_version_ < MINIMUM_COREML_VERSION) { return result; } const auto& logger = *GetLogger(); + // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes + // TODO investigate whether we want to support subgraph using CoreML EP. May simply require processing the + // implicit inputs of the control flow node that contains the subgraph as inputs to the CoreML model we generate. + if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) { + return result; + } + const bool has_neural_engine = coreml::HasNeuralEngine(logger); if ((coreml_flags_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) && !has_neural_engine) { - LOGS(logger, VERBOSE) << "The current system does not have Apple Neural Engine"; + LOGS(logger, WARNING) << "The current system does not have Apple Neural Engine. CoreML EP will not be used."; return result; } - const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_flags_); + const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_, coreml_flags_); const auto supported_nodes = coreml::GetSupportedNodes(graph_viewer, builder_params, logger); - const auto gen_metadef_name = [&]() { - HashValue model_hash; - int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); - return MakeString(COREML, "_", model_hash, "_", metadef_id); - }; + const auto gen_metadef_name = + [&]() { + HashValue model_hash; + int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + return MakeString(COREML, "_", model_hash, "_", metadef_id); + }; result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {}, gen_metadef_name, COREML, kCoreMLExecutionProvider); @@ -86,17 +109,16 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return result; } -#ifdef __APPLE__ +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) common::Status CoreMLExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { Node& fused_node = fused_node_and_graph.fused_node; const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - coreml::ModelBuilder builder(graph_viewer, *GetLogger(), coreml_flags_); std::unique_ptr coreml_model; - const std::string coreml_model_file_path = coreml::util::GetTemporaryFilePath(); - ORT_RETURN_IF_ERROR(builder.Compile(coreml_model, coreml_model_file_path)); + ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_flags_, + coreml_model)); { const auto& input_defs = fused_node.InputDefs(); @@ -241,22 +263,6 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { - for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { - ORT_UNUSED_PARAMETER(fused_node_and_graph); - NodeComputeInfo compute_info; - compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) { return 0; }; - compute_info.release_state_func = [](FunctionState /*state*/) {}; - compute_info.compute_func = [](FunctionState /* state */, const OrtApi* /* api */, - OrtKernelContext* /* context */) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Compute is not supported in this build."); - }; - node_compute_funcs.push_back(compute_info); - } - return Status::OK(); -} -#endif //__APPLE__ +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 0201739547dd1..24a001280eef5 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -3,9 +3,9 @@ #pragma once +#include "core/common/inlined_containers.h" #include "core/framework/execution_provider.h" #include "core/framework/model_metadef_id_generator.h" -#include "core/providers/coreml/coreml_provider_factory.h" namespace onnxruntime { namespace coreml { @@ -26,15 +26,14 @@ class CoreMLExecutionProvider : public IExecutionProvider { std::vector& node_compute_funcs) override; #endif + private: // The bit flags which define bool options for COREML EP, bits are defined as // COREMLFlags in include/onnxruntime/core/providers/coreml/coreml_provider_factory.h - const uint32_t coreml_flags_; - - private: -// > -#ifdef __APPLE__ - std::unordered_map> coreml_models_; -#endif + uint32_t coreml_flags_; + const int32_t coreml_version_; ModelMetadefIdGenerator metadef_id_generator_; + + // map of fused_node_name to compiled_coreml_model + InlinedHashMap> coreml_models_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/host_utils.h b/onnxruntime/core/providers/coreml/model/host_utils.h index f7f45bce087bc..4f9a014c4d885 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.h +++ b/onnxruntime/core/providers/coreml/model/host_utils.h @@ -8,10 +8,50 @@ #include -#define API_AVAILABLE_OS_VERSIONS API_AVAILABLE(macos(10.15), ios(13)) +#if defined(__APPLE__) +// See https://apple.github.io/coremltools/mlmodel/Format/Model.html for the info on each CoreML specification version. +// See https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html for the list of ops +// in each CoreML specification version. -// Base requireed OS to run CoreML Specification Version 4 (Core ML 3) -#define HAS_VALID_BASE_OS_VERSION @available(macOS 10.15, iOS 13, *) +// Specification Versions : OS Availability(Core ML Version) +// +// 4 : iOS 13, macOS 10.15, tvOS 13, watchOS 6 (Core ML 3) +// - initial version of CoreML EP +// 5 : iOS 14, macOS 11, tvOS 14, watchOS 7 (Core ML 4) +// - additional layers in NeuralNetwork but currently none are implemented by the CoreML EP +// 6 : iOS 15, macOS 12, tvOS 15, watchOS 8 (Core ML 5) +// - adds MLProgram (MILSpec.Program) +// - iOS 15 ops +// 7 : iOS 16, macOS 13, tvOS 16, watchOS 9 (Core ML 6) +// - iOS 16 ops +// 8 : iOS 17, macOS 14, tvOS 17, watchOS 10 (Core ML 7) +// - iOS 17 ops +// +// **NOTE** We use the Core ML version not the spec version. +// +// e.g. iOS 13 has Core ML 3 (which is Core ML Specification version 4), and the related macros are +// API_AVAILABLE_COREML3, HAS_COREML3_OR_LATER and onnxruntime::coreml::util::CoreMLVersion() will return 3. + +// https://developer.apple.com/documentation/swift/marking-api-availability-in-objective-c +// API_AVAILABLE is used to decorate Objective-C APIs +#define API_AVAILABLE_COREML3 API_AVAILABLE(macos(10.15), ios(13)) +#define API_AVAILABLE_COREML4 API_AVAILABLE(macos(11), ios(14)) +#define API_AVAILABLE_COREML5 API_AVAILABLE(macos(12), ios(15)) +#define API_AVAILABLE_COREML6 API_AVAILABLE(macos(13), ios(16)) +#define API_AVAILABLE_COREML7 API_AVAILABLE(macos(14), ios(17)) + +// @available is used in implementation code +// Base required OS to run CoreML Specification Version 4 (Core ML 3) +#define HAS_COREML3_OR_LATER @available(macOS 10.15, iOS 13, *) +#define HAS_COREML4_OR_LATER @available(macOS 11, iOS 14, *) +#define HAS_COREML5_OR_LATER @available(macOS 12, iOS 15, *) +#define HAS_COREML6_OR_LATER @available(macOS 13, iOS 16, *) +#define HAS_COREML7_OR_LATER @available(macOS 14, iOS 17, *) + +#endif + +#define MINIMUM_COREML_VERSION 3 // first version we support +#define MINIMUM_COREML_MLPROGRAM_VERSION 5 // first version where ML Program was available namespace onnxruntime { namespace coreml { @@ -21,6 +61,9 @@ namespace util { // This corresponds to [CoreML Specification Version 4 (Core ML 3)] bool HasRequiredBaseOS(); +// Return the CoreML version if 3 or higher. Otherwise returns -1. +int CoreMLVersion(); + // Get a temporary macOS/iOS temp file path std::string GetTemporaryFilePath(); diff --git a/onnxruntime/core/providers/coreml/model/host_utils.mm b/onnxruntime/core/providers/coreml/model/host_utils.mm index 4c394386cd37a..0ae0cf8f0d207 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.mm +++ b/onnxruntime/core/providers/coreml/model/host_utils.mm @@ -10,19 +10,33 @@ namespace util { bool HasRequiredBaseOS() { - // This may look strange, but it is required "@available(macOS ....)" to safe-guard some code - // otherwise the compiler will spit -Wunsupported-availability-guard - if (HAS_VALID_BASE_OS_VERSION) - return true; - else - return false; + return CoreMLVersion() >= 3; +} + +int32_t CoreMLVersion() { + if (HAS_COREML7_OR_LATER) + return 7; + if (HAS_COREML6_OR_LATER) + return 6; + if (HAS_COREML5_OR_LATER) + return 5; + if (HAS_COREML4_OR_LATER) + return 4; + if (HAS_COREML3_OR_LATER) + return 3; + + return -1; } std::string GetTemporaryFilePath() { - // Get temporary directory. + // Get temporary directory for user. NSURL* temporary_directory_url = [NSURL fileURLWithPath:NSTemporaryDirectory() isDirectory:YES]; // Generate a Unique file name to use. NSString* temporary_filename = [[NSProcessInfo processInfo] globallyUniqueString]; + + // make it easy to see who generated it + temporary_filename = [@"onnxruntime-" stringByAppendingString:temporary_filename]; + // Create URL to that file. NSURL* temporary_file_url = [temporary_directory_url URLByAppendingPathComponent:temporary_filename]; diff --git a/onnxruntime/core/providers/coreml/model/host_utils_stub.cc b/onnxruntime/core/providers/coreml/model/host_utils_stub.cc new file mode 100644 index 0000000000000..5c383b0274e8c --- /dev/null +++ b/onnxruntime/core/providers/coreml/model/host_utils_stub.cc @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/platform/env.h" +#include "core/providers/coreml/model/host_utils.h" + +namespace onnxruntime { +namespace coreml { +namespace util { + +bool HasRequiredBaseOS() { + return true; +} + +int CoreMLVersion() { + return 7; // CoreML 7 is the latest we support. +} + +std::string GetTemporaryFilePath() { + static std::atomic counter = 0; + + // we want to avoid creating endless directories/names whilst avoiding clashes if tests run in parallel so cycle + // through 20 potential output names. + auto dir_name = "coreml_ep_test_run." + std::to_string(counter++ % 20); + + // to replicate the iOS/macOS host_utils.mm behavior where the output is / + // we want to return the name of something that does not exist. this is required for ML Package creation. + auto& env = Env::Default(); + if (env.FolderExists(dir_name)) { + ORT_THROW_IF_ERROR(env.DeleteFolder(ToPathString(dir_name))); + } + + return dir_name; +} + +} // namespace util +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index 105b6a0333b15..b940c4b768aec 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -33,19 +33,29 @@ using GetOutputTensorMutableRawDataFn = std::function static_shape)>; class Model { - friend class ModelBuilder; - public: + Model(const std::string& path, + std::unordered_map&& input_output_info, + std::unordered_set&& scalar_outputs, + std::unordered_set&& int64_outputs, + const logging::Logger& logger, uint32_t coreml_flags); + ~Model(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model); + Status LoadModel(); + Status Predict(const std::unordered_map& inputs, const std::unordered_map& outputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn); - bool IsScalarOutput(const std::string& output_name) const; + bool IsScalarOutput(const std::string& output_name) const { + return Contains(scalar_outputs_, output_name); + } - bool IsInt64Output(const std::string& output_name) const; + bool IsInt64Output(const std::string& output_name) const { + return Contains(int64_outputs_, output_name); + } // Mutex for exclusive lock to this model object OrtMutex& GetMutex() { return mutex_; } @@ -57,35 +67,27 @@ class Model { const std::vector& GetOnnxOutputs() const { return onnx_outputs_; } void SetOnnxOutputs(std::vector&& outputs) { onnx_outputs_ = std::move(outputs); } - const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const; - const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const; + const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const { + const auto info_it = input_output_info_.find(name); + return info_it != input_output_info_.end() ? &info_it->second : nullptr; + } + + const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const { + const auto* info = TryGetInputOutputInfo(name); + ORT_ENFORCE(info != nullptr, "Failed to get info for input/output: ", name); + return *info; + } private: std::unique_ptr execution_; + std::unordered_map input_output_info_; std::unordered_set scalar_outputs_; std::unordered_set int64_outputs_; std::vector onnx_inputs_; std::vector onnx_outputs_; - std::unordered_map input_output_info_; - OrtMutex mutex_; - - Model(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags); - Status LoadModel(); - - void SetInputOutputInfo(std::unordered_map&& input_output_info) { - input_output_info_ = std::move(input_output_info); - } - - void SetScalarOutputs(std::unordered_set&& scalar_outputs) { - scalar_outputs_ = std::move(scalar_outputs); - } - - void SetInt64Outputs(std::unordered_set&& int64_outputs) { - int64_outputs_ = std::move(int64_outputs); - } }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 155201ad4c39c..d5cd70bff9479 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -252,14 +252,14 @@ - (instancetype)initWithPath:(const std::string&)path coreml_flags:(uint32_t)coreml_flags; - (void)cleanup; - (void)dealloc; -- (Status)loadModel API_AVAILABLE_OS_VERSIONS; +- (Status)loadModel API_AVAILABLE_COREML3; - (Status)predict:(const std::unordered_map&)inputs outputs:(const std::unordered_map&)outputs getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&) get_output_tensor_mutable_raw_data_fn - API_AVAILABLE_OS_VERSIONS; + API_AVAILABLE_COREML3; -@property(nullable) MLModel* model API_AVAILABLE_OS_VERSIONS; +@property(nullable) MLModel* model API_AVAILABLE_COREML3; @end @@ -308,6 +308,10 @@ - (Status)loadModel { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path"); } + // TODO: Update this to version with callback handler as the API used here is deprecated. + // https://developer.apple.com/documentation/coreml/mlmodel/3929553-compilemodelaturl + // As we call loadModel during EP Compile there shouldn't be an issue letting the actual compile run in the + // background. We will have to check for completion in `predict` and block until it is done. NSError* error = nil; NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; @@ -454,7 +458,7 @@ Status Predict(const std::unordered_map& inputs, return Status::OK(); } - if (HAS_VALID_BASE_OS_VERSION) { + if (HAS_COREML3_OR_LATER) { Status status{}; @autoreleasepool { status = [execution_ loadModel]; @@ -471,7 +475,7 @@ Status Predict(const std::unordered_map& inputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) { ORT_RETURN_IF_NOT(model_loaded, "Execution::Predict requires Execution::LoadModel"); - if (HAS_VALID_BASE_OS_VERSION) { + if (HAS_COREML3_OR_LATER) { @autoreleasepool { return [execution_ predict:inputs outputs:outputs @@ -482,8 +486,16 @@ Status Predict(const std::unordered_map& inputs, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::Predict requires macos 10.15+ or ios 13+"); } -Model::Model(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) - : execution_(std::make_unique(path, logger, coreml_flags)) { +Model::Model(const std::string& path, + std::unordered_map&& input_output_info, + std::unordered_set&& scalar_outputs, + std::unordered_set&& int64_outputs, + const logging::Logger& logger, + uint32_t coreml_flags) + : execution_(std::make_unique(path, logger, coreml_flags)), + input_output_info_(std::move(input_output_info)), + scalar_outputs_(std::move(scalar_outputs)), + int64_outputs_(std::move(int64_outputs)) { } Model::~Model() {} @@ -497,25 +509,5 @@ Status Predict(const std::unordered_map& inputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) { return execution_->Predict(inputs, outputs, get_output_tensor_mutable_raw_data_fn); } - -bool Model::IsScalarOutput(const std::string& output_name) const { - return Contains(scalar_outputs_, output_name); -} - -bool Model::IsInt64Output(const std::string& output_name) const { - return Contains(int64_outputs_, output_name); -} - -const OnnxTensorInfo* Model::TryGetInputOutputInfo(const std::string& name) const { - const auto info_it = input_output_info_.find(name); - return info_it != input_output_info_.end() ? &info_it->second : nullptr; -} - -const OnnxTensorInfo& Model::GetInputOutputInfo(const std::string& name) const { - const auto* info = TryGetInputOutputInfo(name); - ORT_ENFORCE(info != nullptr, "Failed to get info for input/output: ", name); - return *info; -} - } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/model_stub.cc b/onnxruntime/core/providers/coreml/model/model_stub.cc new file mode 100644 index 0000000000000..087c9f8c05d5f --- /dev/null +++ b/onnxruntime/core/providers/coreml/model/model_stub.cc @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/coreml/model/model.h" + +namespace onnxruntime { +namespace coreml { + +class Execution {}; + +Model::Model(const std::string& /*path*/, + std::unordered_map&& input_output_info, + std::unordered_set&& scalar_outputs, + std::unordered_set&& int64_outputs, + const logging::Logger& /*logger*/, + uint32_t /*coreml_flags*/) + : execution_(std::make_unique()), + input_output_info_(std::move(input_output_info)), + scalar_outputs_(std::move(scalar_outputs)), + int64_outputs_(std::move(int64_outputs)) { +} + +Model::~Model() { +} + +Status Model::LoadModel() { + // return OK so we hit more CoreML EP code. + return Status::OK(); +} + +Status Model::Predict(const std::unordered_map& /*inputs*/, + const std::unordered_map& /*outputs*/, + const GetOutputTensorMutableRawDataFn& /*get_output_tensor_mutable_raw_data_fn*/) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Executing a CoreML model is not supported on this platform."); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc index b2225643b788e..edee298ad1ccf 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc @@ -67,7 +67,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const int32_t num_outputs; if (node_unit.SinceVersion() >= 18) { - num_outputs = SafeInt(*helper.GetInt("num_outputs")); + num_outputs = SafeInt(*helper.GetInt64("num_outputs")); } else { num_outputs = SafeInt(node_unit.Outputs().size()); } @@ -127,7 +127,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const No } else { uint32_t num_outputs; if (node_unit.SinceVersion() >= 18) { - auto num_outputs_attr = helper.GetInt("num_outputs"); + auto num_outputs_attr = helper.GetInt64("num_outputs"); if (!num_outputs_attr.has_value()) { LOGS_DEFAULT(VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; return false; diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 37ad14ac2e9b1..c07a0929353b1 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -118,84 +118,134 @@ NodeAttrHelper::NodeAttrHelper(const NodeUnit& node_unit) : node_attributes_(node_unit.GetNode().GetAttributes()) {} float NodeAttrHelper::Get(const std::string& key, float def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return entry->second.f(); + } - return node_attributes_.at(key).f(); + return def_val; } int32_t NodeAttrHelper::Get(const std::string& key, int32_t def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return narrow(entry->second.i()); + } - return SafeInt(node_attributes_.at(key).i()); + return def_val; } uint32_t NodeAttrHelper::Get(const std::string& key, uint32_t def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return narrow(entry->second.i()); + } - return SafeInt(node_attributes_.at(key).i()); + return def_val; } int64_t NodeAttrHelper::Get(const std::string& key, int64_t def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return entry->second.i(); + } - return node_attributes_.at(key).i(); + return def_val; } const std::string& NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return entry->second.s(); + } - return node_attributes_.at(key).s(); + return def_val; } std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& attr = entry->second; + std::vector v; + v.reserve(static_cast(attr.ints_size())); + std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v), + [](int64_t val) -> int32_t { return narrow(val); }); + return v; + } - const auto& attr(node_attributes_.at(key)); - std::vector v; - v.reserve(static_cast(attr.ints_size())); - std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v), - [](int64_t val) -> int32_t { return SafeInt(val); }); - return v; + return def_val; } std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& attr = entry->second; + std::vector v; + v.reserve(static_cast(attr.ints_size())); + std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v), + [](int64_t val) -> uint32_t { return narrow(val); }); + return v; + } - const auto& attr(node_attributes_.at(key)); - std::vector v; - v.reserve(static_cast(attr.ints_size())); - std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v), - [](int64_t val) -> uint32_t { return SafeInt(val); }); - return v; + return def_val; } std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.ints(); + return std::vector{values.cbegin(), values.cend()}; + } - const auto& source(node_attributes_.at(key).ints()); - return std::vector{source.cbegin(), source.cend()}; + return def_val; } std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.floats(); + return std::vector{values.cbegin(), values.cend()}; + } - const auto& source(node_attributes_.at(key).floats()); - return std::vector{source.cbegin(), source.cend()}; + return def_val; } -std::optional NodeAttrHelper::GetInt(const std::string& key) const { - if (!HasAttr(key)) - return std::nullopt; - return node_attributes_.at(key).i(); +std::optional NodeAttrHelper::GetFloat(const std::string& key) const { + std::optional result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + result = entry->second.f(); + } + + return result; +} + +std::optional NodeAttrHelper::GetInt64(const std::string& key) const { + std::optional result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + result = entry->second.i(); + } + + return result; +} + +std::optional> NodeAttrHelper::GetFloats(const std::string& key) const { + std::optional> result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.floats(); + result = std::vector(values.begin(), values.end()); + } + + return result; +} + +std::optional> NodeAttrHelper::GetInt64s(const std::string& key) const { + std::optional> result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.ints(); + result = std::vector(values.begin(), values.end()); + } + + return result; +} + +std::optional NodeAttrHelper::GetString(const std::string& key) const { + std::optional result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + result = entry->second.s(); + } + + return result; } bool NodeAttrHelper::HasAttr(const std::string& key) const { diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h index 31b1aba2e1a63..5813dcc48d72b 100644 --- a/onnxruntime/core/providers/shared/utils/utils.h +++ b/onnxruntime/core/providers/shared/utils/utils.h @@ -47,15 +47,17 @@ class NodeAttrHelper { // Get the attributes from the target node of the node_unit explicit NodeAttrHelper(const NodeUnit& node_unit); + /* + * Get with default + */ float Get(const std::string& key, float def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; int64_t Get(const std::string& key, int64_t def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; const std::string& Get(const std::string& key, const std::string& def_val) const; - std::vector Get(const std::string& key, const std::vector& def_val) const; - std::vector Get(const std::string& key, const std::vector& def_val) const; - // Convert the i() or ints() of the attribute from int64_t to int32_t int32_t Get(const std::string& key, int32_t def_val) const; std::vector Get(const std::string& key, const std::vector& def_val) const; @@ -64,7 +66,16 @@ class NodeAttrHelper { uint32_t Get(const std::string& key, uint32_t def_val) const; std::vector Get(const std::string& key, const std::vector& def_val) const; - std::optional GetInt(const std::string& key) const; + /* + * Get without default. + */ + std::optional GetFloat(const std::string& key) const; + std::optional> GetFloats(const std::string& key) const; + + std::optional GetInt64(const std::string& key) const; + std::optional> GetInt64s(const std::string& key) const; + + std::optional GetString(const std::string& key) const; bool HasAttr(const std::string& key) const; diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index a94f7b5b707c7..40b40136af1af 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -208,12 +208,18 @@ std::unique_ptr DefaultRocmExecutionProvider(bool test_tunab } std::unique_ptr DefaultCoreMLExecutionProvider() { -// For any non - macOS system, CoreML will only be used for ort model converter -// Make it unavailable here, you can still manually append CoreML EP to session for model conversion + // To manually test CoreML model generation on a non-macOS platform, comment out the `&& defined(__APPLE__)` below. + // The test will create a model but execution of it will obviously fail. + // To test creating an ML Program, set the environment variable COREML_EP_TEST_MLPROGRAM to any value. #if defined(USE_COREML) && defined(__APPLE__) // We want to run UT on CPU only to get output value without losing precision uint32_t coreml_flags = 0; coreml_flags |= COREML_FLAG_USE_CPU_ONLY; + + if (!Env::Default().GetEnvironmentVar("COREML_EP_TEST_MLPROGRAM").empty()) { + coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM; + } + return CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider(); #else return nullptr; From a67e6925468effd1897c2f541821d32a2860a037 Mon Sep 17 00:00:00 2001 From: rui-ren Date: Wed, 14 Feb 2024 15:07:56 -0800 Subject: [PATCH 084/207] add GatherSliceToSplitFusion and Unittest (#19218) ### Multi Query Attention Optimization in multi-query attention ``` batch_size, seq_length, three_times_hidden_size = fused_qkv.shape fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :] ``` which can be optimized to ``` batch_size, seq_length, three_times_hidden_size = fused_qkv.shape fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim) (query, key, value) = fused_qkv.split([self.num_heads, 1, 1], dim=2) return query, key, value ``` this optimization can be validated from nsight profiling and perf benchmarking. image As such, This PR is to Optimize the `Gather/Gather/Slice` Ops to `Split` Kernel. ### Optimization Target As 2 `Gather` and 1 `Slice` Kernels are time consuming for backward prop, it would be efficient to use 1 `Split` Kernel ### Example - Before Fusion image - After Fusion image ### Perf Gain After the optimization, there will have **~7%** perf gain. > The `Transpose` Kernel can be fused too, will update it in next PR. However, after testing Transponse Ops fusion on Falcon model, there is no perf gain. Will not create a new PR. --------- Co-authored-by: ruiren --- .../core/optimizer/gather_slice_fusion.cc | 344 ++++++++++++++++++ .../core/optimizer/gather_slice_fusion.h | 32 ++ .../core/optimizer/graph_transformer_utils.cc | 2 + .../test/optimizer/graph_transform_test.cc | 139 +++++++ .../core/optimizer/graph_transformer_utils.cc | 2 + 5 files changed, 519 insertions(+) create mode 100644 onnxruntime/core/optimizer/gather_slice_fusion.cc create mode 100644 onnxruntime/core/optimizer/gather_slice_fusion.h diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc new file mode 100644 index 0000000000000..21266d356a020 --- /dev/null +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -0,0 +1,344 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/gather_slice_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, + int64_t& axis, int64_t& indices_n_dims) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + const NodeArg& input_arg = *(node.InputDefs()[1]); + + if (!optimizer_utils::IsScalar(input_arg)) return false; + + const ONNX_NAMESPACE::TensorProto* indices_init = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + + if (!indices_init) return false; + + if (indices_init->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; + + // get the index value + Initializer init_const(*indices_init, graph.ModelPath()); + index = *(init_const.data()); + + // get attributes value + axis = 0; + auto& attrs = node.GetAttributes(); + if (attrs.find("axis") != attrs.end()) { + auto& axis_attr = attrs.at("axis"); + if (utils::HasInt(axis_attr)) axis = axis_attr.i(); + } + + indices_n_dims = indices_init->dims_size(); + return true; +} + +bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, + InlinedVector& starts, + InlinedVector& ends, + InlinedVector& axes, + InlinedVector& steps) const { + // check the version of Slice ops + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + // get the opset version + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + // If Slice op of opset version 1 + if (onnx_opset_version == 1) { + if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) || + !graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) || + starts.size() != ends.size()) { + return false; + } + + if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { + return false; + } + } + + // If Slice op of opset version >= 10 + if (onnx_opset_version >= 10) { + // node inputs include: starts - ends - axes - steps + + // return a pointer to the corresponding NodeArg if input of the node at the index exists + auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* { + const auto& input_defs = node.InputDefs(); + const NodeArg* input = (input_defs.size() > input_index) ? input_defs[input_index] : nullptr; + return (input == nullptr || !input->Exists()) ? nullptr : input; + }; + + // return a pointer to the initializer if it is constant; otherwise, a nullptr + auto get_initializer_if_constant = + [&graph, get_input_if_exists](size_t input_index) -> const ONNX_NAMESPACE::TensorProto* { + const NodeArg* input = get_input_if_exists(input_index); + return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr; + }; + + // return the initialization data if it is constant + auto get_initializer_data = + [&graph](const ONNX_NAMESPACE::TensorProto* slice_initializer) -> InlinedVector { + Initializer init(*slice_initializer, graph.ModelPath()); + if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { + int32_t* init_data = init.data(); + return InlinedVector(init_data, init_data + init.size()); + } + + if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) { + int64_t* init_data = init.data(); + return InlinedVector(init_data, init_data + init.size()); + } + return {}; + }; + + // starts and ends inputs have to exist, be constants and be of the same size. + const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1); + const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2); + const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3); + const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4); + + if (!starts_init || !ends_init || !axes_init || !steps_init) { + return false; + } + + starts = get_initializer_data(starts_init); + ends = get_initializer_data(ends_init); + axes = get_initializer_data(axes_init); + steps = get_initializer_data(steps_init); + + if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) { + return false; + } + + if (axes_init->dims_size() != 1 || static_cast(axes_init->dims().Get(0)) != starts.size()) { + return false; + } + + // if steps exists, it should be constant and all value should be 1 + if (steps.size() != starts.size()) { + return false; + } + + for (int64_t step : steps) { + if (step != 1) { + return false; + } + } + } + + return true; +} + +/* +GatherToSplitFusion is to fuse: + Node + |-> Gather(index=0, axis=axis) + |-> Gather(index=1, axis=axis) + |-> Slice(index=2, axis=axis) +To + Node + |-> Split(index=0) +So that we can use one kernel to finish the job. +*/ + +Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + InlinedVector output_args; + + // Iterate the topological order and get Reshape ops + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + + if (p_node == nullptr) continue; + + Node& node = *p_node; + + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + // Currently only catch after Reshape ops, optimize in the future + if (node.OpType() != "Reshape") continue; + + size_t output_count = node.GetOutputEdgesCount(); + + // We only catch 1 scenario for Multi Query Attention for now. + // |---> Gather + // Reshape |---> Gather + // |---> Slice + // |... or (other ops) + + // Get the output into node args + if (output_count < 3) continue; + + output_args.push_back(node.OutputDefs()[0]); + } + + // iterate the children of Reshape node + for (const NodeArg* node_arg : output_args) { + auto shape = node_arg->Shape(); + if (!shape) continue; + + auto consumers = graph.GetConsumerNodes(node_arg->Name()); + size_t consumer_count = consumers.size(); + + // get the tensor rank + int64_t rank = static_cast(shape->dim_size()); + + bool can_fuse = true; + bool first_edge = true; + int64_t split_axis = 0; + int64_t indices_n_dims = -1; + + // Fuse 2 Gathers and 1 slice to Split + // Get those outputs as Split outputs + InlinedVector split_outputs(3); + + InlinedVector> nodes_to_fuse; + size_t gather_node_count = 2, slice_node_count = 0; + + // find the nodes to be merged + for (auto consumer : consumers) { + int64_t index, axis, dims; + InlinedVector starts, ends, axes, steps; + + bool IsSupportedGatherOps = IsSupportedGather(graph, *consumer, index, axis, dims); + bool IsSupportedSliceOps = IsSupportedSlice(graph, *consumer, starts, ends, axes, steps); + + if ((!consumer || consumer->InputDefs()[0] != node_arg) || + (!IsSupportedGatherOps && !IsSupportedSliceOps)) { + break; + } + + if (IsSupportedGatherOps) { + if (indices_n_dims == -1) { + indices_n_dims = dims; + } else if (indices_n_dims != dims) { + // Not the same number of dimensions (0 or 1) for all scalar indices. + can_fuse = false; + break; + } + + if (axis < 0) axis += rank; + + if (first_edge) { + auto dim = shape->dim(static_cast(axis)); + // dim.dim_value() = 73 + if (!utils::HasDimValue(dim)) { + can_fuse = false; + break; + } + split_axis = axis; + first_edge = false; + } else if (axis != split_axis) { + can_fuse = false; + break; + } + + if (index < 0) index += static_cast(consumer_count); + if (index < 0 || index >= static_cast(consumer_count)) { + can_fuse = false; + break; + } + + Node& gather_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.push_back(gather_node); + NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; + split_outputs[gather_node_count--] = gather_output_args; + } + + // check the Slice Ops + if (IsSupportedSliceOps) { + if (axes[0] != axis && !first_edge) { + can_fuse = false; + break; + } + + Node& slice_node = *graph.GetNode(consumer->Index()); + NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; + nodes_to_fuse.push_back(slice_node); + split_outputs[slice_node_count++] = slice_output_args; + } + } + + // condition check + if (!can_fuse || gather_node_count != 0 || slice_node_count != 1) continue; + + // generate the split node and merge the kernel + ONNX_NAMESPACE::TypeProto split_output_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( + node_arg->TypeAsProto()->tensor_type().elem_type()); + + split_output_type.mutable_tensor_type()->set_elem_type(element_type); + + for (int64_t i = 0; i < rank; i++) { + if (i == split_axis) + split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); + else + *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); + } + + InlinedVector split_output_types; + + for (size_t i = 0; i < consumer_count; ++i) { + split_output_types.push_back( + &graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type)); + } + + // Generate the Split Node + ONNX_NAMESPACE::TensorProto split_initializer_proto; + split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); + split_initializer_proto.add_dims(static_cast(3)); + split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + + auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); + // Optimize 2 Gather Nodes, so Slice_dim = dim_value - 2 + int64_t slice_dim = static_cast(dim_value - 2); + InlinedVector split_value{{slice_dim, 1, 1}}; + split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); + NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + + Node& split_node = + graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", + {graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs); + + split_node.AddAttribute("axis", split_axis); + + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + if (onnx_opset_version >= 18) { + split_node.AddAttribute("num_outputs", static_cast(consumer_count)); + } + + for (Node& node_to_fuse : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, node_to_fuse); + graph.RemoveNode(node_to_fuse.Index()); + } + modified = true; + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.h b/onnxruntime/core/optimizer/gather_slice_fusion.h new file mode 100644 index 0000000000000..1c5c307efed7f --- /dev/null +++ b/onnxruntime/core/optimizer/gather_slice_fusion.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@class GatherSliceToSplitFusion +Fuse (2 Gather nodes + 1 Slice) to 1 split node. +*/ + +class GatherSliceToSplitFusion : public GraphTransformer { + private: + bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, + int64_t& indices_n_dims) const; + + bool IsSupportedSlice(const Graph& graph, const Node& node, + InlinedVector& starts, + InlinedVector& ends, + InlinedVector& axes, + InlinedVector& steps) const; + + public: + GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index cd3c49be15aa4..4e939fe3c7b6b 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -37,6 +37,7 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -308,6 +309,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index bf02c1741725f..e1fcf835c6043 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -42,6 +42,7 @@ #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -7642,5 +7643,143 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { } } +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{54}}); + auto* reshape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); + builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); + + // Create Gather-1 Ops + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); + auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose 1-Ops + auto* transpose_out_1 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Gather-2 Ops + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(-1)}); + auto* gather_out_2 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose-2 Ops + auto* transpose_out_2 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Slice Ops + auto* slice_output = builder.MakeIntermediate(); + auto* starts = builder.MakeInitializer({1}, {0}); + auto* ends = builder.MakeInitializer({1}, {-2}); + auto* axes = builder.MakeInitializer({1}, {2}); + auto* steps = builder.MakeInitializer({1}, {1}); + builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); + + // Create Shape-1 Ops + auto* shape_output_1 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_1}); + + // Create Shape-2 Ops + auto* shape_output_2 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_2}); + + // Create Transpose-3 Ops + auto* transpose_out_3 = builder.MakeOutput(); + builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(static_cast(attrs.at("axis").i()) == 2); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{54}}); + auto* reshape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); + builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); + + // Create Gather-1 Ops + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); + auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose 1-Ops + auto* transpose_out_1 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Slice Ops + auto* slice_output = builder.MakeIntermediate(); + auto* starts = builder.MakeInitializer({1}, {0}); + auto* ends = builder.MakeInitializer({1}, {-2}); + auto* axes = builder.MakeInitializer({1}, {2}); + auto* steps = builder.MakeInitializer({1}, {1}); + builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); + + // Create Shape-1 Ops + auto* shape_output_1 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_1}); + + // Create Shape-2 Ops + auto* shape_output_2 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_2}); + + // Create Transpose-3 Ops + auto* transpose_out_3 = builder.MakeOutput(); + builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 894fe3b052fb2..0b68dc65e41cd 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -24,6 +24,7 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -140,6 +141,7 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps)); // If a model with Q, DQ nodes is being used for the purpose of training, it must be for // Quantization Aware Training. So, replace QDQ nodes with FakeQuant. transformers.emplace_back(std::make_unique(compatible_eps)); From 775c774f4bdcdd57c107030e1341809b4b5ba35e Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Wed, 14 Feb 2024 18:07:51 -0800 Subject: [PATCH 085/207] Add BF16 to Sqrt (#19363) ### Description Sqrt does not have BF16 support yet. Adding that with this PR ### Motivation and Context --- docs/OperatorKernels.md | 2 +- .../core/providers/cuda/cuda_execution_provider.cc | 2 ++ .../providers/cuda/math/unary_elementwise_ops.cc | 2 +- .../cuda/math/unary_elementwise_ops_impl.cu | 2 +- .../providers/cpu/math/element_wise_ops_test.cc | 13 ++++++++++++- 5 files changed, 17 insertions(+), 4 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 2c00fe3d26752..8ff2135c6b1f6 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -784,7 +784,7 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 77e682e05a2a4..48a952e6dd98f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -989,6 +989,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sqrt); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Log); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Log); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Log); @@ -1882,6 +1883,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 655877f425054..fd8b69d7bd2f5 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -160,7 +160,7 @@ UNARY_OP_CSILHFD(Neg, 13) UNARY_OP_HFD(Floor, 13) UNARY_OP_HFD(Ceil, 13) UNARY_OP_HFD(Reciprocal, 13) -UNARY_OP_HFD(Sqrt, 13) +UNARY_OP_HFDX(Sqrt, 13) UNARY_OP_HFD(Log, 13) UNARY_OP_HFD(Exp, 13) UNARY_OP_HFD(Erf, 13) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 5c3db4a499972..73c5ac80756be 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -83,7 +83,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_CSILHFD(Neg) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Floor) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Ceil) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal) -SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Sqrt) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 5e746ed0c62d4..d35e5c78cfd69 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -5,6 +5,7 @@ #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" #include "test/common/dnnl_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" #include "core/util/math.h" #include #include @@ -786,13 +787,20 @@ TEST(MathOpTest, Sqrt_Float) { test.Run(); } -#if defined(USE_DNNL) +#if defined(USE_DNNL) || defined(USE_CUDA) TEST(MathOpTest, Sqrt_bfloat16) { #ifdef USE_DNNL if (!DnnlHasBF16Support()) { LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; return; } +#endif +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware does NOT support BFP16"; + return; + } #endif OpTester test_bf16("Sqrt", 13); // only version 13 support bf16 for sqrt test_bf16.AddInput("X", {2, 3}, @@ -804,6 +812,9 @@ TEST(MathOpTest, Sqrt_bfloat16) { std::vector> execution_providers; #if defined(USE_DNNL) execution_providers.push_back(DefaultDnnlExecutionProvider()); +#endif +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); #endif test_bf16.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } From 660f39aca5d47888804163405c64ee67eec6eed5 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 14 Feb 2024 18:35:56 -0800 Subject: [PATCH 086/207] Perf improvement for Intel MTL CPUs (#19524) ### Description See the comments inside of the changed files for more detailed information. The file onnxruntime/core/platform/windows/hardware_core_enumerator.cc and onnxruntime/core/platform/windows/hardware_core_enumerator.h were copied from WinML source folder in this repo, with minor coding style changes. I had an offline discussion with Sheil. We agree that given the lack of a future proof solution, we may check-in this temp fix first, and rework it later. I will have a meeting with @ivberg for discussing the issue deeply, and seeking for a long term solution. Thanks for offering help, @ivberg ! ### Motivation and Context With this change, we will see about 2x perf improvement on some Intel CPUs. --- onnxruntime/core/platform/windows/env.cc | 46 +++++++++- .../windows/hardware_core_enumerator.cc | 89 +++++++++++++++++++ .../windows/hardware_core_enumerator.h | 12 +++ onnxruntime/core/util/thread_utils.cc | 19 ++-- tools/ci_build/build.py | 3 +- 5 files changed, 162 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/core/platform/windows/hardware_core_enumerator.cc create mode 100644 onnxruntime/core/platform/windows/hardware_core_enumerator.h diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 1a0713db43db8..0eb34cbfbc9eb 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -32,6 +32,9 @@ limitations under the License. #include "core/common/span_utils.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" +#if defined(_M_X64) && !defined(_M_ARM64EC) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH) +#include "core/platform/windows/hardware_core_enumerator.h" +#endif #include #include @@ -248,12 +251,53 @@ void WindowsEnv::SleepForMicroseconds(int64_t micros) const { Sleep(static_cast(micros) / 1000); } +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH) +static constexpr std::array kVendorID_Intel = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" +#endif int WindowsEnv::DefaultNumCores() { return std::max(1, static_cast(std::thread::hardware_concurrency() / 2)); } int WindowsEnv::GetNumPhysicalCpuCores() const { - return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH) + // The following code is a temporary fix for a perf problem on Intel's Meteor Lake CPUs. The Intel compute platform has + // a hybrid architecture that some CPU cores runs significant slower than the others. If we distribute our compute work + // evenly to all CPU cores, the slowest CPU core will drag the performance down. So, instead, we reduce the total number + // of threads to exclude the slowest cores out. + // The following code is based on assumptions that: + // 1. All Intel hybrid CPUs should have 3 levels of cache. + // 2. If a CPU core is only associated with two levels of cache, it should be a low performance CPU core and should + // not be used. + // Since we don't know what the next Intel hybrid CPU would be like, later on we may need to rework the following code. + // However, no matter what the code should not cause any crash. The worst is it might return 1 that + // thread pools will not be created, which is just a perf issue and does not impact usability. + // TODO: detect if CPUID instruction is available per instructions at https://wiki.osdev.org/CPUID#Checking_CPUID_availability + int regs[4]; + __cpuid(regs, 0); + bool bIsIntel = + (kVendorID_Intel[0] == regs[1]) && + (kVendorID_Intel[1] == regs[2]) && + (kVendorID_Intel[2] == regs[3]); + if (bIsIntel && regs[0] >= 7) { + // Query Structured Extended Feature Flags Enumeration Leaf + __cpuid(regs, 0x7); + // The bit 15 of EDX indicates if the processor is identified as a hybrid part. + bool ishybrid = regs[3] & (1 << 15); + if (ishybrid) { + // NOTE: even if ishybrid is true, it doesn't mean the processor must have P-cores and E-cores. + // On Intel CPUs we assume the HardwareCoreEnumerator::DefaultIntraOpNumThreads function would never fail. + // NOTE: due to resource restrictions, we cannot test this branch in our CI build pipelines. + return std::max(static_cast(1), HardwareCoreEnumerator::DefaultIntraOpNumThreads()); + } else { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); + } + } else +#endif + { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); + } } std::vector WindowsEnv::GetDefaultThreadAffinities() const { diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc new file mode 100644 index 0000000000000..121c59808ae59 --- /dev/null +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "hardware_core_enumerator.h" +#include +#include +#include + +namespace onnxruntime { + +struct LogicalProcessorInformation { + std::unique_ptr Buffer; + size_t Length; +}; + +struct CoreCounter { + uint32_t PhysicalCores = 0; + uint32_t SocDieCores = 0; +}; + +static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) { + DWORD length = 0; + DWORD rc = GetLogicalProcessorInformationEx(relationship, nullptr, &length); + + assert(rc == FALSE); + + auto processorInformationBytes = std::make_unique(length); + + rc = GetLogicalProcessorInformationEx( + relationship, reinterpret_cast(processorInformationBytes.get()), &length); + + assert(rc == TRUE); + + return {std::move(processorInformationBytes), length}; +} + +uint32_t CountSetBits(DWORD input) { + uint32_t c; + for (c = 0; input; c++) { + input &= input - 1; + } + return c; +} + +static CoreCounter GetNumberOPhysicalAndEngineeringCores() { + auto logicalProcessorInformation = GetLogicalProcessorInfos(RelationAll); + + CoreCounter cores; + DWORD dwLevel2GroupMask = 0; + DWORD dwLevel3GroupMask = 0; + size_t read = 0; + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX currentProcessorInfo = NULL; + + while ((read + FIELD_OFFSET(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, Processor)) < logicalProcessorInformation.Length) { + currentProcessorInfo = + reinterpret_cast(logicalProcessorInformation.Buffer.get() + read); + if ((read + currentProcessorInfo->Size) > logicalProcessorInformation.Length) { + break; + } + + switch (currentProcessorInfo->Relationship) { + case RelationProcessorCore: + cores.PhysicalCores++; + break; + case RelationCache: + if (currentProcessorInfo->Cache.Level == 2) { + dwLevel2GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; + } else if (currentProcessorInfo->Cache.Level == 3) { + dwLevel3GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; + } + break; + } + + read += currentProcessorInfo->Size; + } + + cores.SocDieCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); + return cores; +} + +uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { + // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. + // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. + auto cores = GetNumberOPhysicalAndEngineeringCores(); + // We want to use the number of physical cores, but exclude soc cores + return cores.PhysicalCores - cores.SocDieCores; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.h b/onnxruntime/core/platform/windows/hardware_core_enumerator.h new file mode 100644 index 0000000000000..93b50f452afcd --- /dev/null +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +namespace onnxruntime { +struct HardwareCoreEnumerator { + HardwareCoreEnumerator() = delete; + static uint32_t DefaultIntraOpNumThreads(); +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/util/thread_utils.cc b/onnxruntime/core/util/thread_utils.cc index a5a165e150cf1..2a6c14ff1b058 100644 --- a/onnxruntime/core/util/thread_utils.cc +++ b/onnxruntime/core/util/thread_utils.cc @@ -93,22 +93,31 @@ static std::unique_ptr CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) { ThreadOptions to; if (options.thread_pool_size <= 0) { // default - auto default_affinities = Env::Default().GetDefaultThreadAffinities(); - if (default_affinities.size() <= 1) { - return nullptr; - } - options.thread_pool_size = static_cast(default_affinities.size()); if (options.auto_set_affinity) { #ifdef _WIN32 // Only set thread affinity on Server with auto affinity. // On client best to let OS scheduler handle. // On big (P-Core) / little (E-Core) CPU designs affinity overrides QoS and has high power usage if (IsWindowsServer()) { + auto default_affinities = Env::Default().GetDefaultThreadAffinities(); + if (default_affinities.size() <= 1) { + return nullptr; + } + options.thread_pool_size = static_cast(default_affinities.size()); to.affinities = std::move(default_affinities); + } else { + options.thread_pool_size = Env::Default().GetNumPhysicalCpuCores(); } #else + auto default_affinities = Env::Default().GetDefaultThreadAffinities(); + if (default_affinities.size() <= 1) { + return nullptr; + } + options.thread_pool_size = static_cast(default_affinities.size()); to.affinities = std::move(default_affinities); #endif + } else { + options.thread_pool_size = Env::Default().GetNumPhysicalCpuCores(); } } if (options.thread_pool_size <= 1) { diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 96567c8767a82..244bebd81474d 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1526,7 +1526,8 @@ def generate_build_tree( ldflags = ["/profile", "/DYNAMICBASE"] # Address Sanitizer libs do not have a Qspectre version. So they two cannot be both enabled. if not args.enable_address_sanitizer: - cflags += ["/Qspectre"] + # Also enable a special perf patch that was made for Intel Meteor Lake mobile CPUs + cflags += ["/Qspectre", "/DONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH"] if config == "Release": cflags += ["/O2", "/Ob2", "/DNDEBUG"] elif config == "RelWithDebInfo": From d63c664ca0021fbac31cee57ff1eaa8bce3d1903 Mon Sep 17 00:00:00 2001 From: rui-ren Date: Thu, 15 Feb 2024 00:02:08 -0800 Subject: [PATCH 087/207] fix rocm ci pipeline (#19525) ### Description ROCm CI pipeline issue. ``` Downloading and preparing dataset wikitext/wikitext-2-raw-v1 (download: 4.50 MiB, generated: 12.91 MiB, post-processed: Unknown size, total: 17.41 MiB) to /home/onnxruntimedev/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20... main() File "/stage/huggingface-transformers/examples/pytorch/language-modeling/run_mlm.py", line 242, in main datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir) File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/load.py", line 856, in load_dataset builder_instance.download_and_prepare( File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/builder.py", line 583, in download_and_prepare self._download_and_prepare( File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/builder.py", line 639, in _download_and_prepare split_generators = self._split_generators(dl_manager, **split_generators_kwargs) File "/home/onnxruntimedev/.cache/huggingface/modules/datasets_modules/datasets/wikitext/aa5e094000ec7afeb74c3be92c88313cd6f132d564c7effd961c10fd47c76f20/wikitext.py", line 138, in _split_generators data_file = dl_manager.download_and_extract(self.config.data_url) File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/utils/download_manager.py", line 289, in download_and_extract return self.extract(self.download(url_or_urls)) File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/utils/download_manager.py", line 197, in download downloaded_path_or_paths = map_nested( File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/utils/py_utils.py", line 195, in map_nested return function(data_struct) File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/utils/download_manager.py", line 220, in _download return cached_path(url_or_filename, download_config=download_config) File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/utils/file_utils.py", line 281, in cached_path output_path = get_from_cache( File "/opt/miniconda/envs/rocm-ci/lib/python3.9/site-packages/datasets/utils/file_utils.py", line 634, in get_from_cache raise ConnectionError("Couldn't reach {}".format(url)) ConnectionError: Couldn't reach https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip ``` ### Motivation and Context Update the `datasets` pipeline to latest version `2.17.0`. --- tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 64710a982a29d..496b57b417fbd 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -112,7 +112,7 @@ RUN pip install \ cerberus \ sympy \ h5py \ - datasets==1.9.0 \ + datasets==2.17.0 \ requests \ sacrebleu==1.5.1 \ sacremoses \ From d0061d6fb15d40eeb35fa1b40a414cd231d51db9 Mon Sep 17 00:00:00 2001 From: sophies927 <107952697+sophies927@users.noreply.github.com> Date: Thu, 15 Feb 2024 17:03:11 -0800 Subject: [PATCH 088/207] Update stale.yml to use old version as a bug fix (#19532) ### Description Changed the actions/stale version back to v8 from v9. ### Motivation and Context There is a well-documented issue w/ the new actions/stale version (v9.0.0) that causes the following error: "Error delete _state: [403] Resource not accessible by integration". See https://github.com/actions/stale/issues/1133 for more context. This issue is preventing the stale bot from labeling stale issues since the version was updated b/c the action can no longer access the cache and cannot apply labels to all issues due to GH API rate limiting. There are two potential fixes if we continue to use the new version: (1) run the action on all PRs/issues to avoid using the cache or (2) give write access to the endpoints listed in https://docs.github.com/en/rest/authentication/permissions-required-for-fine-grained-personal-access-tokens?apiVersion=2022-11-28#repository-permissions-for-actions. Neither of these options is preferable, so I am going to wait until the bug is fixed. Note: The old version (v8.0.0) uses Node 16, which will be deprecated in Spring 2024, instead of Node 20, so we should keep an eye on [this issue](https://github.com/actions/stale/issues/1133) to see when they make the fix and we can switch back to the new version. --- .github/workflows/stale.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index c94e3fa5bcb8c..181f3fb17d332 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v9.0.0 + - uses: actions/stale@v8 with: # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: contributions welcome, feature request, regression From 4bfa69def85476b33ccfaf68cf070f3fb65d39f7 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 15 Feb 2024 20:22:36 -0800 Subject: [PATCH 089/207] Speed Up DecoderMaskedSelfAttentionTest (#19531) ### Description The unit tests take 19 minutes to run (in debug build) because of too many combinations. I reduce the combinations and remain good test coverage. After the change, the test can finish in 51 seconds. Before: [----------] 2 tests from DecoderMaskedSelfAttentionTest [ RUN ] DecoderMaskedSelfAttentionTest.Test_fp32 [ OK ] DecoderMaskedSelfAttentionTest.Test_fp32 (394086 ms) [ RUN ] DecoderMaskedSelfAttentionTest.Test_fp16 [ OK ] DecoderMaskedSelfAttentionTest.Test_fp16 (747035 ms) [----------] 2 tests from DecoderMaskedSelfAttentionTest (1141122 ms total) After: [----------] 2 tests from DecoderMaskedSelfAttentionTest [ RUN ] DecoderMaskedSelfAttentionTest.Test_fp32 [ OK ] DecoderMaskedSelfAttentionTest.Test_fp32 (21057 ms) [ RUN ] DecoderMaskedSelfAttentionTest.Test_fp16 [ OK ] DecoderMaskedSelfAttentionTest.Test_fp16 (30653 ms) [----------] 2 tests from DecoderMaskedSelfAttentionTest (51710 ms total) ### Motivation and Context Reduce test time, and improve build pipeline efficiency. --- ...oder_masked_multihead_attention_op_test.cc | 451 ++++++++++-------- 1 file changed, 242 insertions(+), 209 deletions(-) diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index 6afb61bd1f0a1..8ea37ad054ed0 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -640,122 +640,139 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { return; } - // Vary batch size - for (int batch_size = 1; batch_size <= 5; batch_size += 2) { - // Vary kv_lengths - for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) { - int sequence_length = 1; - int number_of_heads = 12; - // Vary head_size / hidden_size - int hidden_sizes[3] = {384, 768, 1536}; - for (int hidden_size : hidden_sizes) { - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; - - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); - - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); - - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); - - // Mask - tester.AddOptionalInputEdge(); - - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; - - auto kv_cache = CreateRandom(past_present_size); - - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); - - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(float); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); - - tester.AddInput("past", past_dims, reordered_kv_cache); - - // Rel - tester.AddOptionalInputEdge(); - - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); - - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); - - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); - - auto k_merged = pair.first; - auto k_transpose = pair.second; - - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); - - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); - - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); - - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + // Buckets for test data: + // batch_size: 1, >=2 + // past_sequence_length 0~30, 31~2046, >=2047 (so that total_sequence_length: 1~31, 32~2047, >=2048) + // head_size: 32, 64, 128 + struct MyTestCase { + int batch_size; + int past_sequence_length; + int hidden_size; + } test_cases[] = { + {1, 0, 768}, + {1, 1, 384}, + {2, 30, 768}, + {3, 31, 1536}, + {4, 512, 384}, + {1, 1024, 768}, + {1, 2046, 1536}, + {2, 2047, 384}, + {3, 3000, 768}, + }; + + constexpr int sequence_length = 1; + constexpr int number_of_heads = 12; + + for (MyTestCase test_case : test_cases) { + int batch_size = test_case.batch_size; + int past_sequence_length = test_case.past_sequence_length; + int hidden_size = test_case.hidden_size; + + int head_size = (hidden_size / number_of_heads); + int total_sequence_length = sequence_length + past_sequence_length; + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + + OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(1)); + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector weights_dims = {hidden_size, 3 * hidden_size}; + std::vector bias_dims = {3 * hidden_size}; + std::vector output_dims = {batch_size, sequence_length, hidden_size}; + + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); + + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); + + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); + + // Mask + tester.AddOptionalInputEdge(); + + // Past + std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + + auto kv_cache = CreateRandom(past_present_size); + + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + number_of_heads, past_sequence_length, head_size, max_sequence_length); + + // Validate if reordering went well - by transposing and checking equality + int chunk_size = 16 / sizeof(float); + int num_chunks = head_size / chunk_size; + auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); + + tester.AddInput("past", past_dims, reordered_kv_cache); + + // Rel + tester.AddOptionalInputEdge(); + + // Past sequence length + std::vector arr_past_sequence_len(1, past_sequence_length); + tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + + // QKV MatMul + auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); + auto* qkv_matrix = qkv.data(); + + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, + max_sequence_length, head_size); + + auto k_merged = pair.first; + auto k_transpose = pair.second; + + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, + total_sequence_length, head_size); + + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, + sequence_length, total_sequence_length, head_size); + + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); + + // Validate our test logic + // We want to validate if our merged "unordered" K is the same as + // the merged "ordered" K so that the QKT we do in our test code + // is equivalent to the QKT we do in the kernel + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); + + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, number_of_heads, + sequence_length, total_sequence_length, + max_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); - - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); - - // Output(s) - tester.AddOutput("output", input_dims, output); + // Output(s) + tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); + tester.AddOutput("present", past_dims, present); - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } + // Run - Regular kernel execution path + { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; + // Test alternate kernel path of loading more KV data "in flight" + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } } @@ -766,122 +783,138 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { return; } - // Vary batch size - for (int batch_size = 1; batch_size <= 5; batch_size += 2) { - // Vary kv_lengths - for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) { - int sequence_length = 1; - int number_of_heads = 12; - - // Vary head_size / hidden_size - int hidden_sizes[3] = {384, 768, 1536}; - for (int hidden_size : hidden_sizes) { - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; - - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); - - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); - - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); - - // Mask - tester.AddOptionalInputEdge(); - - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; - - auto kv_cache = CreateRandom(past_present_size); - - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); + // Buckets for test data: + // batch_size: 1, >=2 + // past_sequence_length 0, 1~30, 31~2046, >=2047 (so that total_sequence_length: 1, 2-31, 32~2047, >=2048) + // head_size: 32, 64, 128 + struct MyTestCase { + int batch_size; + int past_sequence_length; + int hidden_size; + } test_cases[] = { + {1, 0, 768}, + {1, 1, 768}, + {3, 30, 384}, + {8, 31, 1536}, + {4, 256, 384}, + {3, 1024, 768}, + {2, 2046, 1536}, + {1, 2047, 384}, + {2, 3000, 768}, + }; + + constexpr int sequence_length = 1; + constexpr int number_of_heads = 12; + + for (MyTestCase test_case : test_cases) { + int batch_size = test_case.batch_size; + int past_sequence_length = test_case.past_sequence_length; + int hidden_size = test_case.hidden_size; + + int head_size = (hidden_size / number_of_heads); + int total_sequence_length = sequence_length + past_sequence_length; + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + + OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(1)); + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector weights_dims = {hidden_size, 3 * hidden_size}; + std::vector bias_dims = {3 * hidden_size}; + std::vector output_dims = {batch_size, sequence_length, hidden_size}; + + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); + + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); + + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); + + // Mask + tester.AddOptionalInputEdge(); + + // Past + std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + + auto kv_cache = CreateRandom(past_present_size); + + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + number_of_heads, past_sequence_length, head_size, max_sequence_length); - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(MLFloat16); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); + // Validate if reordering went well - by transposing and checking equality + int chunk_size = 16 / sizeof(MLFloat16); + int num_chunks = head_size / chunk_size; + auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); - tester.AddInput("past", past_dims, reordered_kv_cache); + tester.AddInput("past", past_dims, reordered_kv_cache); - // Rel - tester.AddOptionalInputEdge(); + // Rel + tester.AddOptionalInputEdge(); - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + // Past sequence length + std::vector arr_past_sequence_len(1, past_sequence_length); + tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); + // QKV MatMul + auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); + auto* qkv_matrix = qkv.data(); - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, + max_sequence_length, head_size); - auto k_merged = pair.first; - auto k_transpose = pair.second; + auto k_merged = pair.first; + auto k_transpose = pair.second; - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, + total_sequence_length, head_size); - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, + sequence_length, total_sequence_length, head_size); - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + // Validate our test logic + // We want to validate if our merged "unordered" K is the same as + // the merged "ordered" K so that the QKT we do in our test code + // is equivalent to the QKT we do in the kernel + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, number_of_heads, + sequence_length, total_sequence_length, + max_sequence_length, head_size); - // Output(s) - tester.AddOutput("output", input_dims, output); + // Output(s) + tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); + tester.AddOutput("present", past_dims, present); - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } + // Run - Regular kernel execution path + { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; + // Test alternate kernel path of loading more KV data "in flight" + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } } @@ -889,4 +922,4 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { #endif } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime From ef0b71308c0e2395d3ea63e627515ff8e624ad45 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 16 Feb 2024 05:34:55 -0800 Subject: [PATCH 090/207] Optimize KahnsTopologicalSort and PriorityNodeCompare (#19475) **Description** 1) During SessionInitialization, KahnsTopologicalSort is a major cause of perf degradation. The main cause of slow down is that the TopologicalSort needs to keep track of nodes to visit in order, and reorder them based on priority (as informed by a comparator). The existing implementation uses a priority_queue that is backed by a std::vector container. However, vectors are not good for insertion and reordering. The appropriate data type for this operation is a linked list. However, linked lists like std::list are not usable as a container for std::priority_queue. This is because std::priority_queue requires random access, which linked lists do not have. However, for this simple implementation, we can leverage a std::list under the hood and perform insertions manually using std::upper_bound. This drastically reduces the time taken by the method, which currently instead causes numerous recopies and a lot of movement inside the graph nodes to visit list. 2) In the comparator, I hide forward and backward attribute checking behind the #ifdef ENABLE_TRAINING macro, as I believe it should only be valid in the training scenario. 3) In noopelimination transformer, I prevent the creation of Initializer (which unpacks tensorproto data) in every node and only create initializers when Add/Sub/Mul/Div op nodes are detected. **Motivation and Context** Session creation time of many models is quite slow. --------- Co-authored-by: Sheil Kumar --- onnxruntime/core/graph/graph.cc | 37 ++++++++-- onnxruntime/core/graph/graph_viewer.cc | 18 +++-- .../core/optimizer/noop_elimination.cc | 73 +++++++++++-------- .../ort_optimizer_api_impl.cc | 2 +- 4 files changed, 85 insertions(+), 45 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 902839bee04ba..305122c56b865 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,16 +1818,36 @@ void Graph::ReverseDFSFrom(gsl::span from, } } +template +struct VisitorPriorityQueue { + using ComparatorType = std::function; + std::list list_; + const ComparatorType comparator_ = nullptr; + VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} + + void push(T node) { + list_.insert( + std::upper_bound(list_.begin(), list_.end(), node, comparator_), + node); + } + bool empty() { return list_.empty(); } + T top() { return list_.back(); } + void pop() { list_.pop_back(); } +}; + #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { - std::unordered_map in_degree; - std::priority_queue, decltype(comp)> to_visit(comp); - std::vector topo_order; + InlinedVector in_degree(MaxNodeIndex(), 0); + InlinedVector topo_order; + VisitorPriorityQueue to_visit(comp); + + auto number_of_nodes = NumberOfNodes(); + topo_order.reserve(number_of_nodes); for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); - in_degree.insert({node.Index(), input_edge_count}); + in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { to_visit.push(&node); } @@ -1844,16 +1864,17 @@ void Graph::KahnsTopologicalSort(const std::function& enter, } for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) { - in_degree[node_it->Index()]--; + auto& node_in_degree = in_degree[node_it->Index()]; + node_in_degree--; - if (in_degree[node_it->Index()] == 0) { + if (node_in_degree == 0) { to_visit.push(&*node_it); } } topo_order.push_back(current->Index()); } - if (NumberOfNodes() != static_cast(topo_order.size())) { + if (number_of_nodes != static_cast(topo_order.size())) { ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } } @@ -2843,7 +2864,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { const gsl::not_null tensor_added{graph_proto_->add_initializer()}; *(tensor_added) = tensor; - name_to_initial_tensor_[tensor.name()] = tensor_added; + name_to_initial_tensor_.emplace(tensor.name(), tensor_added); SetGraphResolveNeeded(); if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) { // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index acf7b3a16541f..119d420066a84 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const { struct PriorityNodeCompare { inline bool IsHighPri(const Node* n) const { // local statics so we can compare std::strings in the checks - static const std::string shape_op("Shape"); - static const std::string size_op("Size"); + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); const auto& op_type = n->OpType(); return op_type == shape_op || op_type == size_op; @@ -26,15 +26,20 @@ struct PriorityNodeCompare { // If return true, n2 will be output first bool operator()(const Node* n1, const Node* n2) const { // nodes in global high priority list will be output first - if (IsHighPri(n1) != IsHighPri(n2)) { - return IsHighPri(n2); + const bool isN1HighPri = IsHighPri(n1); + const bool isN2HighPri = IsHighPri(n2); + if (isN1HighPri != isN2HighPri) { + return isN2HighPri; } // nodes with lower priority value will be output first - if (n1->Priority() != n2->Priority()) { - return n1->Priority() > n2->Priority(); + const auto n1_priority = n1->Priority(); + const auto n2_priority = n2->Priority(); + if (n1_priority != n2_priority) { + return n1_priority > n2_priority; } +#ifdef ENABLE_TRAINING // nodes of forward pass will be output first auto n1_attrs = n1->GetAttributes(); auto n2_attrs = n2->GetAttributes(); @@ -45,6 +50,7 @@ struct PriorityNodeCompare { if (n1_is_forward != n2_is_forward) { return n2_is_forward > n1_is_forward; } +#endif // otherwise, nodes with lower index will be output first return n1->Index() > n2->Index(); diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index b3c2991d54b28..bba39b698a27a 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -42,49 +42,62 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con // if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule, // but it won't happen if the case is accepted, thus reject it - auto initializer_rank = initializer->dims().size(); + const auto& dims = initializer->dims(); + auto initializer_rank = dims.size(); const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape(); if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) { return false; } - int32_t data_type = initializer->data_type(); - Initializer add_init(*initializer, graph.ModelPath()); - if (add_init.size() > 1) { + int64_t tensor_size = 1; + for (auto i : dims) { + tensor_size *= i; + } + + if (tensor_size > 1) { return false; } + // handle edge case where the total size of the initializer is 0 - if (add_init.size() == 0) { + if (tensor_size == 0) { return true; } - float value = 0.0f; - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - value = *add_init.data(); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - value = math::halfToFloat(add_init.data()->val); - break; - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - value = static_cast(*add_init.data()); - break; - default: + if (op_type == "Add" || + op_type == "Sub" || + op_type == "Mul" || + op_type == "Div") { + int32_t data_type = initializer->data_type(); + Initializer add_init(*initializer, graph.ModelPath()); + + float value = 0.0f; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + value = *add_init.data(); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + value = math::halfToFloat(add_init.data()->val); + break; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + value = static_cast(*add_init.data()); + break; + default: + return false; + } + + if (value != 0.0f && (op_type == "Add" || op_type == "Sub")) { return false; - } + } - if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) { - return false; - } - - if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) { - return false; + if (value != 1.0f && (op_type == "Mul" || op_type == "Div")) { + return false; + } } // reject node output is graph output for now diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index d9f08ffe1171e..c532f56b3d3d9 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -115,7 +115,7 @@ class ApiGraph final : public api::GraphRef { const auto& graph_outputs = graph_.GetOutputs(); graph_outputs_.reserve(graph_outputs.size()); for (const auto* output : graph_outputs) { - graph_outputs_.insert(output->Name()); + graph_outputs_.emplace(output->Name()); } } From b84712151c06f0f59359916be572f71bd36721a4 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 16 Feb 2024 14:36:05 -0800 Subject: [PATCH 091/207] QNN EP: Fuse DQ -> Q sequences into a QNN Convert op (#19511) ### Description Fuses DQ -> Q sequences into a QNN Convert operator if: - Converting from one qtype to another. Ex: Dequantize(uint8 to float) -> Quantize(float to uint16) - The DQ and Q operators are not part of another node unit (i.e., standalone) - The Q operator is the only consumer for the DQ operator. ### Motivation and Context Allows faster execution of QDQ models with mixed activation types by leveraging the QNN Convert operator, which converts between quantization types. For certain models, this results in inference latency speed-ups of up to 2x (depends on the number of DQ -> Q sequences). #### Example for Add node unit with 16-bit I/O: Original: ``` u8 ----> DQ ---> Q ---u16--> Add ---u16--> ^ | u16 --------------------------+ ``` After fusing DQ -> Q: ``` u8 ----> Convert ---u16--> Add ---u16--> ^ | u16 ------------------------+ ``` --- .../optimizer/qdq_transformer/qdq_util.cc | 43 ++++++++ .../core/optimizer/qdq_transformer/qdq_util.h | 12 ++ .../qnn/builder/op_builder_factory.h | 23 ++++ .../builder/opbuilder/convert_op_builder.cc | 103 ++++++++++++++++++ .../core/providers/qnn/builder/qnn_model.cc | 35 +++++- .../providers/qnn/qnn_execution_provider.cc | 88 +++++++++------ .../providers/qnn/qnn_execution_provider.h | 1 - .../test/providers/qnn/simple_op_htp_test.cc | 55 ++++++++++ 8 files changed, 319 insertions(+), 41 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index b1ab641a23256..4e3dff705bd41 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -76,6 +76,49 @@ bool IsQDQPairSupported( } } +bool IsDQQConversion( + const Node& dq_node, const Node& q_node, + const GetConstantInitializerFn& get_const_initializer, + const Path& model_path) { + ConstPointerContainer> dq_input_defs = dq_node.InputDefs(); + ConstPointerContainer> q_input_defs = q_node.InputDefs(); + + // Q/DQ contains optional input is not supported + // non-scalar Q/DQ scale and zero point needs are not supported + if (dq_input_defs.size() != InputIndex::TOTAL_COUNT || + q_input_defs.size() != InputIndex::TOTAL_COUNT || + !optimizer_utils::IsScalar(*q_input_defs[InputIndex::SCALE_ID]) || + !optimizer_utils::IsScalar(*q_input_defs[InputIndex::ZERO_POINT_ID]) || + !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::SCALE_ID]) || + !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::ZERO_POINT_ID])) { + return false; + } + + // if Q/DQ scale and zero point are not constant, return false + const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = + get_const_initializer(dq_input_defs[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = + get_const_initializer(q_input_defs[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = + get_const_initializer(dq_input_defs[InputIndex::ZERO_POINT_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = + get_const_initializer(q_input_defs[InputIndex::ZERO_POINT_ID]->Name()); + if (nullptr == q_zp_tensor_proto || + nullptr == dq_zp_tensor_proto || + nullptr == q_scale_tensor_proto || + nullptr == dq_scale_tensor_proto) { + return false; + } + + // check Q/DQ have same scale type and different zero point type + Initializer q_zp(*q_zp_tensor_proto, model_path); + Initializer q_scale(*q_scale_tensor_proto, model_path); + Initializer dq_zp(*dq_zp_tensor_proto, model_path); + Initializer dq_scale(*dq_scale_tensor_proto, model_path); + + return (dq_zp.data_type() != q_zp.data_type()) && (dq_scale.data_type() == q_scale.data_type()); +} + bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) { bool zero_point_exists = false; if (!QOrDQNodeHasConstantScalarScaleAndZeroPoint(dq_node, get_const_initializer, zero_point_exists)) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index bb0bf9438cfcb..8333168b0093f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -38,6 +38,18 @@ bool IsQDQPairSupported( const GetConstantInitializerFn& get_const_initializer, const Path& model_path); +// Check if a DQ -> Q sequence represents a conversion in quantization data type. +// Example of uint8 to uint16: +// Dequantize (uint8 to float) -> Quantize (float to uint16) +// Requires: +// 1. Q/DQ doesn't have optional input. +// 2. scale and zero-point are constant scalars. +// 3. Q and DQ have the same scale *type* and different zero-point *types*. +bool IsDQQConversion( + const Node& dq_node, const Node& q_node, + const GetConstantInitializerFn& get_const_initializer, + const Path& model_path); + // Check if DQ is supported in extended level QDQ transformers. It requires: // 1. DQ doesn't have optional input. // 2. scale and zero point is constant scalar diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index d95e2baa9457f..4a9106f0c06af 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -94,5 +94,28 @@ void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +struct HandleConvertResult { + Status status; // Indicates an unexpected error. Check if q_node_unit != nullptr to determine + // whether a DQ -> Q sequence was successfully merged into a Convert. + const NodeUnit* q_node_unit; // Non-null if successfully merged DQ -> Q sequence. + // Set to nullptr if this node unit could not be merged into a Convert. +}; + +/** + * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from + * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). + * + * \param qnn_model_wrapper The QNN model that is being built. + * \param maybe_dq_node_unit The node unit that could potentially start the DQ -> Q sequence. + * \param logger The logger. + * \param do_op_validation True if should call QNN operator validation APIs. + * \return An qnn::HandleConvertResult object that indicates success/failure and provides a pointer + * to the Q node unit that was successfully merged with the provided DQ node unit. + */ +HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& maybe_dq_node_unit, + const std::unordered_map& node_unit_map, + const logging::Logger& logger, + bool do_op_validation); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc new file mode 100644 index 0000000000000..977a9e0b3d9d0 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/graph_utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/common/safeint.h" +#include "onnx/defs/data_type_utils.h" + +#include "QnnOpDef.h" // From QNN SDK: contains QNN constants (e.g., op names, param values). + +namespace onnxruntime { +namespace qnn { + +class ConvertOpBuilder : public BaseOpBuilder { + public: + ConvertOpBuilder() : BaseOpBuilder("ConvertOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConvertOpBuilder); + + Status AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger, + bool do_op_validation) const ORT_MUST_USE_RESULT; +}; + +Status ConvertOpBuilder::AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger, + bool do_op_validation) const { + std::vector input_names; + + // Process the input from the DQ node + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, dq_node_unit.Inputs()[0], logger, input_names)); + + // Process the output from the Q node. Override the QNN operator type to "Convert". + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, q_node_unit, std::move(input_names), {}, + logger, do_op_validation, QNN_OP_CONVERT)); + return Status::OK(); +} + +HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& maybe_dq_node_unit, + const std::unordered_map& node_unit_map, + const logging::Logger& logger, + bool do_op_validation) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + // Looking for a standalone DQ to start the sequence. + if (maybe_dq_node_unit.OpType() != QDQ::DQOpName || maybe_dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return {}; + } + + const Node& dq_node = maybe_dq_node_unit.GetNode(); + + // DQ must have a single Q child. DQ must not produce a graph output. + auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName); + if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { + return {}; + } + + const Node& q_node = *children[0]; + const auto q_node_unit_it = node_unit_map.find(&q_node); + + if (q_node_unit_it == node_unit_map.end()) { + return {ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node does not have a corresponding NodeUnit"), nullptr}; + } + + const NodeUnit* q_node_unit = q_node_unit_it->second; + + // Q child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return {}; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + // DQ and Q must have equal scale type and different zp type. + if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { + return {}; + } + + ConvertOpBuilder op_builder; + + LOGS(logger, VERBOSE) << " Adding QNN Convert. dq_node name: [" << dq_node.Name() + << "] dq_node optype: [" << dq_node.OpType() + << "] q_node name: [" << q_node_unit->Name() + << "] q_node optype: [" << q_node_unit->OpType() + << "]"; + + auto status = op_builder.AddConvertToModelBuilder(qnn_model_wrapper, maybe_dq_node_unit, *q_node_unit, logger, + do_op_validation); + return status.IsOK() ? HandleConvertResult{status, q_node_unit} : HandleConvertResult{status, nullptr}; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 314cab4a36ca9..dc91b9dfa199e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -114,6 +114,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper."); } + std::unordered_set handled_node_units; + // Op builer const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); for (size_t i = 0; i < node_indices.size(); i++) { @@ -122,20 +124,43 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // Check whether it's part of NodeUnit const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map); // Q, DQ nodes in the node unit only carry the quantization parameters - // Add the QNN node when it is the target node (It's a normal node or a singel Q/DQ node) + // Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node) const std::string& op_type = node_unit.OpType(); + + if (node != &node_unit.GetNode()) { + continue; + } + + if (handled_node_units.count(&node_unit) != 0) { + continue; // Already handled. + } + + // Try to convert particular DQ -> Q sequences into QNN Convert op + auto convert_result = TryHandleConvertSequence(qnn_model_wrapper, + node_unit, + node_unit_map, + logger_, + false /*do_op_validation*/); + ORT_RETURN_IF_ERROR(convert_result.status); + + if (convert_result.q_node_unit) { + // Successfully merged DQ -> Q sequence into a QNN Convert op. + // Mark both of these node units as handled. + handled_node_units.insert(&node_unit); + handled_node_units.insert(convert_result.q_node_unit); + continue; + } + LOGS(logger_, VERBOSE) << " node name: [" << node->Name() << "] node optype: [" << op_type << "] as part of the NodeUnit type: [" << node_unit.OpType() << "] name: [" << node_unit.Name() << "]"; - if (node != &node_unit.GetNode()) { - continue; - } - if (const auto* op_builder = GetOpBuilder(op_type)) { ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_)); } + + handled_node_units.insert(&node_unit); } ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph."); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index b58f6e10df94c..f5a166d36b15a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -286,33 +286,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - std::unordered_map& node_unit_supported_result, const logging::Logger& logger) const { - // If we have visited one of the nodes in the node_unit, use the result directly - const auto it = node_unit_supported_result.find(&node_unit); - if (it != node_unit_supported_result.cend()) { - return it->second; + const std::string& op_type = node_unit.OpType(); + bool supported = false; + const auto* op_builder = qnn::GetOpBuilder(op_type); + if (op_builder == nullptr) { + LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." + << node_unit.OpType() << " node `" << node_unit.Name() + << "` will not be assigned to QNN EP."; } else { - const std::string& op_type = node_unit.OpType(); - - bool supported = false; - const auto* op_builder = qnn::GetOpBuilder(op_type); - if (op_builder == nullptr) { - LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." - << node_unit.OpType() << " node `" << node_unit.Name() - << "` will not be assigned to QNN EP."; - } else { - auto status = op_builder->IsOpSupported(qnn_model_wrapper, - node_unit, logger); - if (Status::OK() != status) { - LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() - << "` is not supported: " << status.ErrorMessage(); - } - supported = (Status::OK() == status); + auto status = op_builder->IsOpSupported(qnn_model_wrapper, + node_unit, logger); + if (Status::OK() != status) { + LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() + << "` is not supported: " << status.ErrorMessage(); } - node_unit_supported_result[&node_unit] = supported; - return supported; + supported = (Status::OK() == status); } + return supported; } std::unordered_set @@ -391,24 +382,51 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, if (node != &node_unit->GetNode()) { continue; } - const bool supported = IsNodeSupported(qnn_model_wrapper, - *node_unit, - node_unit_supported_result, - logger); - LOGS(logger, VERBOSE) << "Node supported: [" << supported - << "] index: [" << node->Index() - << "] name: [" << node->Name() - << "] Operator type: [" << node->OpType() - << "] as part of the NodeUnit type: [" << node_unit->OpType() - << "] index: [" << node_unit->Index() - << "] name: [" << node_unit->Name() - << "]"; + + if (node_unit_supported_result.count(node_unit) != 0) { + continue; // Already handled this node unit + } + + // Try to convert certain standalone DQ -> Q sequences into QNN Convert op + auto convert_result = TryHandleConvertSequence(qnn_model_wrapper, + *node_unit, + node_unit_map, + logger, + true /*do_op_validation*/); + if (!convert_result.status.IsOK()) { + LOGS(logger, WARNING) << "Failed to convert DQ -> Q sequence to QNN Convert. " + << "Type: " << node_unit->OpType() << ", Node name: " << node_unit->Name() << ", " + << "Message: " << convert_result.status.ErrorMessage(); + } + + bool supported = false; + + if (convert_result.status.IsOK() && convert_result.q_node_unit) { // Merged DQ -> Q sequence into QNN Convert op + supported = true; + + // Mark the Q node unit as handled and supported here so that we don't try to process it again. + node_unit_supported_result.insert({convert_result.q_node_unit, true}); + supported_nodes.insert(&convert_result.q_node_unit->GetNode()); + } else { + supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger); + LOGS(logger, VERBOSE) << "Node supported: [" << supported + << "] index: [" << node->Index() + << "] name: [" << node->Name() + << "] Operator type: [" << node->OpType() + << "] as part of the NodeUnit type: [" << node_unit->OpType() + << "] index: [" << node_unit->Index() + << "] name: [" << node_unit->Name() + << "]"; + } + if (supported) { // If the node_unit is supported, add all of its nodes to the supported list. for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) { supported_nodes.insert(node_in_group); } } + + node_unit_supported_result.insert({node_unit, supported}); } return supported_nodes; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 09bcb24db4dc2..0bcaa39b22f6d 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -42,7 +42,6 @@ class QNNExecutionProvider : public IExecutionProvider { private: bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - std::unordered_map& node_unit_supported_result, const logging::Logger& logger) const; std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 2f3b0e84a123e..a6422407d79fd 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1110,6 +1110,61 @@ TEST_F(QnnHTPBackendTests, LpNormalization_u16_rank4) { kOnnxDomain, true); } + +static GetTestQDQModelFn BuildQDQConvertAddTestCase(const TestInputDef& input0_def, + const TestInputDef& input1_def) { + return [input0_def, input1_def](ModelTestBuilder& builder, std::vector>& output_qparams) { + constexpr bool use_contrib_qdq = true; + + // Input0 -> Quantize(u8) -> Dequantize(u8 to float) -> input0_after_qdq + NodeArg* input0 = MakeTestInput(builder, input0_def); + QuantParams input0_u8_qparams = GetTestInputQuantParams(input0_def); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_u8_qparams.scale, + input0_u8_qparams.zero_point, use_contrib_qdq); + + // input0_after_qdq -> Quantize(u16) -> Dequantize(u16 to float) + QuantParams input0_u16_qparams = GetTestInputQuantParams(input0_def); + NodeArg* input0_after_convert = AddQDQNodePair(builder, input0_after_qdq, input0_u16_qparams.scale, + input0_u16_qparams.zero_point, use_contrib_qdq); + + // Input1 -> Quantize(u16) -> Dequantize(u16 to float) -> input1_after_qdq + NodeArg* input1 = MakeTestInput(builder, input1_def); + QuantParams input1_qparams = GetTestInputQuantParams(input1_def); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point, use_contrib_qdq); + + // Add op -> op_output + auto* op_output = builder.MakeIntermediate(); + builder.AddNode("Add", {input0_after_convert, input1_after_qdq}, {op_output}); + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Test quantization type conversion (mixed precision) with Add. +// First input is converted from uint8_t to uint16_t. +TEST_F(QnnHTPBackendTests, Add_U8_U16_Convert) { + std::vector input0_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + std::vector input1_data = GetFloatDataInRange(-20.0f, 20.0f, 8); + TestInputDef input0_def({1, 2, 2, 2}, false, input0_data); + TestInputDef input1_def({1, 2, 2, 2}, false, input1_data); + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + TestQDQModelAccuracy(BuildOpTestCase("Add", {input0_def, input1_def}, {}, {}, kOnnxDomain), + BuildQDQConvertAddTestCase(input0_def, input1_def), + provider_options, + 18, + ExpectedEPNodeAssignment::All); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test From 1dce5e17321d50bf345022b525a937933473415a Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Fri, 16 Feb 2024 14:41:11 -0800 Subject: [PATCH 092/207] Disable TF32 in Linux_Test stage of Linux GPU CI Pipeline (#19541) ### Description Some test thresholds that previously worked in T4 GPU does not work anymore. The reason is current pipeline uses A10, and TF32 is enabled by default. Disable TF32 in Linux GPU CI Pipeline in testing to avoid such random test failure. ### Motivation and Context Linux Test has random failure at tests: ProviderOptionsTest > testCUDAOptions() FAILED org.opentest4j.AssertionFailedError: array contents differ at index [446], expected: <0.0419757> but was: <0.041948937> at app//org.junit.jupiter.api.AssertionFailureBuilder.build(AssertionFailureBuilder.java:151) at app//org.junit.jupiter.api.AssertionFailureBuilder.buildAndThrow(AssertionFailureBuilder.java:132) at app//org.junit.jupiter.api.AssertArrayEquals.failArraysNotEqual(AssertArrayEquals.java:440) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:290) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:123) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:119) at app//org.junit.jupiter.api.Assertions.assertArrayEquals(Assertions.java:1360) at app//ai.onnxruntime.providers.ProviderOptionsTest.runProvider(ProviderOptionsTest.java:99) at app//ai.onnxruntime.providers.ProviderOptionsTest.testCUDAOptions(ProviderOptionsTest.java:43) org.opentest4j.AssertionFailedError: array contents differ at index [6], expected: <0.0225981> but was: <0.022587791> at app//org.junit.jupiter.api.AssertionFailureBuilder.build(AssertionFailureBuilder.java:151) at app//org.junit.jupiter.api.AssertionFailureBuilder.buildAndThrow(AssertionFailureBuilder.java:132) at app//org.junit.jupiter.api.AssertArrayEquals.failArraysNotEqual(AssertArrayEquals.java:440) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:290) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:123) at app//org.junit.jupiter.api.AssertArrayEquals.assertArrayEquals(AssertArrayEquals.java:119) at app//org.junit.jupiter.api.Assertions.assertArrayEquals(Assertions.java:1360) at app//ai.onnxruntime.InferenceTest.runProvider(InferenceTest.java:676) at app//ai.onnxruntime.InferenceTest.testCUDA(InferenceTest.java:615) --- tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index b19a8b11db265..24319184dd0b8 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -204,6 +204,7 @@ jobs: --volume /data/models:/build/models:ro \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ --volume /data/onnx:/data/onnx \ + -e NVIDIA_TF32_OVERRIDE=0 \ $(Repository) \ /bin/bash -c " set -ex; \ From 44d8ad93b20efdba921ca80f23485c084b5174d0 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Fri, 16 Feb 2024 15:21:43 -0800 Subject: [PATCH 093/207] Whisper Timestamps and Temperature (#19509) ### Description This PR updates exporting and running the Whisper model with beam search by adding the following. - Adds temperature as a graph input to the exported model - Fixes the token ids by adding them as attributes to `WhisperBeamSearch` - Fixes the timestamps test cases so they pass now - Fixes a bug with invoking `torch.onnx.export` - Cleans up the Whisper scripts and groups the arguments in `convert_to_onnx.py` - Adds a `requirements.txt` file to specify package dependencies - Adds `whisper-large-v3` to list of pretrained models - Fixes a bug with missing cross-attention KV cache inputs in the decoder subgraph ### Motivation and Context - This is a follow-up to [this PR](https://github.com/microsoft/onnxruntime/pull/19188). - The incorrect token ids in the timestamps processor were first noticed during [this PR review](https://github.com/microsoft/onnxruntime/pull/17500#discussion_r1333520007). When they were originally added in [this PR](https://github.com/microsoft/onnxruntime/pull/15853), the offsets were previously constant across the Whisper model sizes. When comparing the new `whisper-large-v3` variant, the English-only variants (e.g. `whisper-tiny.en`), and the original variants (e.g. `whisper-tiny`), both the values and the offsets differ. Therefore, it is easier to set the token ids as attributes to `WhisperBeamSearch` when exporting to ensure the right values are used in the timestamps processor. - The Hugging Face API for returning timestamps and the expected outputs from the PyTorch model have both changed. - The fix for `torch.onnx.export` is a follow-up to [this PR review](https://github.com/microsoft/onnxruntime/pull/17179#issuecomment-1683001470). - The argument grouping is a follow-up to [this PR review](https://github.com/microsoft/onnxruntime/pull/17500#discussion_r1333521721). - Specific package versions are needed to run the Whisper scripts and the `requirements.txt` file ensures that these versions are installed. - The `whisper-large-v3` variant is released and should be in the list of official pretrained models. - After the changes from [this PR](https://github.com/microsoft/onnxruntime/pull/17316), the exported model is not loading in an ORT inference session because the cross-attention KV cache inputs are missing in the decoder subgraph. --- docs/ContribOperators.md | 32 +- .../transformers/beam_search_impl_whisper.h | 4 +- .../transformers/beam_search_parameters.cc | 8 +- .../cpu/transformers/generation_shared.h | 9 +- .../cpu/transformers/logits_processor.h | 81 +++-- .../transformers/generation_device_helper.cc | 12 +- .../core/graph/contrib_ops/contrib_defs.cc | 40 +-- .../transformers/models/whisper/README.md | 46 ++- .../transformers/models/whisper/benchmark.py | 22 +- .../models/whisper/benchmark_all.py | 6 + .../models/whisper/convert_to_onnx.py | 277 ++++++++++-------- .../models/whisper/requirements-cpu.txt | 2 + .../models/whisper/requirements-cuda.txt | 4 + .../models/whisper/requirements.txt | 11 + .../models/whisper/whisper_chain.py | 272 +++++++++-------- .../models/whisper/whisper_decoder.py | 2 +- .../whisper/whisper_encoder_decoder_init.py | 6 +- .../models/whisper/whisper_helper.py | 79 ++--- .../transformers/torch_onnx_export_helper.py | 3 +- .../python/transformers/test_generation.py | 19 +- .../test_whisper_timestamp_processor.py | 4 +- 21 files changed, 560 insertions(+), 379 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt create mode 100644 onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt create mode 100644 onnxruntime/python/tools/transformers/models/whisper/requirements.txt diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e7b537d6894c8..f523e97293427 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -461,7 +461,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -2252,7 +2252,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5154,7 +5154,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5743,12 +5743,14 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
beginning_timestamp_token_id : int
+
The id of the first timestamp
decoder : graph (required)
Decoder subgraph to execute in a loop.
decoder_output_cross_qk : int
If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.
decoder_start_token_id : int
-
The id of the token that indicates decoding starts.
+
The id of the token that indicates decoding starts (i.e. the start of transcription token id)
early_stopping : int
early stop or not
encoder : graph
@@ -5761,10 +5763,18 @@ This version of the operator has been available since version 1 of the 'com.micr
Must be 2 for whisper
no_repeat_ngram_size : int
no repeat ngrams size
-
no_speech_token : int
+
no_speech_token_id : int
The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.
+
no_timestamps_token_id : int
+
The id of the token that indicates no timestamps
pad_token_id : int (required)
The id of the padding token
+
start_of_lm_token_id : int
+
The id of the token that indicates LM starts
+
transcribe_token_id : int
+
The id of the transcribe task
+
translate_token_id : int
+
The id of the translate task
vocab_size : int
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
@@ -5783,11 +5793,11 @@ This version of the operator has been available since version 1 of the 'com.micr
num_return_sequences : I
The number of returned sequences in the batch. Shape is (1)
length_penalty (optional) : T
-
Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+
Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5797,7 +5807,7 @@ This version of the operator has been available since version 1 of the 'com.micr
logits_processor (optional) : I
Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
cross_qk_layer_head (optional) : I
-
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
+
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
temperature (optional) : T
@@ -5812,11 +5822,11 @@ This version of the operator has been available since version 1 of the 'com.micr
sequences_scores (optional) : T
Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
scores (optional) : T
-
Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
cross_qk (optional) : V
-
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
+
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
non_speech_probs (optional) : T
-
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
+
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]
#### Type Constraints diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 72e6d3930a548..af0904b7d6e4b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -134,8 +134,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe TensorShape no_speech_probs_shape{parameters->batch_size}; Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape); if (no_speech_probs && no_speech_probs->MutableData()) { - ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size, - "no_speech_token id out of range, it is ", parameters->no_speech_token, + ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size, + "no_speech_token_id is out of range, it is ", parameters->no_speech_token_id, ", vocab_size is ", parameters->vocab_size); this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index bb6885c3216bc..93837e785b4a4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -153,7 +153,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeWhisper)); ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper); - no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); + // Token ids are defined below in the order that they appear in the tokenizer + translate_token_id = static_cast(info.GetAttrOrDefault("translate_token_id", -1LL)); + transcribe_token_id = static_cast(info.GetAttrOrDefault("transcribe_token_id", -1LL)); + start_of_lm_token_id = static_cast(info.GetAttrOrDefault("start_of_lm_token_id", -1LL)); + no_speech_token_id = static_cast(info.GetAttrOrDefault("no_speech_token_id", -1LL)); + no_timestamps_token_id = static_cast(info.GetAttrOrDefault("no_timestamps_token_id", -1LL)); + beginning_timestamp_token_id = static_cast(info.GetAttrOrDefault("beginning_timestamp_token_id", -1LL)); cross_qk_layer_head_input_id = 12; extra_decoding_ids_input_id = 13; cross_qk_output_id = 3; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index cb62e2f7bf4da..b1dd55eb20f34 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -183,7 +183,14 @@ struct IGenerationParameters { // Parameters for whisper model bool decoder_output_cross_qk = false; gsl::span extra_decoding_ids; - int32_t no_speech_token = -1; + + // Token ids are defined below in the order that they appear in the tokenizer + int32_t translate_token_id = -1; + int32_t transcribe_token_id = -1; + int32_t start_of_lm_token_id = -1; + int32_t no_speech_token_id = -1; + int32_t no_timestamps_token_id = -1; + int32_t beginning_timestamp_token_id = -1; void* no_speech_probs = nullptr; int cross_qk_layer_head_input_id = -1; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 03d4e89ac20fe..231eb17d1a947 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -10,6 +10,7 @@ #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" +#include namespace onnxruntime { namespace contrib { @@ -34,6 +35,14 @@ struct NextTokenScores { } }; +#ifdef DEBUG_GENERATION +template +void DumpScores(const char* name, const NextTokenScores& next_token_scores) { + std::cout << name << std::endl; + ORT_UNUSED_PARAMETER(next_token_scores); +} +#endif + // Interface for all scorers for beam search or beam sample. template class ILogitsProcessor { @@ -150,19 +159,25 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} + TimestampLogitsProcessor(int end_of_text_token_id, // <|endoftext|> + int start_of_transcript_token_id, // <|startoftranscript|> + int translate_token_id, // <|translate|> + int transcribe_token_id, // <|transcribe|> + int start_of_lm_token_id, // <|startoflm|> + int no_timestamps_token_id, // <|notimestamps|> + int beginning_timestamp_token_id, // <|0.00|> + int max_initial_timestamp_index) + : end_of_text_token_id_(end_of_text_token_id), + start_of_transcript_token_id_(start_of_transcript_token_id), + translate_token_id_(translate_token_id), + transcribe_token_id_(transcribe_token_id), + start_of_lm_token_id_(start_of_lm_token_id), + no_timestamps_token_id_(no_timestamps_token_id), + beginning_timestamp_token_id_(beginning_timestamp_token_id), + max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override { - // TODO: translate_token_id_ and transcribe_token_id_ need to support both multilingual and English-only models. - const int beg_token_id_ = eos_token_id_ + 107; - const int not_token_id_ = eos_token_id_ + 106; - const int solm_token_id_ = eos_token_id_ + 105; - const int sot_token_id_ = eos_token_id_ + 1; - constexpr int translate_token_id_ = 50358; - constexpr int transcribe_token_id_ = 50359; - const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; for (int i = 0; i < batch_beam_size; i++) { @@ -174,7 +189,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { size_t sample_begin = 0; for (size_t j = 0; j < seq_length; j++) { sample_begin++; - if (sequence[j] >= beg_token_id_) { + if (sequence[j] >= beginning_timestamp_token_id_) { break; } } @@ -182,30 +197,30 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Suppress tokens for (int j = 0; j < vocab_size; j++) { // Suppress notimestamps and solm tokens - if (j == not_token_id_ || j == solm_token_id_) { + if (j == no_timestamps_token_id_ || j == start_of_lm_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } // Suppress sot, translate and transcribe tokens if (seq_length > sample_begin) { - if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + if (j == start_of_transcript_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } } } // Timestamps should be in pair except the first one - const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; - const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beginning_timestamp_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beginning_timestamp_token_id_; if (last_was_timestamp) { if (penultimate_was_timestamp) { // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated - for (int j = beg_token_id_; j < vocab_size; j++) { + for (int j = beginning_timestamp_token_id_; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } else { // If timestamp doesn't show up in pair, generate timestamp - for (int j = 0; j < eos_token_id_; j++) { + for (int j = 0; j < end_of_text_token_id_; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -214,7 +229,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Find timestamp tokens std::vector timestamps; for (const auto& word_id : sequence) { - if (word_id >= beg_token_id_) { + if (word_id >= beginning_timestamp_token_id_) { timestamps.push_back(word_id); } } @@ -231,13 +246,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { timestamp_last = timestamps.back() + 1; } - for (int j = beg_token_id_; j < timestamp_last; j++) { + for (int j = beginning_timestamp_token_id_; j < timestamp_last; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } if (seq_length == sample_begin) { - const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + const int last_allowed = beginning_timestamp_token_id_ + max_initial_timestamp_index_; for (int j = last_allowed + 1; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } @@ -247,8 +262,8 @@ class TimestampLogitsProcessor : public ILogitsProcessor { float timestamp_logprob = std::numeric_limits::lowest(); { float logsumexp = 0.0f; - const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); - for (int j = beg_token_id_; j < vocab_size; ++j) { + const float logprob_max = *std::max_element(beam_token_scores.begin() + beginning_timestamp_token_id_, beam_token_scores.end()); + for (int j = beginning_timestamp_token_id_; j < vocab_size; ++j) { if (beam_token_scores[j] > std::numeric_limits::lowest()) { logsumexp += expf(beam_token_scores[j] - logprob_max); } @@ -258,9 +273,9 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } } - const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beginning_timestamp_token_id_); if (timestamp_logprob > max_text_token_logprob) { - for (int j = 0; j < beg_token_id_; ++j) { + for (int j = 0; j < beginning_timestamp_token_id_; ++j) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -268,7 +283,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } private: - int eos_token_id_; + int end_of_text_token_id_; + int start_of_transcript_token_id_; + int translate_token_id_; + int transcribe_token_id_; + int start_of_lm_token_id_; + int no_timestamps_token_id_; + int beginning_timestamp_token_id_; int max_initial_timestamp_index_; }; @@ -330,7 +351,15 @@ class LogitsProcessorList : public ILogitsProcessorList { // Add timestamp processor for whisper model if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) { constexpr int max_initial_timestamp_index = 50; - timestamp_processor_ = std::make_unique>(parameters.eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + timestamp_processor_ = std::make_unique>(parameters.eos_token_id, + parameters.decoder_start_token_id, + parameters.translate_token_id, + parameters.transcribe_token_id, + parameters.start_of_lm_token_id, + parameters.no_timestamps_token_id, + parameters.beginning_timestamp_token_id, + max_initial_timestamp_index); processor_list_.push_back(timestamp_processor_.get()); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index bba30805ae1be..7adc2fe0a67ea 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -424,7 +424,7 @@ Status ProcessLogits(const OrtValue& logits, // const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper); if (step == 1 && is_whisper_model && parameters->no_speech_probs) { cuda::LaunchSaveNoSpeechProbs( - (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream); + (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token_id, cuda_stream); } // NOTE: currently we treat extra decoding ids are same @@ -469,7 +469,15 @@ Status ProcessLogits(const OrtValue& logits, // cudaMemcpyDeviceToHost, cuda_stream)); constexpr int max_initial_timestamp_index = 50; - onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, + parameters->decoder_start_token_id, + parameters->translate_token_id, + parameters->transcribe_token_id, + parameters->start_of_lm_token_id, + parameters->no_timestamps_token_id, + parameters->beginning_timestamp_token_id, + max_initial_timestamp_index); onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 27c968a59eb91..e33ce20737f80 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1163,7 +1163,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) @@ -1188,7 +1188,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .SetDoc("Beam Search for whisper model, especiall with cross_qk features etc.") .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) - .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts (i.e. the start of transcription token id)", AttributeProto::INT, static_cast(-1)) + .Attr("translate_token_id", "The id of the translate task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("start_of_lm_token_id", "The id of the token that indicates LM starts", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_speech_token_id", + "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_timestamps_token_id", "The id of the token that indicates no timestamps", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("beginning_timestamp_token_id", "The id of the first timestamp", AttributeProto::INT, OPTIONAL_VALUE) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast(2)) @@ -1203,27 +1211,24 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "If not provided, it will be inferred from the decoder subgraph's output shape", AttributeProto::INT, static_cast(-1)) .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("no_speech_token", - "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", - AttributeProto::INT, OPTIONAL_VALUE) .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") .Input(5, "length_penalty", - "Exponential penalty to the length. Default value 1.0 means no penalty." - "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." + "Exponential penalty to the length. Default value 1.0 means no penalty. " + "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. " "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) .Input(12, "cross_qk_layer_head", - "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all" + "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all " "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]", "I", OpSchema::Optional) .Input(13, "extra_decoding_ids", @@ -1235,20 +1240,19 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", - "Processed beam scores for each vocabulary token at each generation step." - "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." + "Processed beam scores for each vocabulary token at each generation step. " + "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. " "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) .Output(3, "cross_qk", "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, " - "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers," - "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]." + "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, " + "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. " "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", "V", OpSchema::Optional) .Output(4, "non_speech_probs", - "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." - "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." - "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. " + "The shape of non_speech_probs is [B]", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") @@ -1322,7 +1326,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") @@ -1363,7 +1367,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 02100266200f8..7a678f2734ade 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -1,5 +1,22 @@ # Whisper +## Prerequisites + +Please note the package versions needed for using Whisper in the `requirements.txt` file that fits your scenario. +- `requirements-cpu.txt` + - For running Whisper on CPU +- `requirements-cuda.txt` + - For running Whisper on CUDA + - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. +- `requirements.txt` + - Package versions needed in each of the above files + +In addition to the above packages, you will need to install `ffmpeg` on your machine. Visit the [FFmpeg website](https://ffmpeg.org/) for details. You can also install it natively using package managers. + +- Linux: `sudo apt-get install ffmpeg` +- MacOS: `sudo brew install ffmpeg` +- Windows: Download from website + ## Exporting Whisper with Beam Search There are several ways to export Whisper with beam search (using Whisper tiny as an example). @@ -10,10 +27,10 @@ There are several ways to export Whisper with beam search (using Whisper tiny as # From source $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers/ -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format # From wheel -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format ``` ### Option 2: end-to-end model from [Olive](https://github.com/microsoft/Olive/tree/main/examples/whisper) @@ -39,40 +56,49 @@ model.save_pretrained(model_name.split("/")[-1] + "-onnx") Here are some additional examples for exporting Whisper with beam search. +To see all available options +``` +# From source: +$ python3 -m models.whisper.convert_to_onnx --help + +# From wheel: +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx --help +``` + Export with Forced Decoder Input Ids ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids ``` Export + Optimize for FP32 ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32 # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32 ``` Export + Optimize for FP16 and GPU ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision ``` Export + Quantize for INT8 ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer ``` ## Benchmark Whisper diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 759ae6d14f184..e57385aa6db8f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import argparse import ast import datetime @@ -54,6 +60,8 @@ def load_via_numpy(): inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32) if args.has_logits_processor: inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32) + if args.has_temperature: + inputs["temperature"] = np.array([args.temperature], dtype=np.float32) # Measure time taken to load audio file logger.info(f"Load audio: {args.audio_path}") @@ -163,6 +171,7 @@ def get_model(args: argparse.Namespace): def time_fn(args, fn, inputs): warmup_inputs = inputs[0] if type(inputs) is tuple else inputs benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs + torch_device = torch.device(args.target_device) # Warm up warmup_range = ( @@ -180,7 +189,7 @@ def time_fn(args, fn, inputs): # Benchmark if args.device != "cpu": - torch.cuda.synchronize() + torch.cuda.synchronize(torch_device) start_time = time.time() bench_range = ( @@ -192,7 +201,7 @@ def time_fn(args, fn, inputs): fn(benchmark_inputs) if args.device != "cpu": - torch.cuda.synchronize() + torch.cuda.synchronize(torch_device) end_time = time.time() # Newline print after trange in order to print metrics on new lines without progress bar on same line @@ -500,7 +509,13 @@ def parse_args(): "--logits-processor", type=int, default=1, - help="Type of logits processor to use. See `BeamSearch` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/contrib_ops/contrib_defs.cc for details.", + help="Whether to use timestamps logits processor or not (0 for false, 1 for true).", + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="Temperature value for generation.", ) # Args for accessing detailed info @@ -581,6 +596,7 @@ def main(): args.has_audio_stream = "audio_stream" in ort_model_inputs setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010 setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010 + setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010 if args.decoder_input_ids == []: args.decoder_input_ids = [config.decoder_start_token_id] diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index d205a2d340721..814b0dd1ef6ac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import argparse import datetime import json diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index bb697fe1e1506..35211aab272e4 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -28,17 +28,25 @@ def parse_arguments(argv=None): parser = argparse.ArgumentParser() - pretrained_models = PRETRAINED_WHISPER_MODELS - parser.add_argument( + conversion_args = parser.add_argument_group("Conversion Process Args") + optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)") + optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)") + quant_args = parser.add_argument_group("INT8 Quantization Args") + + ################################# + # Conversion options for Whisper + ################################# + + conversion_args.add_argument( "-m", "--model_name_or_path", required=False, default=PRETRAINED_WHISPER_MODELS[0], type=str, - help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models), + help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS), ) - parser.add_argument( + conversion_args.add_argument( "--model_impl", required=False, default="hf", @@ -47,7 +55,7 @@ def parse_arguments(argv=None): help="Select implementation for export of encoder and decoder subgraphs", ) - parser.add_argument( + conversion_args.add_argument( "--cache_dir", required=False, type=str, @@ -55,7 +63,7 @@ def parse_arguments(argv=None): help="Directory to cache pre-trained models", ) - parser.add_argument( + conversion_args.add_argument( "--output", required=False, type=str, @@ -63,19 +71,24 @@ def parse_arguments(argv=None): help="Output directory", ) - parser.add_argument( + conversion_args.add_argument( "-o", "--optimize_onnx", required=False, action="store_true", help="Use optimizer.py to optimize onnx model", ) - parser.set_defaults(optimize_onnx=False) + conversion_args.set_defaults(optimize_onnx=False) - parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference") - parser.set_defaults(use_gpu=False) + conversion_args.add_argument( + "--use_gpu", + required=False, + action="store_true", + help="Use GPU for model inference", + ) + conversion_args.set_defaults(use_gpu=False) - parser.add_argument( + conversion_args.add_argument( "-p", "--precision", required=False, @@ -85,221 +98,226 @@ def parse_arguments(argv=None): help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization", ) - parser.add_argument("--verbose", required=False, action="store_true") - parser.set_defaults(verbose=False) - - parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true") - parser.set_defaults(use_external_data_format=False) - - parser.add_argument( - "-s", - "--use_decoder_start_token", + conversion_args.add_argument( + "--use_int64_inputs", required=False, action="store_true", - help="Use config.decoder_start_token_id. Otherwise, add an extra graph input to \ - the encoder-decoder-init subgraph for decoder_input_ids.", + help="Use int64 instead of int32 for input_ids and attention_mask.", ) - parser.set_defaults(use_decoder_start_token=False) + conversion_args.set_defaults(use_int64_inputs=False) - parser.add_argument( - "-f", - "--use_forced_decoder_ids", + conversion_args.add_argument( + "--disable_auto_mixed_precision", required=False, action="store_true", - help="Use decoder_input_ids as an extra graph input to the beam search op", + help="Use pure fp16 instead of mixed precision", ) - parser.set_defaults(use_forced_decoder_ids=False) + conversion_args.set_defaults(disable_auto_mixed_precision=False) - parser.add_argument( - "-l", - "--use_logits_processor", + conversion_args.add_argument( + "-r", + "--provider", required=False, - action="store_true", - help="Use logits_processor as an extra graph input to enable specific logits processing", + type=str, + default="cpu", + choices=list(PROVIDERS.keys()), + help="Provider to benchmark. Default is CPUExecutionProvider.", ) - parser.set_defaults(use_specific_logits_processor=False) - parser.add_argument( - "-v", - "--use_vocab_mask", + conversion_args.add_argument( + "--verbose", required=False, action="store_true", - help="Use vocab_mask as an extra graph input to enable specific logits processing", + help="Enable verbose logging", ) - parser.set_defaults(use_vocab_mask=False) + conversion_args.set_defaults(verbose=False) - parser.add_argument( - "-u", - "--use_prefix_vocab_mask", + conversion_args.add_argument( + "-e", + "--use_external_data_format", required=False, action="store_true", - help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", + help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.", ) - parser.set_defaults(use_prefix_vocab_mask=False) + conversion_args.set_defaults(use_external_data_format=False) - parser.add_argument( + conversion_args.add_argument( "-w", "--overwrite", required=False, action="store_true", - help="overwrite existing ONNX model", + help="Overwrite existing ONNX model", ) - parser.set_defaults(overwrite=False) + conversion_args.set_defaults(overwrite=False) - parser.add_argument( - "--disable_auto_mixed_precision", + conversion_args.add_argument( + "--separate_encoder_and_decoder_init", required=False, action="store_true", - help="use pure fp16 instead of mixed precision", + help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.", ) - parser.set_defaults(disable_auto_mixed_precision=False) + conversion_args.set_defaults(separate_encoder_and_decoder_init=False) - parser.add_argument( - "--separate_encoder_and_decoder_init", + conversion_args.add_argument( + "--no_beam_search_op", required=False, action="store_true", - help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.", + help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.", ) - parser.set_defaults(separate_encoder_and_decoder_init=False) + conversion_args.set_defaults(no_beam_search_op=False) - parser.add_argument( - "--use_int64_inputs", + conversion_args.add_argument( + "--state_dict_path", + type=str, + default="", + help="Filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", + ) + + ############################################################# + # Optional inputs for Whisper + # (listed below in the order that WhisperBeamSearch expects) + ############################################################# + + optional_inputs.add_argument( + "-v", + "--use_vocab_mask", required=False, action="store_true", - help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.", + help="Use vocab_mask as an extra graph input to enable specific logits processing", ) - parser.set_defaults(use_int64_inputs=False) + optional_inputs.set_defaults(use_vocab_mask=False) - parser.add_argument( - "--chain_model", + optional_inputs.add_argument( + "-u", + "--use_prefix_vocab_mask", required=False, action="store_true", - help="Produce beam search model with chained encdecinit and decoder.", + help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", ) - parser.set_defaults(chain_model=True) + optional_inputs.set_defaults(use_prefix_vocab_mask=False) - parser.add_argument( - "--use_whisper_beamsearch", + optional_inputs.add_argument( + "-f", + "--use_forced_decoder_ids", required=False, action="store_true", - help="When chain_model, using WhisperBeamSearch operator rather than BeamSearch operator. \ - It will be set to true when collect_cross_qk, extra_decoding_ids or output_no_speech_probs is set.", + help="Use decoder_input_ids as an extra graph input to the beam search op", ) - parser.set_defaults(use_whisper_beamsearch=False) + optional_inputs.set_defaults(use_forced_decoder_ids=False) - parser.add_argument( - "--extra_decoding_ids", + optional_inputs.add_argument( + "-l", + "--use_logits_processor", required=False, action="store_true", - help="Need extra starting decoding ids for some feature like cross qk. Default if false.", + help="Use logits_processor as an extra graph input to enable specific logits processing", ) - parser.set_defaults(extra_decoding_ids=False) + optional_inputs.set_defaults(use_specific_logits_processor=False) - parser.add_argument( + optional_inputs.add_argument( "--collect_cross_qk", required=False, action="store_true", help="Beam search model collect stacked cross QK.", ) - parser.set_defaults(collect_cross_qk=False) + optional_inputs.set_defaults(collect_cross_qk=False) - parser.add_argument( - "--output_cross_qk", + optional_inputs.add_argument( + "--extra_decoding_ids", required=False, action="store_true", - help="Beam search model output collected qk as output. Also hint collect_cross_qk", + help="Need extra starting decoding ids for some feature like cross qk. Default if false.", ) - parser.set_defaults(output_cross_qk=False) + optional_inputs.set_defaults(extra_decoding_ids=False) - parser.add_argument( - "--no_speech_token_id", - default=50362, + optional_inputs.add_argument( + "-t", + "--use_temperature", + required=False, + action="store_true", + help="Use temperature as an extra graph input for the WhisperBeamSearch op", + ) + optional_inputs.set_defaults(use_temperature=False) + + optional_inputs.add_argument( + "--no_repeat_ngram_size", type=int, - help="specify no_speech_token_id. Default is 50362. if >= 0, will be add into beam search attr. \ - Note that default value maybe different between the multilingual and English-only models.", + default=0, + help="default to 0", ) - parser.add_argument( - "--output_no_speech_probs", + ############################################################# + # Optional outputs for Whisper + # (listed below in the order that WhisperBeamSearch expects) + ############################################################# + + optional_outputs.add_argument( + "--output_sequence_scores", required=False, action="store_true", - help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", + help="Beam search model output scores for each generated sequence.", ) - parser.set_defaults(output_no_speech_probs=False) + optional_outputs.set_defaults(output_sequence_scores=False) - parser.add_argument( + optional_outputs.add_argument( "--output_scores", required=False, action="store_true", help="Beam search model output scores over vocab per generated token.", ) - parser.set_defaults(output_scores=False) + optional_outputs.set_defaults(output_scores=False) - parser.add_argument( - "--output_sequence_scores", + optional_outputs.add_argument( + "--output_cross_qk", required=False, action="store_true", - help="Beam search model output scores for each generated sequence.", + help="Beam search model output collected qk as output. Also hint collect_cross_qk", ) - parser.set_defaults(output_sequence_scores=False) + optional_outputs.set_defaults(output_cross_qk=False) - parser.add_argument( + optional_outputs.add_argument( "--cross_qk_onnx_model", required=False, type=str, default=None, - help="the model which consume cross_qk.", + help="The model which consumes cross_qk outputs.", ) - parser.add_argument( - "--beam_output_model", - type=str, - default="whisper_beamsearch.onnx", - help="default name is whisper_beamsearch.onnx.", + optional_outputs.add_argument( + "--output_no_speech_probs", + required=False, + action="store_true", + help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", ) + optional_outputs.set_defaults(output_no_speech_probs=False) - parser.add_argument( + ################################### + # Quantization options for Whisper + ################################### + + quant_args.add_argument( "--quantize_embedding_layer", required=False, action="store_true", help="Quantize MatMul, GEMM, and Gather.", ) - parser.set_defaults(quantize_embedding_layer=False) + quant_args.set_defaults(quantize_embedding_layer=False) - parser.add_argument( + quant_args.add_argument( "--quantize_per_channel", required=False, action="store_true", help="Quantize weights per each channel.", ) - parser.set_defaults(quantize_per_channel=False) + quant_args.set_defaults(quantize_per_channel=False) - parser.add_argument( + quant_args.add_argument( "--quantize_reduce_range", required=False, action="store_true", help="Quantize weights with 7 bits.", ) - parser.set_defaults(quantize_reduce_range=False) - - parser.add_argument("--no_repeat_ngram_size", type=int, default=0, help="default to 0") - - parser.add_argument( - "--state_dict_path", - type=str, - default="", - help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", - ) - - parser.add_argument( - "-r", - "--provider", - required=False, - type=str, - default="cpu", - choices=list(PROVIDERS.keys()), - help="Provider to benchmark. Default is CPUExecutionProvider.", - ) + quant_args.set_defaults(quantize_reduce_range=False) args = parser.parse_args(argv) args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk @@ -317,7 +335,7 @@ def export_onnx_models( optimize_onnx, precision, verbose, - use_decoder_start_token: bool = False, + use_forced_decoder_ids: bool = False, merge_encoder_and_decoder_init: bool = True, overwrite: bool = False, disable_auto_mixed_precision: bool = False, @@ -362,7 +380,6 @@ def export_onnx_models( onnx_path, verbose, use_external_data_format, - use_decoder_input_ids=not use_decoder_start_token, use_int32_inputs=use_int32_inputs, ) else: @@ -406,7 +423,7 @@ def export_onnx_models( extra_options={"MatMulConstBOnly": True}, ) else: - logger.info(f"Skip optimizing: existed ONNX model {onnx_path}") + logger.info(f"Skip optimizing: existing ONNX model {onnx_path}") else: output_path = onnx_path @@ -449,7 +466,7 @@ def main(argv=None): args.optimize_onnx, args.precision, args.verbose, - args.use_decoder_start_token, + args.use_forced_decoder_ids, not args.separate_encoder_and_decoder_init, args.overwrite, args.disable_auto_mixed_precision, @@ -462,7 +479,7 @@ def main(argv=None): ) max_diff = 0 - if args.chain_model: + if not args.no_beam_search_op: logger.info("Chaining model ... :") args.beam_model_output_dir = WhisperHelper.get_onnx_path( output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt new file mode 100644 index 0000000000000..db2cd95324328 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt @@ -0,0 +1,2 @@ +-r requirements.txt +onnxruntime>=1.17.1 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt new file mode 100644 index 0000000000000..9bd215de9bc09 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt @@ -0,0 +1,4 @@ +-r requirements.txt +# Please manually install torch>=1.13.0 with CUDA enabled for the CUDA version installed in your system. +# Instructions can be found here: https://pytorch.org/get-started/locally/ +onnxruntime-gpu>=1.17.1 diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt new file mode 100644 index 0000000000000..c307a3665f8a0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -0,0 +1,11 @@ +torch>=1.13.0 +transformers>=4.24.0 +openai-whisper +ffmpeg-python +datasets +soundfile +librosa +optimum +onnxruntime-extensions>=0.9.0 +protobuf==3.20.2 +numpy==1.23.3 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index a74666b7af297..14691da4ad643 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import logging import os @@ -9,7 +15,7 @@ update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha, ) from onnx import TensorProto, helper -from transformers import WhisperConfig +from transformers import WhisperConfig, WhisperTokenizer logger = logging.getLogger(__name__) @@ -23,11 +29,22 @@ def verify_inputs(beam_inputs, graph_inputs): assert graph_input.name in beam_input +def clean_list(arr, remove_all_strings=True): + if remove_all_strings: + # Remove all empty strings in list + return list(filter(lambda elm: elm != "", arr)) + + # Remove empty strings at end of list + while len(arr) > 0: + if arr[-1] == "": + arr.pop() + else: + break + return arr + + def chain_model(args): - # Load encoder/decoder and insert necessary (but unused) graph inputs expected by BeamSearch op or WhisperBeamSearch op - args.use_whisper_beamsearch = ( - args.use_whisper_beamsearch or args.collect_cross_qk or args.output_no_speech_probs or args.extra_decoding_ids - ) + # Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op encoder_model = onnx.load_model(args.encoder_path, load_external_data=True) encoder_model.graph.name = "encoderdecoderinit subgraph" @@ -35,7 +52,10 @@ def chain_model(args): decoder_model.graph.name = "decoder subgraph" config = WhisperConfig.from_pretrained(args.model_name_or_path) + tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path) + # Create inputs/outputs for WhisperBeamSearch op + temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature" beam_inputs = [ "input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features", "max_length", @@ -44,38 +64,27 @@ def chain_model(args): "num_return_sequences", "length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty", "repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty", - "vocab_mask" if args.use_prefix_vocab_mask else "", + "vocab_mask" if args.use_vocab_mask else "", "prefix_vocab_mask" if args.use_prefix_vocab_mask else "", "", # attention mask "decoder_input_ids" if args.use_forced_decoder_ids else "", "logits_processor" if args.use_logits_processor else "", + "cross_qk_layer_head" if args.collect_cross_qk else "", + "extra_decoding_ids" if args.extra_decoding_ids else "", + temperature_name if args.use_temperature else "", ] - beam_outputs = ["sequences"] - if args.output_sequence_scores: - beam_outputs.append("sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores") - if args.output_scores: - beam_outputs.append("scores_fp16" if args.precision == Precision.FLOAT16 else "scores") - - if args.use_whisper_beamsearch: - assert len(beam_inputs) == 12 - beam_inputs.extend( - [ - "cross_qk_layer_head" if args.collect_cross_qk else "", - "extra_decoding_ids" if args.extra_decoding_ids else "", - ] - ) - if args.collect_cross_qk: - while len(beam_outputs) < 3: - beam_outputs.extend([""]) - beam_outputs.extend(["cross_qk"]) - if args.output_no_speech_probs: - while len(beam_outputs) < 4: - beam_outputs.extend([""]) - beam_outputs.extend(["no_speech_probs_beam"]) - - input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None - output_scores_cast_node = output_sequence_scores_cast_node = None + sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores" + scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores" + beam_outputs = [ + "sequences", + sequence_scores_name if args.output_sequence_scores else "", + scores_name if args.output_scores else "", + "cross_qk" if args.collect_cross_qk else "", + "no_speech_probs_beam" if args.output_no_speech_probs else "", + ] + + graph_nodes = [] if args.precision == Precision.FLOAT16: input_features_cast_node = helper.make_node( "Cast", @@ -98,6 +107,18 @@ def chain_model(args): name="CastRepetitionPenaltyToFp16", to=TensorProto.FLOAT16, ) + graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node]) + + if args.use_temperature: + temp_cast_node = helper.make_node( + "Cast", + inputs=["temperature"], + outputs=["temperature_fp16"], + name="temperature_to_fp16", + to=TensorProto.FLOAT16, + ) + graph_nodes.append(temp_cast_node) + if args.output_sequence_scores: output_sequence_scores_cast_node = helper.make_node( "Cast", @@ -106,6 +127,8 @@ def chain_model(args): name="CastOutputSequenceScoresToFp32", to=TensorProto.FLOAT, ) + graph_nodes.append(output_sequence_scores_cast_node) + if args.output_scores: output_scores_cast_node = helper.make_node( "Cast", @@ -114,26 +137,38 @@ def chain_model(args): name="CastScoresToFp32", to=TensorProto.FLOAT, ) - - operator_type = "WhisperBeamSearch" if args.use_whisper_beamsearch else "BeamSearch" - node = helper.make_node(operator_type, inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode") - node.domain = "com.microsoft" - node.attribute.extend( - [ - helper.make_attribute("eos_token_id", config.eos_token_id), - helper.make_attribute("pad_token_id", config.pad_token_id), - helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), - helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), - helper.make_attribute("early_stopping", True), - helper.make_attribute("model_type", 2), - ] + graph_nodes.append(output_scores_cast_node) + + # Create WhisperBeamSearch op + beam_search_attrs = [ + helper.make_attribute("eos_token_id", config.eos_token_id), + helper.make_attribute("pad_token_id", config.pad_token_id), + helper.make_attribute( + "decoder_start_token_id", config.decoder_start_token_id + ), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0] + helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]), + helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]), + helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]), + helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0]) + if args.output_no_speech_probs + else "", + helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]), + helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]), + helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + helper.make_attribute("early_stopping", True), + helper.make_attribute("model_type", 2), + helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "", + ] + node = helper.make_node( + "WhisperBeamSearch", + inputs=clean_list(beam_inputs, remove_all_strings=False), + outputs=clean_list(beam_outputs, remove_all_strings=False), + name="BeamSearch", + domain="com.microsoft", ) - if args.use_whisper_beamsearch: - if args.collect_cross_qk: - node.attribute.extend([helper.make_attribute("decoder_output_cross_qk", 1)]) - if args.no_speech_token_id >= 0: - node.attribute.extend([helper.make_attribute("no_speech_token", args.no_speech_token_id)]) + node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True)) + # Graph inputs input_features = helper.make_tensor_value_info( "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"] ) @@ -143,73 +178,63 @@ def chain_model(args): num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1]) length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) + vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) + prefix_vocab_mask = helper.make_tensor_value_info( + "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] + ) + decoder_input_ids = helper.make_tensor_value_info( + "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] + ) + logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) + cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2]) + extra_decoding_ids = helper.make_tensor_value_info( + "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] + ) + temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) - graph_inputs = [ - input_features, - max_length, - min_length, - num_beams, - num_return_sequences, - length_penalty, - repetition_penalty, - ] - if args.use_vocab_mask: - vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) - graph_inputs.append(vocab_mask) - - if args.use_prefix_vocab_mask: - prefix_vocab_mask = helper.make_tensor_value_info( - "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] - ) - graph_inputs.append(prefix_vocab_mask) - - if args.use_forced_decoder_ids: - decoder_input_ids = helper.make_tensor_value_info( - "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] - ) - graph_inputs.append(decoder_input_ids) - - if args.use_logits_processor: - logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) - graph_inputs.append(logits_processor) - - if args.collect_cross_qk: - cross_qk_layer_head = helper.make_tensor_value_info( - "cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2] - ) - graph_inputs.append(cross_qk_layer_head) - - if args.extra_decoding_ids: - extra_decoding_ids = helper.make_tensor_value_info( - "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] - ) - graph_inputs.append(extra_decoding_ids) + graph_inputs = clean_list( + [ + input_features, + max_length, + min_length, + num_beams, + num_return_sequences, + length_penalty, + repetition_penalty, + vocab_mask if args.use_vocab_mask else "", + prefix_vocab_mask if args.use_prefix_vocab_mask else "", + decoder_input_ids if args.use_forced_decoder_ids else "", + logits_processor if args.use_logits_processor else "", + cross_qk_layer_head if args.collect_cross_qk else "", + extra_decoding_ids if args.extra_decoding_ids else "", + temperature if args.use_temperature else "", + ] + ) - # graph outputs + # Graph outputs sequences = helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"] ) - graph_outputs = [sequences] - if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk): - cross_qk = helper.make_tensor_value_info( - "cross_qk", - TensorProto.FLOAT, - ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], - ) - graph_outputs.extend([cross_qk]) - - if args.output_no_speech_probs: - no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([no_speech_probs]) - - if args.output_sequence_scores: - sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([sequence_scores]) + sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) + scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) + cross_qk = helper.make_tensor_value_info( + "cross_qk", + TensorProto.FLOAT, + ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], + ) + no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - if args.output_scores: - scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([scores]) + graph_outputs = clean_list( + [ + sequences, + sequence_scores if args.output_sequence_scores else "", + scores if args.output_scores else "", + cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "", + no_speech_probs if args.output_no_speech_probs else "", + ] + ) + # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference if hasattr(args, "use_gpu") and args.use_gpu: if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph): logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!") @@ -230,19 +255,7 @@ def chain_model(args): opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)] - graph_nodes = ( - [ - input_features_cast_node, - len_pen_cast_node, - rep_pen_cast_node, - node, - output_sequence_scores_cast_node, - output_scores_cast_node, - ] - if args.precision == Precision.FLOAT16 - else [node] - ) - graph_nodes = [node for node in graph_nodes if node is not None] + graph_nodes.append(node) if args.output_no_speech_probs: prob_cast_node = helper.make_node( "Cast", @@ -251,9 +264,16 @@ def chain_model(args): name="no_speech_probs_cast_to_fp32", to=TensorProto.FLOAT, ) - graph_nodes.extend([prob_cast_node]) - - beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers) + graph_nodes.append(prob_cast_node) + + # Make graph with WhisperBeamSearch op + beam_graph = helper.make_graph( + graph_nodes, + name="WhisperBeamSearch Graph", + inputs=graph_inputs, + outputs=graph_outputs, + initializer=initializers, + ) beam_graph_input_names = [gi.name for gi in graph_inputs] beam_graph_output_names = [go.name for go in graph_outputs] @@ -287,10 +307,12 @@ def chain_model(args): ir_version=decoder_model.ir_version, ) + # Save WhisperBeamSearch graph and external data if os.path.isfile(args.beam_model_output_dir): logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}") os.remove(args.beam_model_output_dir) os.remove(args.beam_model_output_dir + ".data") + onnx.save( beam_model, args.beam_model_output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index 0d69960a095ac..93fd64c9eb7d3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -170,7 +170,7 @@ def create_dummy( cross_attention_past_shape = [ batch_size, num_attention_heads, - past_decode_sequence_length, + encode_sequence_length, head_size, ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 351173f525727..832f692e9980d 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -75,7 +75,7 @@ def create_dummy( config: WhisperConfig, batch_size: int, encode_sequence_length: int, - use_decoder_input_ids: int, + use_decoder_input_ids: bool, device: torch.device, use_int32_inputs: bool = False, ): # -> WhisperEncoderDecoderInitInputs: @@ -125,7 +125,7 @@ def export_onnx( model.config, batch_size=2, encode_sequence_length=3000, - use_decoder_input_ids=use_decoder_input_ids, + use_decoder_input_ids=True, device=device, use_int32_inputs=use_int32_inputs, ) @@ -159,7 +159,7 @@ def export_onnx( hidden_size = str(model.config.d_model) head_size = str(model.config.d_model // model.config.encoder_attention_heads) dynamic_axes = { - "encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"}, + "encoder_input_ids": {0: "batch_size", 1: "feature_size"}, "encoder_hidden_states": { 0: "batch_size", 1: "encode_sequence_length", diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index e2dc79ca247ce..1b47b9426d983 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -6,12 +6,14 @@ import logging import os -import sys from pathlib import Path from typing import Dict, Tuple, Union import numpy as np import torch +from float16 import float_to_float16_max_diff +from onnx_model import OnnxModel +from optimizer import optimize_model from packaging import version from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor from transformers import __version__ as transformers_version @@ -21,24 +23,20 @@ from onnxruntime import InferenceSession -sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from float16 import float_to_float16_max_diff -from onnx_model import OnnxModel -from optimizer import optimize_model - logger = logging.getLogger(__name__) PRETRAINED_WHISPER_MODELS = [ "whisper-tiny", "whisper-tiny.en", + "whisper-base", + "whisper-base.en", "whisper-small", "whisper-small.en", "whisper-medium", "whisper-medium.en", - "whisper-base", - "whisper-base.en", "whisper-large", "whisper-large-v2", + "whisper-large-v3", ] @@ -346,7 +344,12 @@ def verify_onnx( ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features - batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 + start_id = [config.decoder_start_token_id] # ex: [50258] + prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") + prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] + forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] + + batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1 length_penalty, repetition_penalty = 1.0, 1.0 inputs = { "input_features": input_features.to(device), @@ -383,43 +386,51 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - raw_input_ids = ( - [[config.decoder_start_token_id]] - if use_extra_decoding_ids - else [[config.decoder_start_token_id, 50259, 50359, 50363]] - ) + raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids] inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) elif name == "cross_qk_layer_head": inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) elif name == "extra_decoding_ids": - inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) + inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0) + elif name == "temperature": + inputs[name] = np.array([1.0], dtype=ort_to_np[dtype]) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] - if pt_outputs.shape != ort_outputs.shape: - logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") + expected_transcription_no_comma = ( + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." + ) + expected_transcription_with_comma = ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ) + expected_transcription_with_quote_and_comma = ( + ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ) + expected_transcription_options = { + expected_transcription_no_comma, + expected_transcription_with_comma, + expected_transcription_with_quote_and_comma, + } + pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0] + ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0] - diff = pt_outputs - ort_outputs - max_diff = max(diff.min(), diff.max(), key=abs) + parity = ( + pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options + ) + max_diff = 0 - if max_diff > 0: - # For ONNX Runtime INT8 model - pt_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." - ) - pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) - ort_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." - ) - ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) + if not parity: + if pt_outputs.shape != ort_outputs.shape: + diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])] + else: + diff = pt_outputs - ort_outputs + max_diff = max(diff.min(), diff.max(), key=abs) - parity = ( - pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0] - ) - if parity: - max_diff = 0 + if max_diff != 0: + logger.warning(f"PyTorch outputs: {pt_transcription}") + logger.warning(f"ONNX Runtime outputs: {ort_transcription}") return max_diff diff --git a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py index f3e67930adbff..66f24c47f6cdb 100644 --- a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py +++ b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import torch +from torch._C._onnx import OperatorExportTypes TrainingMode = torch.onnx.TrainingMode from packaging.version import Version # noqa: E402 @@ -18,7 +19,7 @@ def torch_onnx_export( training=TrainingMode.EVAL, input_names=None, output_names=None, - operator_export_type=None, + operator_export_type=OperatorExportTypes.ONNX, opset_version=None, _retain_param_name=None, do_constant_folding=True, diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index 40ea8cf774918..33ec1bd7728fe 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -381,22 +381,23 @@ def test_logits_processor(self): @pytest.mark.slow def test_cross_qk_overall(self): - decoder_input_ids = [ - "--chain_model", - "--collect_cross_qk", - "--output_cross_qk", - "--use_forced_decoder_ids", - "--extra_decoding_ids", - "--output_no_speech_probs", + cross_qk_input_args = [ "--use_vocab_mask", "--use_prefix_vocab_mask", + "--use_forced_decoder_ids", "--use_logits_processor", + "--collect_cross_qk", + "--extra_decoding_ids", ] - self.run_configs(decoder_input_ids) + cross_qk_output_args = [ + "--output_cross_qk", + "--output_no_speech_probs", + ] + self.run_configs(cross_qk_input_args + cross_qk_output_args) @pytest.mark.slow def test_openai_impl_whisper(self): - optional_args = ["--model_impl", "openai", "--chain_model", "--use_whisper_beamsearch"] + optional_args = ["--model_impl", "openai"] self.run_configs(optional_args) diff --git a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py index 77ce09d7e793b..7892000ae45a0 100644 --- a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py +++ b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py @@ -50,7 +50,7 @@ def run_timestamp(self, provider: str): ort_out = sess.run(None, ort_inputs) ort_out_tensor = torch.from_numpy(ort_out[0]) ort_transcription = processor.batch_decode( - ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True + ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True, decode_with_timestamps=True ) print(ort_transcription) expected_transcription = [ @@ -58,7 +58,7 @@ def run_timestamp(self, provider: str): "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", "offsets": [ { - "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", "timestamp": (0.0, 5.44), } ], From 4874a41008138ecc1f26e9cd17e5d9d7febb29aa Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 16 Feb 2024 16:59:43 -0800 Subject: [PATCH 094/207] [QNN EP] Update default QNN SDK to 2.19.2.240210 (#19546) ### Description Updates the default QNN SDK version to 2.19.2.240210. ### Motivation and Context Build and test the latest version of QNN SDK in our pipelines. --- .../android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 2 +- tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml | 2 +- .../github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml | 2 +- .../github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml | 2 +- tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 2b181810b0788..d37266a8e96d8 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -31,7 +31,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101 + default: qnn-v2.19.2.240210 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 0312b70d2b1d5..8fa5bdbf90931 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101 + default: qnn-v2.19.2.240210 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index b0509467e1689..9a38513d04a79 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: qnn-v2.18.0.240101_win + default: qnn-v2.19.2.240210_win - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 13d4589a67cdc..dc861f7f1ed79 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101_win + default: qnn-v2.19.2.240210_win jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 6246bb83566e5..534d5c6d6135b 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101_win + default: qnn-v2.19.2.240210_win jobs: - job: 'build' From 06269a3952fb1759d93235b9d66f9beb10ae8663 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 16 Feb 2024 18:28:27 -0800 Subject: [PATCH 095/207] [js/webgpu] allow uint8 tensors for webgpu (#19545) ### Description allow uint8 tensors for webgpu --- js/common/lib/tensor-impl.ts | 2 +- js/common/lib/tensor.ts | 2 +- js/web/lib/wasm/wasm-common.ts | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index e3e2b9c728556..de18126a9d0ae 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -103,7 +103,7 @@ export class Tensor implements TensorInterface { } case 'gpu-buffer': { if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' && - type !== 'bool')) { + type !== 'uint8' && type !== 'bool')) { throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`); } this.gpuBufferData = arg0.gpuBuffer; diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 6c08d1fe8e057..d5da33640dc7d 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -135,7 +135,7 @@ export declare namespace Tensor { /** * supported data types for constructing a tensor from a WebGPU buffer */ - export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool'; + export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool'; /** * represent where the tensor data is stored diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index b9eff45e890c4..93910af1f1bf0 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -169,7 +169,8 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro * Check whether the given tensor type is supported by GPU buffer */ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || - type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint8' || + type === 'bool'; /** * Map string data location to integer value From dfeda9019cfed2d6df5bcacc54269c7de481bdee Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Sat, 17 Feb 2024 09:19:17 -0800 Subject: [PATCH 096/207] [JS/WebGPU] Add MatMulNBits (#19446) ### Description Add MatMulNBits to support MatMul using 4-bit quantized weights ### Motivation and Context --- js/web/docs/webgpu-operators.md | 1 + js/web/lib/wasm/jsep/util.ts | 28 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 2 + .../lib/wasm/jsep/webgpu/ops/matmulnbits.ts | 184 ++ js/web/test/data/ops/matmulnbits.jsonc | 1527 +++++++++++++++++ js/web/test/suite-test-list.jsonc | 1 + .../contrib_ops/js/js_contrib_kernels.cc | 16 +- .../js/quantization/matmul_nbits.cc | 25 + .../js/quantization/matmul_nbits.h | 48 + 9 files changed, 1825 insertions(+), 7 deletions(-) create mode 100644 js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts create mode 100644 js/web/test/data/ops/matmulnbits.jsonc create mode 100644 onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc create mode 100644 onnxruntime/contrib_ops/js/quantization/matmul_nbits.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index b21af8e715db3..4a8c92bb97bfd 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -62,6 +62,7 @@ Do not modify directly.* | LessOrEqual | ai.onnx(12-15,16+) | | | Log | ai.onnx(6-12,13+) | | | MatMul | ai.onnx(1-12,13+) | | +| MatMulNBits | com.microsoft(1+) | | | MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation | | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index 6922d7ff5df6e..c0517ce363644 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -92,6 +92,34 @@ export class ShapeUtil { return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length); } + /** + * convert dims corresponding to type change to pack. ex. uint8 data to uint32 + */ + static convertShape(dims: readonly number[], size = 4): readonly number[] { + const rank = dims.length; + if (rank === 0) { + return []; + } + const newDims = new Array(rank); + let i = rank - 1; + while (i >= 0) { + if (dims[i] % size === 0) { + newDims[i] = dims[i] / size; + break; + } + if (size % dims[i] !== 0) { + throw new Error('cannot convert shape'); + } + newDims[i] = 1; + size /= dims[i]; + i--; + } + for (i--; i >= 0; i--) { + newDims[i] = dims[i]; + } + return newDims; + } + /** * calculate the size (number of elements) from the given axis (inclusive) */ diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index ac08c5fb1f7ab..ba874c8dd0f80 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -20,6 +20,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm} from './ops/instance-norm'; import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad} from './ops/pad'; import * as pool from './ops/pool'; @@ -92,6 +93,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['LessOrEqual', [binaryOps.lessOrEqual]], ['Log', [unaryOps.log]], ['MatMul', [matMul]], + ['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]], // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts new file mode 100644 index 0000000000000..ead7635cf3ac4 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; + +// TODO support quantization bits not equal to 4 +export interface MatMulNBitsAttributes extends AttributeWithCacheKey { + k: number; + n: number; + accuracyLevel: number; + bits: number; + blockSize: number; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => { + if (inputs.length < 3 || inputs.length > 4) { + throw new Error('MatMulNBits requires 3 or 4 inputs'); + } + const a = inputs[0]; + const aRank = a.dims.length; + if (a.dims[aRank - 1] !== attributes.k) { + throw new Error('The last dim of input shape does not match the k value'); + } + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const b = inputs[1]; + if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) { + throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize'); + } + const scales = inputs[2]; + const scalesShape = scales.dims; + if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) { + throw new Error('scales input size error.'); + } + if (inputs.length === 4) { + const zeroPoints = inputs[3]; + const zeroPointsShape = zeroPoints.dims; + const expectedZeroPointsSize = + attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2); + if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) { + throw new Error('zeroPoints input size error.'); + } + } +}; + +export const createMatMulNBitsProgramInfo = + (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => { + const a = inputs[0]; + const b = inputs[1]; + const scales = inputs[2]; + const aRank = a.dims.length; + const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n); + const outputSize = ShapeUtil.size(outputShape); + + + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k}, + {type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel}, + {type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize} + ]; + programUniforms.push(...createTensorShapeVariables(a.dims)); + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(b.dims))); + programUniforms.push(...createTensorShapeVariables(scales.dims)); + if (inputs.length === 4) { + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const a = inputVariable('a', inputs[0].dataType, inputs[0].dims.length); + const b = inputVariable('b', DataType.uint32, inputs[1].dims.length); + const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); + const inputVariables = [a, b, scales]; + const zeroPoints = + inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined; + if (zeroPoints) { + inputVariables.push(zeroPoints); + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'}, + {name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'} + ]; + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const wordPerBlob = blobSize / 4; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return ` + fn ortUnpack8x4snorm(value: u32) -> array<${dataType}, 8>{ + var result = array<${dataType}, 8>(); + var offset: u32 = 0; + let count: u32 = 4; + for (var i: u32 = 0; i < 8u; i++) { + result[i] = ${dataType}(extractBits(value, offset, count)); + offset += count; + } + return result; + } + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var value: ${dataType} = 0.0; + let output_indices = ${output.offsetToIndices('global_idx')}; + var a_indices: ${a.type.indices} = output_indices; + var n = ${output.indicesGet('output_indices', aRank - 1)}; + // Two zero points are packed into one byte because uniforms.bits <= 4. + // zero_point_offset is either 0 or 4. It is bit offset within one byte. + // TODO support zero_point_offset for bits > 4 + ${ + zeroPoints ? ` + var zero_point_index: u32 = n * ((${nBlocksPerCol} + 1) / 2) / 4; + var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; + var zero_point_offset: u32 = 0;` : + ''} + var scale_idex = n * ${nBlocksPerCol}; + var b_indices: ${b.type.indices}; + ${b.indicesSet('b_indices', '0', 'n')}; + var block_offset: u32 = 0; + for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { + // The scale and zero points are computed per block. + let scale = ${scales.getByOffset('scale_idex')}; + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point: ${dataType} = ${ + zeroPoints ? `${dataType}(extractBits(zero_point_word, zero_point_offset, 4))` : 8.0}; + ${b.indicesSet('b_indices', '1', 'block')}; + var word_offset: u32 = block_offset; + for (var word: u32 = 0; word < ${wordPerBlob}; word++) { + ${b.indicesSet('b_indices', '2', 'word')}; + let b_value = ${b.getByIndices('b_indices')}; + let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value); + // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 + var offset: u32 = word_offset; + for (var i: u32 = 0; i < 8; i++) { + ${a.indicesSet('a_indices', aRank - 1, 'offset')}; + let a_value = ${a.getByIndices('a_indices')}; + let b_quantized_value = b_quantized_values[i]; + let b_dequantized_value = (b_quantized_value - zero_point) * scale; + value += a_value * b_dequantized_value; + offset++; + } + word_offset += 8; + } + scale_idex++; + ${ + zeroPoints ? ` + if (zero_point_offset == 28) { + zero_point_offset = 0; + zero_point_index++; + zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; + } else { + zero_point_offset += 4; + }` : + ''} + block_offset += uniforms.block_size; + } + ${output.setByOffset('global_idx', 'value')}; + } + `; + }; + return { + name: 'MatMulNBits', + shaderCache: + {hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64)}, + programUniforms + }), + getShaderSource + }; + }; + +export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { + validateInputs(context.inputs, attributes); + context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); +}; + +export const parseMatMulNBitsAttributes = (attributes: Record): MatMulNBitsAttributes => + createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc new file mode 100644 index 0000000000000..c57c431afb3ce --- /dev/null +++ b/js/web/test/data/ops/matmulnbits.jsonc @@ -0,0 +1,1527 @@ +[ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, -2944, -1161, -3040, -715, -2880, -13, -2464, 945, 0, + -1073, -3808, -2643, -6848, -3445, -9120, -3479, -10624, -2745, -11360, -1243, -11328, 1027, -10528, 4065, + 0, -1761, -6496, -4323, -11712, -5605, -15648, -5607, -18304, -4329, -19680, -1771, -19776, 2067, -18592, + 7185, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, -25984, -5913, -28000, -2299, -28224, 3107, + -26656, 10305, 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, -33664, -7497, -36320, -2827, + -36672, 4147, -34720, 13425, 0, -3825, -14560, -9363, -26304, -12085, -35232, -11991, -41344, -9081, + -44640, -3355, -45120, 5187, -42784, 16545, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, + -49024, -10665, -52960, -3883, -53568, 6227, -50848, 19665, 0, -5201, -19936, -12723, -36032, -16405, + -48288, -16247, -56704, -12249, -61280, -4411, -62016, 7267, -58912, 22785, 0, -5889, -22624, -14403, + -40896, -18565, -54816, -18375, -64384, -13833, -69600, -4939, -70464, 8307, -66976, 25905, 0, -6577, + -25312, -16083, -45760, -20725, -61344, -20503, -72064, -15417, -77920, -5467, -78912, 9347, -75040, + 29025, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, -79744, -17001, -86240, -5995, -87360, + 10387, -83104, 32145, 0, -7953, -30688, -19443, -55488, -25045, -74400, -24759, -87424, -18585, -94560, + -6523, -95808, 11427, -91168, 35265, 0, -8641, -33376, -21123, -60352, -27205, -80928, -26887, -95104, + -20169, -102880, -7051, -104256, 12467, -99232, 38385, 0, -9329, -36064, -22803, -65216, -29365, -87456, + -29015, -102784, -21753, -111200, -7579, -112704, 13507, -107296, 41505, 0, -10017, -38752, -24483, + -70080, -31525, -93984, -31143, -110464, -23337, -119520, -8107, -121152, 14547, -115360, 44625, 0, + -10705, -41440, -26163, -74944, -33685, -100512, -33271, -118144, -24921, -127840, -8635, -129600, 15587, + -123424, 47745 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, 728, 688, 2376, 1632, 4280, 2832, 6440, 4288, 8856, 6000, 11528, 7968, 14456, 10192, 17640, 0, 2200, + 1840, 7176, 4448, 12920, 7824, 19432, 11968, 26712, 16880, 34760, 22560, 43576, 29008, 53160, 0, 3672, + 2992, 11976, 7264, 21560, 12816, 32424, 19648, 44568, 27760, 57992, 37152, 72696, 47824, 88680, 0, 5144, + 4144, 16776, 10080, 30200, 17808, 45416, 27328, 62424, 38640, 81224, 51744, 101816, 66640, 124200, 0, + 6616, 5296, 21576, 12896, 38840, 22800, 58408, 35008, 80280, 49520, 104456, 66336, 130936, 85456, 159720, + 0, 8088, 6448, 26376, 15712, 47480, 27792, 71400, 42688, 98136, 60400, 127688, 80928, 160056, 104272, + 195240, 0, 9560, 7600, 31176, 18528, 56120, 32784, 84392, 50368, 115992, 71280, 150920, 95520, 189176, + 123088, 230760, 0, 11032, 8752, 35976, 21344, 64760, 37776, 97384, 58048, 133848, 82160, 174152, 110112, + 218296, 141904, 266280, 0, 12504, 9904, 40776, 24160, 73400, 42768, 110376, 65728, 151704, 93040, 197384, + 124704, 247416, 160720, 301800, 0, 13976, 11056, 45576, 26976, 82040, 47760, 123368, 73408, 169560, + 103920, 220616, 139296, 276536, 179536, 337320, 0, 15448, 12208, 50376, 29792, 90680, 52752, 136360, + 81088, 187416, 114800, 243848, 153888, 305656, 198352, 372840, 0, 16920, 13360, 55176, 32608, 99320, + 57744, 149352, 88768, 205272, 125680, 267080, 168480, 334776, 217168, 408360, 0, 18392, 14512, 59976, + 35424, 107960, 62736, 162344, 96448, 223128, 136560, 290312, 183072, 363896, 235984, 443880, 0, 19864, + 15664, 64776, 38240, 116600, 67728, 175336, 104128, 240984, 147440, 313544, 197664, 393016, 254800, + 479400, 0, 21336, 16816, 69576, 41056, 125240, 72720, 188328, 111808, 258840, 158320, 336776, 212256, + 422136, 273616, 514920, 0, 22808, 17968, 74376, 43872, 133880, 77712, 201320, 119488, 276696, 169200, + 360008, 226848, 451256, 292432, 550440 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -428, -1288, -1068, -2288, -1420, -3000, -1484, -3424, -1260, -3560, -748, -3408, 52, -2968, 1140, + -2272, 2516, -1224, 4180, 80, 6132, 1672, 8372, 3552, 10900, 5720, 13716, 8176, 16820, 10920, 12276, 0, + -1116, -3976, -2748, -7152, -3580, -9528, -3612, -11104, -2844, -11880, -1276, -11856, 1092, -11032, 4260, + -8160, 8228, -6984, 12996, -3760, 18564, 264, 24932, 5088, 32100, 10712, 40068, 17136, 48836, 24360, + 42532, 0, -1804, -6664, -4428, -12016, -5740, -16056, -5740, -18784, -4428, -20200, -1804, -20304, 2132, + -19096, 7380, -14048, 13940, -12744, 21812, -7600, 30996, -1144, 41492, 6624, 53300, 15704, 66420, 26096, + 80852, 37800, 72788, 0, -2492, -9352, -6108, -16880, -7900, -22584, -7868, -26464, -6012, -28520, -2332, + -28752, 3172, -27160, 10500, -19936, 19652, -18504, 30628, -11440, 43428, -2552, 58052, 8160, 74500, + 20696, 92772, 35056, 112868, 51240, 103044, 0, -3180, -12040, -7788, -21744, -10060, -29112, -9996, + -34144, -7596, -36840, -2860, -37200, 4212, -35224, 13620, -25824, 25364, -24264, 39444, -15280, 55860, + -3960, 74612, 9696, 95700, 25688, 119124, 44016, 144884, 64680, 133300, 0, -3868, -14728, -9468, -26608, + -12220, -35640, -12124, -41824, -9180, -45160, -3388, -45648, 5252, -43288, 16740, -31712, 31076, -30024, + 48260, -19120, 68292, -5368, 91172, 11232, 116900, 30680, 145476, 52976, 176900, 78120, 163556, 0, -4556, + -17416, -11148, -31472, -14380, -42168, -14252, -49504, -10764, -53480, -3916, -54096, 6292, -51352, + 19860, -37600, 36788, -35784, 57076, -22960, 80724, -6776, 107732, 12768, 138100, 35672, 171828, 61936, + 208916, 91560, 193812, 0, -5244, -20104, -12828, -36336, -16540, -48696, -16380, -57184, -12348, -61800, + -4444, -62544, 7332, -59416, 22980, -43488, 42500, -41544, 65892, -26800, 93156, -8184, 124292, 14304, + 159300, 40664, 198180, 70896, 240932, 105000, 224068, 0, -5932, -22792, -14508, -41200, -18700, -55224, + -18508, -64864, -13932, -70120, -4972, -70992, 8372, -67480, 26100, -49376, 48212, -47304, 74708, -30640, + 105588, -9592, 140852, 15840, 180500, 45656, 224532, 79856, 272948, 118440, 254324, 0, -6620, -25480, + -16188, -46064, -20860, -61752, -20636, -72544, -15516, -78440, -5500, -79440, 9412, -75544, 29220, + -55264, 53924, -53064, 83524, -34480, 118020, -11000, 157412, 17376, 201700, 50648, 250884, 88816, 304964, + 131880, 284580, 0, -7308, -28168, -17868, -50928, -23020, -68280, -22764, -80224, -17100, -86760, -6028, + -87888, 10452, -83608, 32340, -61152, 59636, -58824, 92340, -38320, 130452, -12408, 173972, 18912, 222900, + 55640, 277236, 97776, 336980, 145320, 314836, 0, -7996, -30856, -19548, -55792, -25180, -74808, -24892, + -87904, -18684, -95080, -6556, -96336, 11492, -91672, 35460, -67040, 65348, -64584, 101156, -42160, + 142884, -13816, 190532, 20448, 244100, 60632, 303588, 106736, 368996, 158760, 345092, 0, -8684, -33544, + -21228, -60656, -27340, -81336, -27020, -95584, -20268, -103400, -7084, -104784, 12532, -99736, 38580, + -72928, 71060, -70344, 109972, -46000, 155316, -15224, 207092, 21984, 265300, 65624, 329940, 115696, + 401012, 172200, 375348, 0, -9372, -36232, -22908, -65520, -29500, -87864, -29148, -103264, -21852, + -111720, -7612, -113232, 13572, -107800, 41700, -78816, 76772, -76104, 118788, -49840, 167748, -16632, + 223652, 23520, 286500, 70616, 356292, 124656, 433028, 185640, 405604, 0, -10060, -38920, -24588, -70384, + -31660, -94392, -31276, -110944, -23436, -120040, -8140, -121680, 14612, -115864, 44820, -84704, 82484, + -81864, 127604, -53680, 180180, -18040, 240212, 25056, 307700, 75608, 382644, 133616, 465044, 199080, + 435860, 0, -10748, -41608, -26268, -75248, -33820, -100920, -33404, -118624, -25020, -128360, -8668, + -130128, 15652, -123928, 47940, -90592, 88196, -87624, 136420, -57520, 192612, -19448, 256772, 26592, + 328900, 80600, 408996, 142576, 497060, 212520, 466116, 0, -11436, -44296, -27948, -80112, -35980, -107448, + -35532, -126304, -26604, -136680, -9196, -138576, 16692, -131992, 51060, -96480, 93908, -93384, 145236, + -61360, 205044, -20856, 273332, 28128, 350100, 85592, 435348, 151536, 529076, 225960, 496372, 0, -12124, + -46984, -29628, -84976, -38140, -113976, -37660, -133984, -28188, -145000, -9724, -147024, 17732, -140056, + 54180, -102368, 99620, -99144, 154052, -65200, 217476, -22264, 289892, 29664, 371300, 90584, 461700, + 160496, 561092, 239400, 526628, 0, -12812, -49672, -31308, -89840, -40300, -120504, -39788, -141664, + -29772, -153320, -10252, -155472, 18772, -148120, 57300, -108256, 105332, -104904, 162868, -69040, 229908, + -23672, 306452, 31200, 392500, 95576, 488052, 169456, 593108, 252840, 556884, 0, -13500, -52360, -32988, + -94704, -42460, -127032, -41916, -149344, -31356, -161640, -10780, -163920, 19812, -156184, 60420, + -114144, 111044, -110664, 171684, -72880, 242340, -25080, 323012, 32736, 413700, 100568, 514404, 178416, + 625124, 266280, 587140, 0, -14188, -55048, -34668, -99568, -44620, -133560, -44044, -157024, -32940, + -169960, -11308, -172368, 20852, -164248, 63540, -120032, 116756, -116424, 180500, -76720, 254772, -26488, + 339572, 34272, 434900, 105560, 540756, 187376, 657140, 279720, 617396, 0, -14876, -57736, -36348, -104432, + -46780, -140088, -46172, -164704, -34524, -178280, -11836, -180816, 21892, -172312, 66660, -125920, + 122468, -122184, 189316, -80560, 267204, -27896, 356132, 35808, 456100, 110552, 567108, 196336, 689156, + 293160, 647652, 0, -15564, -60424, -38028, -109296, -48940, -146616, -48300, -172384, -36108, -186600, + -12364, -189264, 22932, -180376, 69780, -131808, 128180, -127944, 198132, -84400, 279636, -29304, 372692, + 37344, 477300, 115544, 593460, 205296, 721172, 306600, 677908, 0, -16252, -63112, -39708, -114160, -51100, + -153144, -50428, -180064, -37692, -194920, -12892, -197712, 23972, -188440, 72900, -137696, 133892, + -133704, 206948, -88240, 292068, -30712, 389252, 38880, 498500, 120536, 619812, 214256, 753188, 320040, + 708164, 0, -16940, -65800, -41388, -119024, -53260, -159672, -52556, -187744, -39276, -203240, -13420, + -206160, 25012, -196504, 76020, -143584, 139604, -139464, 215764, -92080, 304500, -32120, 405812, 40416, + 519700, 125528, 646164, 223216, 785204, 333480, 738420, 0, -17628, -68488, -43068, -123888, -55420, + -166200, -54684, -195424, -40860, -211560, -13948, -214608, 26052, -204568, 79140, -149472, 145316, + -145224, 224580, -95920, 316932, -33528, 422372, 41952, 540900, 130520, 672516, 232176, 817220, 346920, + 768676, 0, -18316, -71176, -44748, -128752, -57580, -172728, -56812, -203104, -42444, -219880, -14476, + -223056, 27092, -212632, 82260, -155360, 151028, -150984, 233396, -99760, 329364, -34936, 438932, 43488, + 562100, 135512, 698868, 241136, 849236, 360360, 798932, 0, -19004, -73864, -46428, -133616, -59740, + -179256, -58940, -210784, -44028, -228200, -15004, -231504, 28132, -220696, 85380, -161248, 156740, + -156744, 242212, -103600, 341796, -36344, 455492, 45024, 583300, 140504, 725220, 250096, 881252, 373800, + 829188, 0, -19692, -76552, -48108, -138480, -61900, -185784, -61068, -218464, -45612, -236520, -15532, + -239952, 29172, -228760, 88500, -167136, 162452, -162504, 251028, -107440, 354228, -37752, 472052, 46560, + 604500, 145496, 751572, 259056, 913268, 387240, 859444, 0, -20380, -79240, -49788, -143344, -64060, + -192312, -63196, -226144, -47196, -244840, -16060, -248400, 30212, -236824, 91620, -173024, 168164, + -168264, 259844, -111280, 366660, -39160, 488612, 48096, 625700, 150488, 777924, 268016, 945284, 400680, + 889700, 0, -21068, -81928, -51468, -148208, -66220, -198840, -65324, -233824, -48780, -253160, -16588, + -256848, 31252, -244888, 94740, -178912, 173876, -174024, 268660, -115120, 379092, -40568, 505172, 49632, + 646900, 155480, 804276, 276976, 977300, 414120, 919956, 0, -21756, -84616, -53148, -153072, -68380, + -205368, -67452, -241504, -50364, -261480, -17116, -265296, 32292, -252952, 97860, -184800, 179588, + -179784, 277476, -118960, 391524, -41976, 521732, 51168, 668100, 160472, 830628, 285936, 1009316, 427560, + 950212 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 660, 888, 2196, 2064, 4020, 3528, 6132, 5280, 8532, 7320, 11220, 9648, 14196, 12264, 17460, 15136, + 21012, 18360, 24852, 21840, 28980, 25608, 33396, 29664, 38100, 34008, 43092, 38640, 48372, 43560, 46004, + 0, 2020, 2296, 6660, 5392, 12100, 9288, 18340, 13984, 25380, 19480, 33220, 25776, 41860, 32872, 51300, + 42016, 61540, 49464, 72580, 58960, 84420, 69256, 97060, 80352, 110500, 92248, 124740, 104944, 139780, + 118440, 139748, 0, 3380, 3704, 11124, 8720, 20180, 15048, 30548, 22688, 42228, 31640, 55220, 41904, 69524, + 53480, 85140, 68896, 102068, 80568, 120308, 96080, 139860, 112904, 160724, 131040, 182900, 150488, 206388, + 171248, 231188, 193320, 233492, 0, 4740, 5112, 15588, 12048, 28260, 20808, 42756, 31392, 59076, 43800, + 77220, 58032, 97188, 74088, 118980, 95776, 142596, 111672, 168036, 133200, 195300, 156552, 224388, 181728, + 255300, 208728, 288036, 237552, 322596, 268200, 327236, 0, 6100, 6520, 20052, 15376, 36340, 26568, 54964, + 40096, 75924, 55960, 99220, 74160, 124852, 94696, 152820, 122656, 183124, 142776, 215764, 170320, 250740, + 200200, 288052, 232416, 327700, 266968, 369684, 303856, 414004, 343080, 420980, 0, 7460, 7928, 24516, + 18704, 44420, 32328, 67172, 48800, 92772, 68120, 121220, 90288, 152516, 115304, 186660, 149536, 223652, + 173880, 263492, 207440, 306180, 243848, 351716, 283104, 400100, 325208, 451332, 370160, 505412, 417960, + 514724, 0, 8820, 9336, 28980, 22032, 52500, 38088, 79380, 57504, 109620, 80280, 143220, 106416, 180180, + 135912, 220500, 176416, 264180, 204984, 311220, 244560, 361620, 287496, 415380, 333792, 472500, 383448, + 532980, 436464, 596820, 492840, 608468, 0, 10180, 10744, 33444, 25360, 60580, 43848, 91588, 66208, 126468, + 92440, 165220, 122544, 207844, 156520, 254340, 203296, 304708, 236088, 358948, 281680, 417060, 331144, + 479044, 384480, 544900, 441688, 614628, 502768, 688228, 567720, 702212, 0, 11540, 12152, 37908, 28688, + 68660, 49608, 103796, 74912, 143316, 104600, 187220, 138672, 235508, 177128, 288180, 230176, 345236, + 267192, 406676, 318800, 472500, 374792, 542708, 435168, 617300, 499928, 696276, 569072, 779636, 642600, + 795956, 0, 12900, 13560, 42372, 32016, 76740, 55368, 116004, 83616, 160164, 116760, 209220, 154800, + 263172, 197736, 322020, 257056, 385764, 298296, 454404, 355920, 527940, 418440, 606372, 485856, 689700, + 558168, 777924, 635376, 871044, 717480, 889700, 0, 14260, 14968, 46836, 35344, 84820, 61128, 128212, + 92320, 177012, 128920, 231220, 170928, 290836, 218344, 355860, 283936, 426292, 329400, 502132, 393040, + 583380, 462088, 670036, 536544, 762100, 616408, 859572, 701680, 962452, 792360, 983444, 0, 15620, 16376, + 51300, 38672, 92900, 66888, 140420, 101024, 193860, 141080, 253220, 187056, 318500, 238952, 389700, + 310816, 466820, 360504, 549860, 430160, 638820, 505736, 733700, 587232, 834500, 674648, 941220, 767984, + 1053860, 867240, 1077188, 0, 16980, 17784, 55764, 42000, 100980, 72648, 152628, 109728, 210708, 153240, + 275220, 203184, 346164, 259560, 423540, 337696, 507348, 391608, 597588, 467280, 694260, 549384, 797364, + 637920, 906900, 732888, 1022868, 834288, 1145268, 942120, 1170932, 0, 18340, 19192, 60228, 45328, 109060, + 78408, 164836, 118432, 227556, 165400, 297220, 219312, 373828, 280168, 457380, 364576, 547876, 422712, + 645316, 504400, 749700, 593032, 861028, 688608, 979300, 791128, 1104516, 900592, 1236676, 1017000, + 1264676, 0, 19700, 20600, 64692, 48656, 117140, 84168, 177044, 127136, 244404, 177560, 319220, 235440, + 401492, 300776, 491220, 391456, 588404, 453816, 693044, 541520, 805140, 636680, 924692, 739296, 1051700, + 849368, 1186164, 966896, 1328084, 1091880, 1358420, 0, 21060, 22008, 69156, 51984, 125220, 89928, 189252, + 135840, 261252, 189720, 341220, 251568, 429156, 321384, 525060, 418336, 628932, 484920, 740772, 578640, + 860580, 680328, 988356, 789984, 1124100, 907608, 1267812, 1033200, 1419492, 1166760, 1452164, 0, 22420, + 23416, 73620, 55312, 133300, 95688, 201460, 144544, 278100, 201880, 363220, 267696, 456820, 341992, + 558900, 445216, 669460, 516024, 788500, 615760, 916020, 723976, 1052020, 840672, 1196500, 965848, 1349460, + 1099504, 1510900, 1241640, 1545908, 0, 23780, 24824, 78084, 58640, 141380, 101448, 213668, 153248, 294948, + 214040, 385220, 283824, 484484, 362600, 592740, 472096, 709988, 547128, 836228, 652880, 971460, 767624, + 1115684, 891360, 1268900, 1024088, 1431108, 1165808, 1602308, 1316520, 1639652, 0, 25140, 26232, 82548, + 61968, 149460, 107208, 225876, 161952, 311796, 226200, 407220, 299952, 512148, 383208, 626580, 498976, + 750516, 578232, 883956, 690000, 1026900, 811272, 1179348, 942048, 1341300, 1082328, 1512756, 1232112, + 1693716, 1391400, 1733396, 0, 26500, 27640, 87012, 65296, 157540, 112968, 238084, 170656, 328644, 238360, + 429220, 316080, 539812, 403816, 660420, 525856, 791044, 609336, 931684, 727120, 1082340, 854920, 1243012, + 992736, 1413700, 1140568, 1594404, 1298416, 1785124, 1466280, 1827140, 0, 27860, 29048, 91476, 68624, + 165620, 118728, 250292, 179360, 345492, 250520, 451220, 332208, 567476, 424424, 694260, 552736, 831572, + 640440, 979412, 764240, 1137780, 898568, 1306676, 1043424, 1486100, 1198808, 1676052, 1364720, 1876532, + 1541160, 1920884, 0, 29220, 30456, 95940, 71952, 173700, 124488, 262500, 188064, 362340, 262680, 473220, + 348336, 595140, 445032, 728100, 579616, 872100, 671544, 1027140, 801360, 1193220, 942216, 1370340, + 1094112, 1558500, 1257048, 1757700, 1431024, 1967940, 1616040, 2014628, 0, 30580, 31864, 100404, 75280, + 181780, 130248, 274708, 196768, 379188, 274840, 495220, 364464, 622804, 465640, 761940, 606496, 912628, + 702648, 1074868, 838480, 1248660, 985864, 1434004, 1144800, 1630900, 1315288, 1839348, 1497328, 2059348, + 1690920, 2108372, 0, 31940, 33272, 104868, 78608, 189860, 136008, 286916, 205472, 396036, 287000, 517220, + 380592, 650468, 486248, 795780, 633376, 953156, 733752, 1122596, 875600, 1304100, 1029512, 1497668, + 1195488, 1703300, 1373528, 1920996, 1563632, 2150756, 1765800, 2202116, 0, 33300, 34680, 109332, 81936, + 197940, 141768, 299124, 214176, 412884, 299160, 539220, 396720, 678132, 506856, 829620, 660256, 993684, + 764856, 1170324, 912720, 1359540, 1073160, 1561332, 1246176, 1775700, 1431768, 2002644, 1629936, 2242164, + 1840680, 2295860, 0, 34660, 36088, 113796, 85264, 206020, 147528, 311332, 222880, 429732, 311320, 561220, + 412848, 705796, 527464, 863460, 687136, 1034212, 795960, 1218052, 949840, 1414980, 1116808, 1624996, + 1296864, 1848100, 1490008, 2084292, 1696240, 2333572, 1915560, 2389604, 0, 36020, 37496, 118260, 88592, + 214100, 153288, 323540, 231584, 446580, 323480, 583220, 428976, 733460, 548072, 897300, 714016, 1074740, + 827064, 1265780, 986960, 1470420, 1160456, 1688660, 1347552, 1920500, 1548248, 2165940, 1762544, 2424980, + 1990440, 2483348, 0, 37380, 38904, 122724, 91920, 222180, 159048, 335748, 240288, 463428, 335640, 605220, + 445104, 761124, 568680, 931140, 740896, 1115268, 858168, 1313508, 1024080, 1525860, 1204104, 1752324, + 1398240, 1992900, 1606488, 2247588, 1828848, 2516388, 2065320, 2577092, 0, 38740, 40312, 127188, 95248, + 230260, 164808, 347956, 248992, 480276, 347800, 627220, 461232, 788788, 589288, 964980, 767776, 1155796, + 889272, 1361236, 1061200, 1581300, 1247752, 1815988, 1448928, 2065300, 1664728, 2329236, 1895152, 2607796, + 2140200, 2670836, 0, 40100, 41720, 131652, 98576, 238340, 170568, 360164, 257696, 497124, 359960, 649220, + 477360, 816452, 609896, 998820, 794656, 1196324, 920376, 1408964, 1098320, 1636740, 1291400, 1879652, + 1499616, 2137700, 1722968, 2410884, 1961456, 2699204, 2215080, 2764580, 0, 41460, 43128, 136116, 101904, + 246420, 176328, 372372, 266400, 513972, 372120, 671220, 493488, 844116, 630504, 1032660, 821536, 1236852, + 951480, 1456692, 1135440, 1692180, 1335048, 1943316, 1550304, 2210100, 1781208, 2492532, 2027760, 2790612, + 2289960, 2858324, 0, 42820, 44536, 140580, 105232, 254500, 182088, 384580, 275104, 530820, 384280, 693220, + 509616, 871780, 651112, 1066500, 848416, 1277380, 982584, 1504420, 1172560, 1747620, 1378696, 2006980, + 1600992, 2282500, 1839448, 2574180, 2094064, 2882020, 2364840, 2952068 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, 56908, 81124, + 108476, 138964, 140844, -3868, -21508, -33964, -41236, -43324, -40228, -31948, -18484, 5252, 23996, 53012, + 87212, 126596, 171164, 220916, 228236, -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, + 4900, 30108, 70196, 117516, 172068, 233852, 302868, 315628, -6620, -38980, -62060, -75860, -80380, -75620, + -61580, -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -7996, -47716, -76108, -93172, + -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, 466772, 490412, -9372, + -56452, -90156, -110484, -117436, -111012, -91212, -58036, 3844, 48444, 121748, 208428, 308484, 421916, + 548724, 577804, -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, + 238732, 353956, 484604, 630676, 665196, -12124, -73924, -118252, -145108, -154492, -146404, -120844, + -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -13500, -82660, -132300, -162420, + -173020, -164100, -135660, -87700, 2788, 66780, 173300, 299340, 444900, 609980, 794580, 839980, -14876, + -91396, -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, + 876532, 927372, -16252, -100132, -160396, -197044, -210076, -199492, -165292, -107476, 2084, 79004, + 207668, 359948, 535844, 735356, 958484, 1014764, -17628, -108868, -174444, -214356, -228604, -217188, + -180108, -117364, 1732, 85116, 224852, 390252, 581316, 798044, 1040436, 1102156, -19004, -117604, -188492, + -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, 860732, 1122388, + 1189548, -20380, -126340, -202540, -248980, -265660, -252580, -209740, -137140, 1028, 97340, 259220, + 450860, 672260, 923420, 1204340, 1276940, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, 170956, 205540, 243260, + 284116, 296364, -3868, -2948, 3156, 14444, 30916, 52572, 79412, 111436, 153732, 191036, 238612, 291372, + 349316, 412444, 480756, 506636, -5244, -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, + 337716, 411788, 493092, 581628, 677396, 716908, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, + 284100, 350716, 436820, 532204, 636868, 750812, 874036, 927180, -7996, -4580, 10164, 36236, 73636, 122364, + 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, -9372, -5124, 12500, + 43500, 87876, 145628, 216756, 301260, 414468, 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, + -10748, -5668, 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, + 1258364, 1463956, 1557996, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, 544836, 670076, + 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -13500, -6756, 19508, 65292, 130596, 215420, 319764, + 443628, 610020, 749916, 932340, 1134284, 1355748, 1596732, 1857236, 1978540, -14876, -7300, 21844, 72556, + 144836, 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, + -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, 1130548, 1375116, 1643300, + 1935100, 2250516, 2399084, -17628, -8388, 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, + 1229652, 1495532, 1787076, 2104284, 2447156, 2609356, -19004, -8932, 28852, 94348, 187556, 308476, 457108, + 633452, 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -20380, -9476, 31188, + 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, 2442652, 2840436, + 3029900, -21756, -10020, 33524, 108876, 216036, 355004, 525780, 728364, 1001124, 1228956, 1526964, + 1856780, 2218404, 2611836, 3037076, 3240172 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -59740, -53956, -47084, -39124, -30076, -19940, -8716, 3596, 16996, 31484, 47060, 63724, 81476, + 100316, 120244, 109004, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, + 56908, 81124, 108476, 138964, 140844, -199356, -184548, -166604, -145524, -121308, -93956, -63468, -29844, + 6916, 46812, 89844, 136012, 185316, 237756, 293332, 287532, -3868, -21508, -33964, -41236, -43324, -40228, + -31948, -18484, 5252, 23996, 53012, 87212, 126596, 171164, 220916, 228236, -338972, -315140, -286124, + -251924, -212540, -167972, -118220, -63284, -3164, 62140, 132628, 208300, 289156, 375196, 466420, 466060, + -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, 4900, 30108, 70196, 117516, 172068, 233852, + 302868, 315628, -478588, -445732, -405644, -358324, -303772, -241988, -172972, -96724, -13244, 77468, + 175412, 280588, 392996, 512636, 639508, 644588, -6620, -38980, -62060, -75860, -80380, -75620, -61580, + -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -618204, -576324, -525164, -464724, + -395004, -316004, -227724, -130164, -23324, 92796, 218196, 352876, 496836, 650076, 812596, 823116, -7996, + -47716, -76108, -93172, -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, + 466772, 490412, -757820, -706916, -644684, -571124, -486236, -390020, -282476, -163604, -33404, 108124, + 260980, 425164, 600676, 787516, 985684, 1001644, -9372, -56452, -90156, -110484, -117436, -111012, -91212, + -58036, 3844, 48444, 121748, 208428, 308484, 421916, 548724, 577804, -897436, -837508, -764204, -677524, + -577468, -464036, -337228, -197044, -43484, 123452, 303764, 497452, 704516, 924956, 1158772, 1180172, + -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, 238732, 353956, + 484604, 630676, 665196, -1037052, -968100, -883724, -783924, -668700, -538052, -391980, -230484, -53564, + 138780, 346548, 569740, 808356, 1062396, 1331860, 1358700, -12124, -73924, -118252, -145108, -154492, + -146404, -120844, -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -1176668, -1098692, + -1003244, -890324, -759932, -612068, -446732, -263924, -63644, 154108, 389332, 642028, 912196, 1199836, + 1504948, 1537228, -13500, -82660, -132300, -162420, -173020, -164100, -135660, -87700, 2788, 66780, + 173300, 299340, 444900, 609980, 794580, 839980, -1316284, -1229284, -1122764, -996724, -851164, -686084, + -501484, -297364, -73724, 169436, 432116, 714316, 1016036, 1337276, 1678036, 1715756, -14876, -91396, + -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, 876532, + 927372, -1455900, -1359876, -1242284, -1103124, -942396, -760100, -556236, -330804, -83804, 184764, + 474900, 786604, 1119876, 1474716, 1851124, 1894284, -16252, -100132, -160396, -197044, -210076, -199492, + -165292, -107476, 2084, 79004, 207668, 359948, 535844, 735356, 958484, 1014764, -1595516, -1490468, + -1361804, -1209524, -1033628, -834116, -610988, -364244, -93884, 200092, 517684, 858892, 1223716, 1612156, + 2024212, 2072812, -17628, -108868, -174444, -214356, -228604, -217188, -180108, -117364, 1732, 85116, + 224852, 390252, 581316, 798044, 1040436, 1102156, -1735132, -1621060, -1481324, -1315924, -1124860, + -908132, -665740, -397684, -103964, 215420, 560468, 931180, 1327556, 1749596, 2197300, 2251340, -19004, + -117604, -188492, -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, + 860732, 1122388, 1189548, -1874748, -1751652, -1600844, -1422324, -1216092, -982148, -720492, -431124, + -114044, 230748, 603252, 1003468, 1431396, 1887036, 2370388, 2429868, -20380, -126340, -202540, -248980, + -265660, -252580, -209740, -137140, 1028, 97340, 259220, 450860, 672260, 923420, 1204340, 1276940, + -2014364, -1882244, -1720364, -1528724, -1307324, -1056164, -775244, -464564, -124124, 246076, 646036, + 1075756, 1535236, 2024476, 2543476, 2608396, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332, -2153980, -2012836, -1839884, + -1635124, -1398556, -1130180, -829996, -498004, -134204, 261404, 688820, 1148044, 1639076, 2161916, + 2716564, 2786924, -23132, -143812, -230636, -283604, -302716, -287972, -239372, -156916, 324, 109564, + 293588, 511468, 763204, 1048796, 1368244, 1451724, -2293596, -2143428, -1959404, -1741524, -1489788, + -1204196, -884748, -531444, -144284, 276732, 731604, 1220332, 1742916, 2299356, 2889652, 2965452, -24508, + -152548, -244684, -300916, -321244, -305668, -254188, -166804, -28, 115676, 310772, 541772, 808676, + 1111484, 1450196, 1539116, -2433212, -2274020, -2078924, -1847924, -1581020, -1278212, -939500, -564884, + -154364, 292060, 774388, 1292620, 1846756, 2436796, 3062740, 3143980, -25884, -161284, -258732, -318228, + -339772, -323364, -269004, -176692, -380, 121788, 327956, 572076, 854148, 1174172, 1532148, 1626508, + -2572828, -2404612, -2198444, -1954324, -1672252, -1352228, -994252, -598324, -164444, 307388, 817172, + 1364908, 1950596, 2574236, 3235828, 3322508, -27260, -170020, -272780, -335540, -358300, -341060, -283820, + -186580, -732, 127900, 345140, 602380, 899620, 1236860, 1614100, 1713900, -2712444, -2535204, -2317964, + -2060724, -1763484, -1426244, -1049004, -631764, -174524, 322716, 859956, 1437196, 2054436, 2711676, + 3408916, 3501036, -28636, -178756, -286828, -352852, -376828, -358756, -298636, -196468, -1084, 134012, + 362324, 632684, 945092, 1299548, 1696052, 1801292, -2852060, -2665796, -2437484, -2167124, -1854716, + -1500260, -1103756, -665204, -184604, 338044, 902740, 1509484, 2158276, 2849116, 3582004, 3679564, -30012, + -187492, -300876, -370164, -395356, -376452, -313452, -206356, -1436, 140124, 379508, 662988, 990564, + 1362236, 1778004, 1888684, -2991676, -2796388, -2557004, -2273524, -1945948, -1574276, -1158508, -698644, + -194684, 353372, 945524, 1581772, 2262116, 2986556, 3755092, 3858092, -31388, -196228, -314924, -387476, + -413884, -394148, -328268, -216244, -1788, 146236, 396692, 693292, 1036036, 1424924, 1859956, 1976076, + -3131292, -2926980, -2676524, -2379924, -2037180, -1648292, -1213260, -732084, -204764, 368700, 988308, + 1654060, 2365956, 3123996, 3928180, 4036620, -32764, -204964, -328972, -404788, -432412, -411844, -343084, + -226132, -2140, 152348, 413876, 723596, 1081508, 1487612, 1941908, 2063468, -3270908, -3057572, -2796044, + -2486324, -2128412, -1722308, -1268012, -765524, -214844, 384028, 1031092, 1726348, 2469796, 3261436, + 4101268, 4215148, -34140, -213700, -343020, -422100, -450940, -429540, -357900, -236020, -2492, 158460, + 431060, 753900, 1126980, 1550300, 2023860, 2150860, -3410524, -3188164, -2915564, -2592724, -2219644, + -1796324, -1322764, -798964, -224924, 399356, 1073876, 1798636, 2573636, 3398876, 4274356, 4393676, + -35516, -222436, -357068, -439412, -469468, -447236, -372716, -245908, -2844, 164572, 448244, 784204, + 1172452, 1612988, 2105812, 2238252, -3550140, -3318756, -3035084, -2699124, -2310876, -1870340, -1377516, + -832404, -235004, 414684, 1116660, 1870924, 2677476, 3536316, 4447444, 4572204, -36892, -231172, -371116, + -456724, -487996, -464932, -387532, -255796, -3196, 170684, 465428, 814508, 1217924, 1675676, 2187764, + 2325644, -3689756, -3449348, -3154604, -2805524, -2402108, -1944356, -1432268, -865844, -245084, 430012, + 1159444, 1943212, 2781316, 3673756, 4620532, 4750732, -38268, -239908, -385164, -474036, -506524, -482628, + -402348, -265684, -3548, 176796, 482612, 844812, 1263396, 1738364, 2269716, 2413036, -3829372, -3579940, + -3274124, -2911924, -2493340, -2018372, -1487020, -899284, -255164, 445340, 1202228, 2015500, 2885156, + 3811196, 4793620, 4929260, -39644, -248644, -399212, -491348, -525052, -500324, -417164, -275572, -3900, + 182908, 499796, 875116, 1308868, 1801052, 2351668, 2500428, -3968988, -3710532, -3393644, -3018324, + -2584572, -2092388, -1541772, -932724, -265244, 460668, 1245012, 2087788, 2988996, 3948636, 4966708, + 5107788, -41020, -257380, -413260, -508660, -543580, -518020, -431980, -285460, -4252, 189020, 516980, + 905420, 1354340, 1863740, 2433620, 2587820, -4108604, -3841124, -3513164, -3124724, -2675804, -2166404, + -1596524, -966164, -275324, 475996, 1287796, 2160076, 3092836, 4086076, 5139796, 5286316, -42396, -266116, + -427308, -525972, -562108, -535716, -446796, -295348, -4604, 195132, 534164, 935724, 1399812, 1926428, + 2515572, 2675212, -4248220, -3971716, -3632684, -3231124, -2767036, -2240420, -1651276, -999604, -285404, + 491324, 1330580, 2232364, 3196676, 4223516, 5312884, 5464844, -43772, -274852, -441356, -543284, -580636, + -553412, -461612, -305236, -4956, 201244, 551348, 966028, 1445284, 1989116, 2597524, 2762604, -4387836, + -4102308, -3752204, -3337524, -2858268, -2314436, -1706028, -1033044, -295484, 506652, 1373364, 2304652, + 3300516, 4360956, 5485972, 5643372 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -24924, -16964, -7916, 2220, 13444, 25756, 39156, 53644, 69220, 85884, 103636, 122476, 142404, + 163420, 185524, 176460, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, + 170956, 205540, 243260, 284116, 296364, -33468, -8292, 20020, 51468, 86052, 123772, 164628, 208620, + 255748, 306012, 359412, 415948, 475620, 538428, 604372, 608940, -3868, -2948, 3156, 14444, 30916, 52572, + 79412, 111436, 153732, 191036, 238612, 291372, 349316, 412444, 480756, 506636, -42012, 380, 47956, 100716, + 158660, 221788, 290100, 363596, 442276, 526140, 615188, 709420, 808836, 913436, 1023220, 1041420, -5244, + -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, 337716, 411788, 493092, 581628, 677396, + 716908, -50556, 9052, 75892, 149964, 231268, 319804, 415572, 518572, 628804, 746268, 870964, 1002892, + 1142052, 1288444, 1442068, 1473900, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, 284100, + 350716, 436820, 532204, 636868, 750812, 874036, 927180, -59100, 17724, 103828, 199212, 303876, 417820, + 541044, 673548, 815332, 966396, 1126740, 1296364, 1475268, 1663452, 1860916, 1906380, -7996, -4580, 10164, + 36236, 73636, 122364, 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, + -67644, 26396, 131764, 248460, 376484, 515836, 666516, 828524, 1001860, 1186524, 1382516, 1589836, + 1808484, 2038460, 2279764, 2338860, -9372, -5124, 12500, 43500, 87876, 145628, 216756, 301260, 414468, + 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, -76188, 35068, 159700, 297708, 449092, 613852, + 791988, 983500, 1188388, 1406652, 1638292, 1883308, 2141700, 2413468, 2698612, 2771340, -10748, -5668, + 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, 1258364, 1463956, + 1557996, -84732, 43740, 187636, 346956, 521700, 711868, 917460, 1138476, 1374916, 1626780, 1894068, + 2176780, 2474916, 2788476, 3117460, 3203820, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, + 544836, 670076, 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -93276, 52412, 215572, 396204, + 594308, 809884, 1042932, 1293452, 1561444, 1846908, 2149844, 2470252, 2808132, 3163484, 3536308, 3636300, + -13500, -6756, 19508, 65292, 130596, 215420, 319764, 443628, 610020, 749916, 932340, 1134284, 1355748, + 1596732, 1857236, 1978540, -101820, 61084, 243508, 445452, 666916, 907900, 1168404, 1448428, 1747972, + 2067036, 2405620, 2763724, 3141348, 3538492, 3955156, 4068780, -14876, -7300, 21844, 72556, 144836, + 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, -110364, + 69756, 271444, 494700, 739524, 1005916, 1293876, 1603404, 1934500, 2287164, 2661396, 3057196, 3474564, + 3913500, 4374004, 4501260, -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, + 1130548, 1375116, 1643300, 1935100, 2250516, 2399084, -118908, 78428, 299380, 543948, 812132, 1103932, + 1419348, 1758380, 2121028, 2507292, 2917172, 3350668, 3807780, 4288508, 4792852, 4933740, -17628, -8388, + 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, 1229652, 1495532, 1787076, 2104284, 2447156, + 2609356, -127452, 87100, 327316, 593196, 884740, 1201948, 1544820, 1913356, 2307556, 2727420, 3172948, + 3644140, 4140996, 4663516, 5211700, 5366220, -19004, -8932, 28852, 94348, 187556, 308476, 457108, 633452, + 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -135996, 95772, 355252, 642444, + 957348, 1299964, 1670292, 2068332, 2494084, 2947548, 3428724, 3937612, 4474212, 5038524, 5630548, 5798700, + -20380, -9476, 31188, 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, + 2442652, 2840436, 3029900, -144540, 104444, 383188, 691692, 1029956, 1397980, 1795764, 2223308, 2680612, + 3167676, 3684500, 4231084, 4807428, 5413532, 6049396, 6231180, -21756, -10020, 33524, 108876, 216036, + 355004, 525780, 728364, 1001124, 1228956, 1526964, 1856780, 2218404, 2611836, 3037076, 3240172, -153084, + 113116, 411124, 740940, 1102564, 1495996, 1921236, 2378284, 2867140, 3387804, 3940276, 4524556, 5140644, + 5788540, 6468244, 6663660, -23132, -10564, 35860, 116140, 230276, 378268, 560116, 775820, 1066308, + 1308796, 1626068, 1977196, 2362180, 2781020, 3233716, 3450444, -161628, 121788, 439060, 790188, 1175172, + 1594012, 2046708, 2533260, 3053668, 3607932, 4196052, 4818028, 5473860, 6163548, 6887092, 7096140, -24508, + -11108, 38196, 123404, 244516, 401532, 594452, 823276, 1131492, 1388636, 1725172, 2097612, 2505956, + 2950204, 3430356, 3660716, -170172, 130460, 466996, 839436, 1247780, 1692028, 2172180, 2688236, 3240196, + 3828060, 4451828, 5111500, 5807076, 6538556, 7305940, 7528620, -25884, -11652, 40532, 130668, 258756, + 424796, 628788, 870732, 1196676, 1468476, 1824276, 2218028, 2649732, 3119388, 3626996, 3870988, -178716, + 139132, 494932, 888684, 1320388, 1790044, 2297652, 2843212, 3426724, 4048188, 4707604, 5404972, 6140292, + 6913564, 7724788, 7961100, -27260, -12196, 42868, 137932, 272996, 448060, 663124, 918188, 1261860, + 1548316, 1923380, 2338444, 2793508, 3288572, 3823636, 4081260, -187260, 147804, 522868, 937932, 1392996, + 1888060, 2423124, 2998188, 3613252, 4268316, 4963380, 5698444, 6473508, 7288572, 8143636, 8393580, -28636, + -12740, 45204, 145196, 287236, 471324, 697460, 965644, 1327044, 1628156, 2022484, 2458860, 2937284, + 3457756, 4020276, 4291532, -195804, 156476, 550804, 987180, 1465604, 1986076, 2548596, 3153164, 3799780, + 4488444, 5219156, 5991916, 6806724, 7663580, 8562484, 8826060, -30012, -13284, 47540, 152460, 301476, + 494588, 731796, 1013100, 1392228, 1707996, 2121588, 2579276, 3081060, 3626940, 4216916, 4501804, -204348, + 165148, 578740, 1036428, 1538212, 2084092, 2674068, 3308140, 3986308, 4708572, 5474932, 6285388, 7139940, + 8038588, 8981332, 9258540, -31388, -13828, 49876, 159724, 315716, 517852, 766132, 1060556, 1457412, + 1787836, 2220692, 2699692, 3224836, 3796124, 4413556, 4712076, -212892, 173820, 606676, 1085676, 1610820, + 2182108, 2799540, 3463116, 4172836, 4928700, 5730708, 6578860, 7473156, 8413596, 9400180, 9691020, -32764, + -14372, 52212, 166988, 329956, 541116, 800468, 1108012, 1522596, 1867676, 2319796, 2820108, 3368612, + 3965308, 4610196, 4922348, -221436, 182492, 634612, 1134924, 1683428, 2280124, 2925012, 3618092, 4359364, + 5148828, 5986484, 6872332, 7806372, 8788604, 9819028, 10123500, -34140, -14916, 54548, 174252, 344196, + 564380, 834804, 1155468, 1587780, 1947516, 2418900, 2940524, 3512388, 4134492, 4806836, 5132620, -229980, + 191164, 662548, 1184172, 1756036, 2378140, 3050484, 3773068, 4545892, 5368956, 6242260, 7165804, 8139588, + 9163612, 10237876, 10555980, -35516, -15460, 56884, 181516, 358436, 587644, 869140, 1202924, 1652964, + 2027356, 2518004, 3060940, 3656164, 4303676, 5003476, 5342892, -238524, 199836, 690484, 1233420, 1828644, + 2476156, 3175956, 3928044, 4732420, 5589084, 6498036, 7459276, 8472804, 9538620, 10656724, 10988460, + -36892, -16004, 59220, 188780, 372676, 610908, 903476, 1250380, 1718148, 2107196, 2617108, 3181356, + 3799940, 4472860, 5200116, 5553164, -247068, 208508, 718420, 1282668, 1901252, 2574172, 3301428, 4083020, + 4918948, 5809212, 6753812, 7752748, 8806020, 9913628, 11075572, 11420940, -38268, -16548, 61556, 196044, + 386916, 634172, 937812, 1297836, 1783332, 2187036, 2716212, 3301772, 3943716, 4642044, 5396756, 5763436, + -255612, 217180, 746356, 1331916, 1973860, 2672188, 3426900, 4237996, 5105476, 6029340, 7009588, 8046220, + 9139236, 10288636, 11494420, 11853420, -39644, -17092, 63892, 203308, 401156, 657436, 972148, 1345292, + 1848516, 2266876, 2815316, 3422188, 4087492, 4811228, 5593396, 5973708, -264156, 225852, 774292, 1381164, + 2046468, 2770204, 3552372, 4392972, 5292004, 6249468, 7265364, 8339692, 9472452, 10663644, 11913268, + 12285900, -41020, -17636, 66228, 210572, 415396, 680700, 1006484, 1392748, 1913700, 2346716, 2914420, + 3542604, 4231268, 4980412, 5790036, 6183980, -272700, 234524, 802228, 1430412, 2119076, 2868220, 3677844, + 4547948, 5478532, 6469596, 7521140, 8633164, 9805668, 11038652, 12332116, 12718380, -42396, -18180, 68564, + 217836, 429636, 703964, 1040820, 1440204, 1978884, 2426556, 3013524, 3663020, 4375044, 5149596, 5986676, + 6394252, -281244, 243196, 830164, 1479660, 2191684, 2966236, 3803316, 4702924, 5665060, 6689724, 7776916, + 8926636, 10138884, 11413660, 12750964, 13150860, -43772, -18724, 70900, 225100, 443876, 727228, 1075156, + 1487660, 2044068, 2506396, 3112628, 3783436, 4518820, 5318780, 6183316, 6604524, -289788, 251868, 858100, + 1528908, 2264292, 3064252, 3928788, 4857900, 5851588, 6909852, 8032692, 9220108, 10472100, 11788668, + 13169812, 13583340 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -1560, -2576, -3048, -2976, -2360, -1200, 504, 2736, 5544, 8880, 12760, 17184, 22152, 27664, 26040, + -29312, -26520, -23184, -19304, -14880, -9912, -4400, 1656, 8256, 15400, 23088, 31320, 40096, 49416, + 59280, 53816, 0, -5368, -9168, -11400, -12064, -11160, -8688, -4648, 2224, 8136, 16880, 27192, 39072, + 52520, 67536, 68760, -98432, -91256, -82512, -72200, -60320, -46872, -31856, -15272, 2880, 22600, 43888, + 66744, 91168, 117160, 144720, 142104, 0, -9176, -15760, -19752, -21152, -19960, -16176, -9800, 1712, + 10728, 24880, 41624, 60960, 82888, 107408, 111480, -167552, -155992, -141840, -125096, -105760, -83832, + -59312, -32200, -2496, 29800, 64688, 102168, 142240, 184904, 230160, 230392, 0, -12984, -22352, -28104, + -30240, -28760, -23664, -14952, 1200, 13320, 32880, 56056, 82848, 113256, 147280, 154200, -236672, + -220728, -201168, -177992, -151200, -120792, -86768, -49128, -7872, 37000, 85488, 137592, 193312, 252648, + 315600, 318680, 0, -16792, -28944, -36456, -39328, -37560, -31152, -20104, 688, 15912, 40880, 70488, + 104736, 143624, 187152, 196920, -305792, -285464, -260496, -230888, -196640, -157752, -114224, -66056, + -13248, 44200, 106288, 173016, 244384, 320392, 401040, 406968, 0, -20600, -35536, -44808, -48416, -46360, + -38640, -25256, 176, 18504, 48880, 84920, 126624, 173992, 227024, 239640, -374912, -350200, -319824, + -283784, -242080, -194712, -141680, -82984, -18624, 51400, 127088, 208440, 295456, 388136, 486480, 495256, + 0, -24408, -42128, -53160, -57504, -55160, -46128, -30408, -336, 21096, 56880, 99352, 148512, 204360, + 266896, 282360, -444032, -414936, -379152, -336680, -287520, -231672, -169136, -99912, -24000, 58600, + 147888, 243864, 346528, 455880, 571920, 583544, 0, -28216, -48720, -61512, -66592, -63960, -53616, -35560, + -848, 23688, 64880, 113784, 170400, 234728, 306768, 325080, -513152, -479672, -438480, -389576, -332960, + -268632, -196592, -116840, -29376, 65800, 168688, 279288, 397600, 523624, 657360, 671832, 0, -32024, + -55312, -69864, -75680, -72760, -61104, -40712, -1360, 26280, 72880, 128216, 192288, 265096, 346640, + 367800, -582272, -544408, -497808, -442472, -378400, -305592, -224048, -133768, -34752, 73000, 189488, + 314712, 448672, 591368, 742800, 760120, 0, -35832, -61904, -78216, -84768, -81560, -68592, -45864, -1872, + 28872, 80880, 142648, 214176, 295464, 386512, 410520, -651392, -609144, -557136, -495368, -423840, + -342552, -251504, -150696, -40128, 80200, 210288, 350136, 499744, 659112, 828240, 848408, 0, -39640, + -68496, -86568, -93856, -90360, -76080, -51016, -2384, 31464, 88880, 157080, 236064, 325832, 426384, + 453240, -720512, -673880, -616464, -548264, -469280, -379512, -278960, -167624, -45504, 87400, 231088, + 385560, 550816, 726856, 913680, 936696, 0, -43448, -75088, -94920, -102944, -99160, -83568, -56168, -2896, + 34056, 96880, 171512, 257952, 356200, 466256, 495960, -789632, -738616, -675792, -601160, -514720, + -416472, -306416, -184552, -50880, 94600, 251888, 420984, 601888, 794600, 999120, 1024984, 0, -47256, + -81680, -103272, -112032, -107960, -91056, -61320, -3408, 36648, 104880, 185944, 279840, 386568, 506128, + 538680, -858752, -803352, -735120, -654056, -560160, -453432, -333872, -201480, -56256, 101800, 272688, + 456408, 652960, 862344, 1084560, 1113272, 0, -51064, -88272, -111624, -121120, -116760, -98544, -66472, + -3920, 39240, 112880, 200376, 301728, 416936, 546000, 581400, -927872, -868088, -794448, -706952, -605600, + -490392, -361328, -218408, -61632, 109000, 293488, 491832, 704032, 930088, 1170000, 1201560, 0, -54872, + -94864, -119976, -130208, -125560, -106032, -71624, -4432, 41832, 120880, 214808, 323616, 447304, 585872, + 624120, -996992, -932824, -853776, -759848, -651040, -527352, -388784, -235336, -67008, 116200, 314288, + 527256, 755104, 997832, 1255440, 1289848, 0, -58680, -101456, -128328, -139296, -134360, -113520, -76776, + -4944, 44424, 128880, 229240, 345504, 477672, 625744, 666840, -1066112, -997560, -913104, -812744, + -696480, -564312, -416240, -252264, -72384, 123400, 335088, 562680, 806176, 1065576, 1340880, 1378136, 0, + -62488, -108048, -136680, -148384, -143160, -121008, -81928, -5456, 47016, 136880, 243672, 367392, 508040, + 665616, 709560, -1135232, -1062296, -972432, -865640, -741920, -601272, -443696, -269192, -77760, 130600, + 355888, 598104, 857248, 1133320, 1426320, 1466424, 0, -66296, -114640, -145032, -157472, -151960, -128496, + -87080, -5968, 49608, 144880, 258104, 389280, 538408, 705488, 752280, -1204352, -1127032, -1031760, + -918536, -787360, -638232, -471152, -286120, -83136, 137800, 376688, 633528, 908320, 1201064, 1511760, + 1554712, 0, -70104, -121232, -153384, -166560, -160760, -135984, -92232, -6480, 52200, 152880, 272536, + 411168, 568776, 745360, 795000, -1273472, -1191768, -1091088, -971432, -832800, -675192, -498608, -303048, + -88512, 145000, 397488, 668952, 959392, 1268808, 1597200, 1643000, 0, -73912, -127824, -161736, -175648, + -169560, -143472, -97384, -6992, 54792, 160880, 286968, 433056, 599144, 785232, 837720, -1342592, + -1256504, -1150416, -1024328, -878240, -712152, -526064, -319976, -93888, 152200, 418288, 704376, 1010464, + 1336552, 1682640, 1731288, 0, -77720, -134416, -170088, -184736, -178360, -150960, -102536, -7504, 57384, + 168880, 301400, 454944, 629512, 825104, 880440, -1411712, -1321240, -1209744, -1077224, -923680, -749112, + -553520, -336904, -99264, 159400, 439088, 739800, 1061536, 1404296, 1768080, 1819576, 0, -81528, -141008, + -178440, -193824, -187160, -158448, -107688, -8016, 59976, 176880, 315832, 476832, 659880, 864976, 923160, + -1480832, -1385976, -1269072, -1130120, -969120, -786072, -580976, -353832, -104640, 166600, 459888, + 775224, 1112608, 1472040, 1853520, 1907864, 0, -85336, -147600, -186792, -202912, -195960, -165936, + -112840, -8528, 62568, 184880, 330264, 498720, 690248, 904848, 965880, -1549952, -1450712, -1328400, + -1183016, -1014560, -823032, -608432, -370760, -110016, 173800, 480688, 810648, 1163680, 1539784, 1938960, + 1996152, 0, -89144, -154192, -195144, -212000, -204760, -173424, -117992, -9040, 65160, 192880, 344696, + 520608, 720616, 944720, 1008600, -1619072, -1515448, -1387728, -1235912, -1060000, -859992, -635888, + -387688, -115392, 181000, 501488, 846072, 1214752, 1607528, 2024400, 2084440, 0, -92952, -160784, -203496, + -221088, -213560, -180912, -123144, -9552, 67752, 200880, 359128, 542496, 750984, 984592, 1051320, + -1688192, -1580184, -1447056, -1288808, -1105440, -896952, -663344, -404616, -120768, 188200, 522288, + 881496, 1265824, 1675272, 2109840, 2172728, 0, -96760, -167376, -211848, -230176, -222360, -188400, + -128296, -10064, 70344, 208880, 373560, 564384, 781352, 1024464, 1094040, -1757312, -1644920, -1506384, + -1341704, -1150880, -933912, -690800, -421544, -126144, 195400, 543088, 916920, 1316896, 1743016, 2195280, + 2261016, 0, -100568, -173968, -220200, -239264, -231160, -195888, -133448, -10576, 72936, 216880, 387992, + 586272, 811720, 1064336, 1136760, -1826432, -1709656, -1565712, -1394600, -1196320, -970872, -718256, + -438472, -131520, 202600, 563888, 952344, 1367968, 1810760, 2280720, 2349304, 0, -104376, -180560, + -228552, -248352, -239960, -203376, -138600, -11088, 75528, 224880, 402424, 608160, 842088, 1104208, + 1179480, -1895552, -1774392, -1625040, -1447496, -1241760, -1007832, -745712, -455400, -136896, 209800, + 584688, 987768, 1419040, 1878504, 2366160, 2437592, 0, -108184, -187152, -236904, -257440, -248760, + -210864, -143752, -11600, 78120, 232880, 416856, 630048, 872456, 1144080, 1222200, -1964672, -1839128, + -1684368, -1500392, -1287200, -1044792, -773168, -472328, -142272, 217000, 605488, 1023192, 1470112, + 1946248, 2451600, 2525880, 0, -111992, -193744, -245256, -266528, -257560, -218352, -148904, -12112, + 80712, 240880, 431288, 651936, 902824, 1183952, 1264920, -2033792, -1903864, -1743696, -1553288, -1332640, + -1081752, -800624, -489256, -147648, 224200, 626288, 1058616, 1521184, 2013992, 2537040, 2614168, 0, + -115800, -200336, -253608, -275616, -266360, -225840, -154056, -12624, 83304, 248880, 445720, 673824, + 933192, 1223824, 1307640, -2102912, -1968600, -1803024, -1606184, -1378080, -1118712, -828080, -506184, + -153024, 231400, 647088, 1094040, 1572256, 2081736, 2622480, 2702456, 0, -119608, -206928, -261960, + -284704, -275160, -233328, -159208, -13136, 85896, 256880, 460152, 695712, 963560, 1263696, 1350360, + -2172032, -2033336, -1862352, -1659080, -1423520, -1155672, -855536, -523112, -158400, 238600, 667888, + 1129464, 1623328, 2149480, 2707920, 2790744 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 2664, 5872, 9624, 13920, 18760, 24144, 30072, 36528, 43560, 51120, 59224, 67872, 77064, 86800, 89400, + 38272, 45288, 52848, 60952, 69600, 78792, 88528, 98808, 109632, 121000, 132912, 145368, 158368, 171912, + 186000, 184760, 0, 7048, 15664, 25848, 37600, 50920, 65808, 82264, 101552, 119880, 141040, 163768, 188064, + 213928, 241360, 255000, 100224, 119816, 140976, 163704, 188000, 213864, 241296, 270296, 300864, 333000, + 366704, 401976, 438816, 477224, 517200, 527000, 0, 11432, 25456, 42072, 61280, 83080, 107472, 134456, + 166576, 196200, 230960, 268312, 308256, 350792, 395920, 420600, 162176, 194344, 229104, 266456, 306400, + 348936, 394064, 441784, 492096, 545000, 600496, 658584, 719264, 782536, 848400, 869240, 0, 15816, 35248, + 58296, 84960, 115240, 149136, 186648, 231600, 272520, 320880, 372856, 428448, 487656, 550480, 586200, + 224128, 268872, 317232, 369208, 424800, 484008, 546832, 613272, 683328, 757000, 834288, 915192, 999712, + 1087848, 1179600, 1211480, 0, 20200, 45040, 74520, 108640, 147400, 190800, 238840, 296624, 348840, 410800, + 477400, 548640, 624520, 705040, 751800, 286080, 343400, 405360, 471960, 543200, 619080, 699600, 784760, + 874560, 969000, 1068080, 1171800, 1280160, 1393160, 1510800, 1553720, 0, 24584, 54832, 90744, 132320, + 179560, 232464, 291032, 361648, 425160, 500720, 581944, 668832, 761384, 859600, 917400, 348032, 417928, + 493488, 574712, 661600, 754152, 852368, 956248, 1065792, 1181000, 1301872, 1428408, 1560608, 1698472, + 1842000, 1895960, 0, 28968, 64624, 106968, 156000, 211720, 274128, 343224, 426672, 501480, 590640, 686488, + 789024, 898248, 1014160, 1083000, 409984, 492456, 581616, 677464, 780000, 889224, 1005136, 1127736, + 1257024, 1393000, 1535664, 1685016, 1841056, 2003784, 2173200, 2238200, 0, 33352, 74416, 123192, 179680, + 243880, 315792, 395416, 491696, 577800, 680560, 791032, 909216, 1035112, 1168720, 1248600, 471936, 566984, + 669744, 780216, 898400, 1024296, 1157904, 1299224, 1448256, 1605000, 1769456, 1941624, 2121504, 2309096, + 2504400, 2580440, 0, 37736, 84208, 139416, 203360, 276040, 357456, 447608, 556720, 654120, 770480, 895576, + 1029408, 1171976, 1323280, 1414200, 533888, 641512, 757872, 882968, 1016800, 1159368, 1310672, 1470712, + 1639488, 1817000, 2003248, 2198232, 2401952, 2614408, 2835600, 2922680, 0, 42120, 94000, 155640, 227040, + 308200, 399120, 499800, 621744, 730440, 860400, 1000120, 1149600, 1308840, 1477840, 1579800, 595840, + 716040, 846000, 985720, 1135200, 1294440, 1463440, 1642200, 1830720, 2029000, 2237040, 2454840, 2682400, + 2919720, 3166800, 3264920, 0, 46504, 103792, 171864, 250720, 340360, 440784, 551992, 686768, 806760, + 950320, 1104664, 1269792, 1445704, 1632400, 1745400, 657792, 790568, 934128, 1088472, 1253600, 1429512, + 1616208, 1813688, 2021952, 2241000, 2470832, 2711448, 2962848, 3225032, 3498000, 3607160, 0, 50888, + 113584, 188088, 274400, 372520, 482448, 604184, 751792, 883080, 1040240, 1209208, 1389984, 1582568, + 1786960, 1911000, 719744, 865096, 1022256, 1191224, 1372000, 1564584, 1768976, 1985176, 2213184, 2453000, + 2704624, 2968056, 3243296, 3530344, 3829200, 3949400, 0, 55272, 123376, 204312, 298080, 404680, 524112, + 656376, 816816, 959400, 1130160, 1313752, 1510176, 1719432, 1941520, 2076600, 781696, 939624, 1110384, + 1293976, 1490400, 1699656, 1921744, 2156664, 2404416, 2665000, 2938416, 3224664, 3523744, 3835656, + 4160400, 4291640, 0, 59656, 133168, 220536, 321760, 436840, 565776, 708568, 881840, 1035720, 1220080, + 1418296, 1630368, 1856296, 2096080, 2242200, 843648, 1014152, 1198512, 1396728, 1608800, 1834728, 2074512, + 2328152, 2595648, 2877000, 3172208, 3481272, 3804192, 4140968, 4491600, 4633880, 0, 64040, 142960, 236760, + 345440, 469000, 607440, 760760, 946864, 1112040, 1310000, 1522840, 1750560, 1993160, 2250640, 2407800, + 905600, 1088680, 1286640, 1499480, 1727200, 1969800, 2227280, 2499640, 2786880, 3089000, 3406000, 3737880, + 4084640, 4446280, 4822800, 4976120, 0, 68424, 152752, 252984, 369120, 501160, 649104, 812952, 1011888, + 1188360, 1399920, 1627384, 1870752, 2130024, 2405200, 2573400, 967552, 1163208, 1374768, 1602232, 1845600, + 2104872, 2380048, 2671128, 2978112, 3301000, 3639792, 3994488, 4365088, 4751592, 5154000, 5318360, 0, + 72808, 162544, 269208, 392800, 533320, 690768, 865144, 1076912, 1264680, 1489840, 1731928, 1990944, + 2266888, 2559760, 2739000, 1029504, 1237736, 1462896, 1704984, 1964000, 2239944, 2532816, 2842616, + 3169344, 3513000, 3873584, 4251096, 4645536, 5056904, 5485200, 5660600, 0, 77192, 172336, 285432, 416480, + 565480, 732432, 917336, 1141936, 1341000, 1579760, 1836472, 2111136, 2403752, 2714320, 2904600, 1091456, + 1312264, 1551024, 1807736, 2082400, 2375016, 2685584, 3014104, 3360576, 3725000, 4107376, 4507704, + 4925984, 5362216, 5816400, 6002840, 0, 81576, 182128, 301656, 440160, 597640, 774096, 969528, 1206960, + 1417320, 1669680, 1941016, 2231328, 2540616, 2868880, 3070200, 1153408, 1386792, 1639152, 1910488, + 2200800, 2510088, 2838352, 3185592, 3551808, 3937000, 4341168, 4764312, 5206432, 5667528, 6147600, + 6345080, 0, 85960, 191920, 317880, 463840, 629800, 815760, 1021720, 1271984, 1493640, 1759600, 2045560, + 2351520, 2677480, 3023440, 3235800, 1215360, 1461320, 1727280, 2013240, 2319200, 2645160, 2991120, + 3357080, 3743040, 4149000, 4574960, 5020920, 5486880, 5972840, 6478800, 6687320, 0, 90344, 201712, 334104, + 487520, 661960, 857424, 1073912, 1337008, 1569960, 1849520, 2150104, 2471712, 2814344, 3178000, 3401400, + 1277312, 1535848, 1815408, 2115992, 2437600, 2780232, 3143888, 3528568, 3934272, 4361000, 4808752, + 5277528, 5767328, 6278152, 6810000, 7029560, 0, 94728, 211504, 350328, 511200, 694120, 899088, 1126104, + 1402032, 1646280, 1939440, 2254648, 2591904, 2951208, 3332560, 3567000, 1339264, 1610376, 1903536, + 2218744, 2556000, 2915304, 3296656, 3700056, 4125504, 4573000, 5042544, 5534136, 6047776, 6583464, + 7141200, 7371800, 0, 99112, 221296, 366552, 534880, 726280, 940752, 1178296, 1467056, 1722600, 2029360, + 2359192, 2712096, 3088072, 3487120, 3732600, 1401216, 1684904, 1991664, 2321496, 2674400, 3050376, + 3449424, 3871544, 4316736, 4785000, 5276336, 5790744, 6328224, 6888776, 7472400, 7714040, 0, 103496, + 231088, 382776, 558560, 758440, 982416, 1230488, 1532080, 1798920, 2119280, 2463736, 2832288, 3224936, + 3641680, 3898200, 1463168, 1759432, 2079792, 2424248, 2792800, 3185448, 3602192, 4043032, 4507968, + 4997000, 5510128, 6047352, 6608672, 7194088, 7803600, 8056280, 0, 107880, 240880, 399000, 582240, 790600, + 1024080, 1282680, 1597104, 1875240, 2209200, 2568280, 2952480, 3361800, 3796240, 4063800, 1525120, + 1833960, 2167920, 2527000, 2911200, 3320520, 3754960, 4214520, 4699200, 5209000, 5743920, 6303960, + 6889120, 7499400, 8134800, 8398520, 0, 112264, 250672, 415224, 605920, 822760, 1065744, 1334872, 1662128, + 1951560, 2299120, 2672824, 3072672, 3498664, 3950800, 4229400, 1587072, 1908488, 2256048, 2629752, + 3029600, 3455592, 3907728, 4386008, 4890432, 5421000, 5977712, 6560568, 7169568, 7804712, 8466000, + 8740760, 0, 116648, 260464, 431448, 629600, 854920, 1107408, 1387064, 1727152, 2027880, 2389040, 2777368, + 3192864, 3635528, 4105360, 4395000, 1649024, 1983016, 2344176, 2732504, 3148000, 3590664, 4060496, + 4557496, 5081664, 5633000, 6211504, 6817176, 7450016, 8110024, 8797200, 9083000, 0, 121032, 270256, + 447672, 653280, 887080, 1149072, 1439256, 1792176, 2104200, 2478960, 2881912, 3313056, 3772392, 4259920, + 4560600, 1710976, 2057544, 2432304, 2835256, 3266400, 3725736, 4213264, 4728984, 5272896, 5845000, + 6445296, 7073784, 7730464, 8415336, 9128400, 9425240, 0, 125416, 280048, 463896, 676960, 919240, 1190736, + 1491448, 1857200, 2180520, 2568880, 2986456, 3433248, 3909256, 4414480, 4726200, 1772928, 2132072, + 2520432, 2938008, 3384800, 3860808, 4366032, 4900472, 5464128, 6057000, 6679088, 7330392, 8010912, + 8720648, 9459600, 9767480, 0, 129800, 289840, 480120, 700640, 951400, 1232400, 1543640, 1922224, 2256840, + 2658800, 3091000, 3553440, 4046120, 4569040, 4891800, 1834880, 2206600, 2608560, 3040760, 3503200, + 3995880, 4518800, 5071960, 5655360, 6269000, 6912880, 7587000, 8291360, 9025960, 9790800, 10109720, 0, + 134184, 299632, 496344, 724320, 983560, 1274064, 1595832, 1987248, 2333160, 2748720, 3195544, 3673632, + 4182984, 4723600, 5057400, 1896832, 2281128, 2696688, 3143512, 3621600, 4130952, 4671568, 5243448, + 5846592, 6481000, 7146672, 7843608, 8571808, 9331272, 10122000, 10451960, 0, 138568, 309424, 512568, + 748000, 1015720, 1315728, 1648024, 2052272, 2409480, 2838640, 3300088, 3793824, 4319848, 4878160, 5223000, + 1958784, 2355656, 2784816, 3246264, 3740000, 4266024, 4824336, 5414936, 6037824, 6693000, 7380464, + 8100216, 8852256, 9636584, 10453200, 10794200 + ] + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 55b21283025c2..1c61518ddcdd2 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1362,6 +1362,7 @@ "less.jsonc", "log.jsonc", "matmul.jsonc", + "matmulnbits.jsonc", "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index bd58dded026a6..25e7567a2e9fc 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -8,13 +8,14 @@ namespace contrib { namespace js { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -25,14 +26,15 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + SkipLayerNormalization)>}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..888db0fd161f2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/js/quantization/matmul_nbits.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", JsepSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..cca2c4757765b --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class MatMulNBits final : public JsKernel { + public: + MatMulNBits(const OpKernelInfo& info) : JsKernel(info), + K_{narrow(info.GetAttr("K"))}, + N_{narrow(info.GetAttr("N"))}, + accuracy_level_{info.GetAttrOrDefault("accuracy_level", 0)}, + nbits_{narrow(info.GetAttr("bits"))}, + block_size_{narrow(info.GetAttr("block_size"))} { + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)), + "Block size must be a power of 2 and greater than or equal to 16."); + JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({ + "k" : $1, + "n" : $2, + "accuracyLevel" : $3, + "bits" : $4, + "blockSize" : $5 + }), + static_cast(K_), + static_cast(N_), + static_cast(accuracy_level_), + static_cast(nbits_), + static_cast(block_size_)); + } + + private: + const size_t K_; + const size_t N_; + const int64_t accuracy_level_; + const size_t nbits_; + const size_t block_size_; +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime From b55260d076da309f3a4634eb5248a0eb541e8ca0 Mon Sep 17 00:00:00 2001 From: pengwa Date: Mon, 19 Feb 2024 10:21:19 +0800 Subject: [PATCH 097/207] Minor fix for cmake (#19552) ### Minor fix for cmake When build on Linux, get a warning saying " CMake Warning at CMakeLists.txt:1603 (message): MPI and NCCL disabled on Win build. " This message is not correct. So have such a fix to avoid any misunderstanding from users. ![image](https://github.com/microsoft/onnxruntime/assets/10530022/848c2d77-a538-4e31-8e0d-4b539233e515) ### Motivation and Context --- cmake/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ff1c7a84f077f..c9be4aa65d0cc 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1600,7 +1600,7 @@ if (UNIX AND onnxruntime_USE_NCCL) else() set(onnxruntime_USE_NCCL OFF) set(onnxruntime_USE_MPI OFF) -message( WARNING "MPI and NCCL disabled on Win build." ) + message( WARNING "MPI and NCCL are disabled because build is on Windows or USE_NCCL is set to OFF." ) endif() if (onnxruntime_USE_MPI) From f3e3b531fe4c0d33d70928b101fb5d445e4174a8 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Tue, 20 Feb 2024 10:31:39 +0800 Subject: [PATCH 098/207] Update build directory clean up stage for python package pipeline (#19553) Fix to make clean up stage take effect. If the `SourceFolder ` is empty, the task deletes files from the root folder of the repository as though [$(Build.SourcesDirectory)](https://learn.microsoft.com/en-us/azure/devops/pipelines/build/variables) was specified. --- .../component-governance-component-detection-steps.yml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml index c2ef565a6e9ee..f1418e75bffa2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml @@ -5,10 +5,12 @@ parameters: default: 'succeeded' # could be 'ci_only', 'always', 'succeeded' steps: -- ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: +- ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: - task: DeleteFiles@1 inputs: - contents: $(Build.BinariesDirectory)/* + SourceFolder: '$(Build.BinariesDirectory)' + contents: | + **/* displayName: 'Clean up build directory' - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 From e832562d70685ffeaab7e3bfa20cd5e9aec916a3 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Tue, 20 Feb 2024 09:06:03 +0100 Subject: [PATCH 099/207] Fix invalid usage of designated initializers. (#19497) ### Description I've replaces all ocurances of C++ designated initializers in the CUDA NHWC Tests by member initialization. ### Motivation and Context C++ designated initializers have been introduced in C++ 20. Yet GCC accepts designated initializers in C++17 which is the standard used to compile onnxruntime. Yet MSVC is standard conform and accepts this feature starting C++20 which leads to compile failures on Windows without this change. --- .../test/providers/cuda/nhwc/conv_test.cc | 23 +++++++--- .../cuda/nhwc/conv_transpose_test.cc | 40 +++++++++------- .../providers/cuda/nhwc/nhwc_cuda_helper.h | 6 ++- .../test/providers/cuda/nhwc/norm_test.cc | 7 ++- .../test/providers/cuda/nhwc/pool_test.cc | 46 ++++++++++--------- 5 files changed, 72 insertions(+), 50 deletions(-) diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_test.cc index 13d4546d669e3..b6a760f7041ad 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_test.cc @@ -9,8 +9,8 @@ namespace test { template struct ConvOp { - const std::vector input_dims; - const std::vector kernel_shape; + std::vector input_dims; + std::vector kernel_shape; int64_t channels; int64_t group = 1; bool bias = false; @@ -52,20 +52,31 @@ struct ConvOp { }; TYPED_TEST(CudaNhwcTypedTest, ConvNhwcBias) { - auto op = ConvOp{.input_dims = {1, 16, 64, 64}, .kernel_shape = {3, 3}, .channels = 16, .bias = true}; + auto op = ConvOp{}; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.bias = true; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvNhwcGroupNoBias) { - auto op = ConvOp{.input_dims = {1, 16, 64, 64}, .kernel_shape = {3, 3}, .channels = 16, .group = 4}; + auto op = ConvOp{}; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.group = 4; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvNhwcPadding) { - auto op = - ConvOp{.input_dims = {2, 4, 64, 64}, .kernel_shape = {3, 3}, .channels = 4, .padding = {4, 4, 4, 4}}; + auto op = ConvOp{}; + op.input_dims = {2, 4, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 4; + op.padding = {4, 4, 4, 4}; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc index 6514feadf0ff7..786b2cb4cedc4 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc @@ -9,8 +9,8 @@ namespace test { template struct ConvTransposeOp { - const std::vector input_dims; - const std::vector kernel_shape; + std::vector input_dims; + std::vector kernel_shape; int64_t channels; int64_t group = 1; bool bias = false; @@ -60,15 +60,21 @@ struct ConvTransposeOp { }; TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcGroupNoBias) { - auto op = - ConvTransposeOp{.input_dims = {8, 8, 32, 32}, .kernel_shape = {3, 3}, .channels = 16, .group = 4}; + auto op = ConvTransposeOp{}; + op.input_dims = {8, 8, 32, 32}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.group = 4; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { - auto op = - ConvTransposeOp{.input_dims = {1, 8, 80, 80}, .kernel_shape = {5, 5}, .channels = 16, .bias = true}; + auto op = ConvTransposeOp{}; + op.input_dims = {1, 8, 80, 80}; + op.kernel_shape = {5, 5}; + op.channels = 16; + op.bias = true; if (HasCudaEnvironment(800)) { MAKE_PROVIDERS_EPS(1e-2) @@ -78,21 +84,23 @@ TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcPad) { - auto op = ConvTransposeOp{.input_dims = {1, 16, 8, 8}, - .kernel_shape = {3, 3}, - .channels = 32, - .padding = {2, 2, 2, 2}, - .output_padding = {}}; + auto op = ConvTransposeOp{}; + op.input_dims = {1, 16, 8, 8}; + op.kernel_shape = {3, 3}; + op.channels = 32; + op.padding = {2, 2, 2, 2}; + op.output_padding = {}; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcOutPad) { - auto op = ConvTransposeOp{.input_dims = {1, 32, 8, 8}, - .kernel_shape = {3, 3}, - .channels = 32, - .strides = {2, 2}, - .output_padding = {1, 1, 1, 1}}; + auto op = ConvTransposeOp{}; + op.input_dims = {1, 32, 8, 8}; + op.kernel_shape = {3, 3}; + op.channels = 32; + op.strides = {2, 2}; + op.output_padding = {1, 1, 1, 1}; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } diff --git a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h index 2c942bb790096..82b6a286409cd 100644 --- a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h +++ b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h @@ -16,11 +16,13 @@ #define MAKE_PROVIDERS_EPS(eps) \ std::vector> execution_providers; \ - OrtCUDAProviderOptionsV2 nhwc = {.prefer_nhwc = true}; \ + OrtCUDAProviderOptionsV2 nhwc{}; \ + nhwc.prefer_nhwc = true; \ execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc)); \ \ double error_tolerance = eps; \ - OrtCUDAProviderOptionsV2 nchw = {.prefer_nhwc = false}; \ + OrtCUDAProviderOptionsV2 nchw{}; \ + nchw.prefer_nhwc = false; \ auto source_ep = CudaExecutionProviderWithOptions(&nchw); \ auto test = op.get_test(); \ test->CompareEPs(std::move(source_ep), execution_providers, error_tolerance); diff --git a/onnxruntime/test/providers/cuda/nhwc/norm_test.cc b/onnxruntime/test/providers/cuda/nhwc/norm_test.cc index 52da8ba557c2d..40f69e3bd5b4f 100644 --- a/onnxruntime/test/providers/cuda/nhwc/norm_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/norm_test.cc @@ -9,7 +9,7 @@ namespace test { template struct BatchNormOp { - const std::vector input_dims; + std::vector input_dims; std::unique_ptr get_test() { // create rand inputs @@ -40,9 +40,8 @@ struct BatchNormOp { }; TYPED_TEST(CudaNhwcTypedTest, BatchNormNhwc) { - auto op = BatchNormOp{ - .input_dims = {4, 16, 64, 64}, - }; + auto op = BatchNormOp{}; + op.input_dims = {4, 16, 64, 64}; MAKE_PROVIDERS() } diff --git a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc index e0d59901da80c..426170b9588f1 100644 --- a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc @@ -9,9 +9,9 @@ namespace test { template struct PoolOp { - const std::string pooling_type; - const std::vector input_dims; - const std::vector kernel_shape; + std::string pooling_type; + std::vector input_dims; + std::vector kernel_shape; int64_t channels; int64_t group = 1; std::vector strides = {1, 1}; @@ -41,22 +41,21 @@ struct PoolOp { }; TYPED_TEST(CudaNhwcTypedTest, AveragePoolNhwc) { - auto op = PoolOp{ - .pooling_type = "AveragePool", - .input_dims = {1, 16, 64, 64}, - .kernel_shape = {3, 3}, - .channels = 16, - }; + auto op = PoolOp{}; + op.pooling_type = "AveragePool"; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + MAKE_PROVIDERS() } TYPED_TEST(CudaNhwcTypedTest, MaxPoolNhwc) { - auto op = PoolOp{ - .pooling_type = "MaxPool", - .input_dims = {1, 16, 64, 64}, - .kernel_shape = {3, 3}, - .channels = 16, - }; + auto op = PoolOp{}; + op.pooling_type = "MaxPool"; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; MAKE_PROVIDERS() } @@ -72,21 +71,24 @@ TYPED_TEST(CudaNhwcTypedTest, GlobalMaxPoolNhwc) { test->AddOutput("Y", output_dims, output_data); std::vector> execution_providers; - OrtCUDAProviderOptionsV2 nhwc = {.prefer_nhwc = true}; + OrtCUDAProviderOptionsV2 nhwc{}; + nhwc.prefer_nhwc = true; execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc)); double error_tolerance = 1e-3; - OrtCUDAProviderOptionsV2 nchw = {.prefer_nhwc = false}; + OrtCUDAProviderOptionsV2 nchw{}; + nchw.prefer_nhwc = false; auto source_ep = CudaExecutionProviderWithOptions(&nchw); test->CompareEPs(std::move(source_ep), execution_providers, error_tolerance); } TYPED_TEST(CudaNhwcTypedTest, AveragePoolNhwcPad) { - auto op = PoolOp{.pooling_type = "AveragePool", - .input_dims = {1, 16, 64, 64}, - .kernel_shape = {3, 3}, - .channels = 16, - .padding = {2, 2, 2, 2}}; + auto op = PoolOp{}; + op.pooling_type = "AveragePool"; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.padding = {2, 2, 2, 2}; MAKE_PROVIDERS() } From 7efb0dbe12cf8736d97dcc3b8f41eb96c5c34719 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 20 Feb 2024 17:22:44 +0100 Subject: [PATCH 100/207] add option DefaultTensorType to specify the default tensor type to quantize (#19455) ### Description The current quantization tool relies on shape inference to provide the type of every intermediate tensor, then the tool knows which type it must dequantize into (float32, float16). However, this information is not available if shape inference fails. That happens every time the model include an operator from a custom domain such as com.microsoft. This PR introduces an extra option `DefaultTensorType` as a fall back when the quantizer cannot find the type it needs. ### Motivation and Context This fixes issue #19409. --- .../tools/quantization/onnx_quantizer.py | 25 ++++- .../tools/transformers/quantize_helper.py | 3 +- .../test_quantizer_shape_inference.py | 92 +++++++++++++++++++ 3 files changed, 115 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/test/python/quantization/test_quantizer_shape_inference.py diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index ecfbaa569ca0a..9450426f12444 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -385,7 +385,7 @@ def add_new_nodes(self, nodes): def quantize_model(self): if self.has_QDQ_nodes(): logging.warning( - "Please check if the model is already quantized." + "Please check if the model is already quantized. " "Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly." ) @@ -442,6 +442,23 @@ def is_valid_quantize_weight(self, weight_name): return False return self.parent.is_valid_quantize_weight(weight_name) + def _get_default_tensor_type(self, tensor_name): + if "DefaultTensorType" in self.extra_options: + logging.info( + "get_tensor_type returns DefaultTensorType for tensor name %r, use %d", + tensor_name, + self.extra_options["DefaultTensorType"], + ) + return self.extra_options["DefaultTensorType"] + raise RuntimeError( + f"Unable to find data type for weight_name={tensor_name!r}. " + f"shape_inference failed to return a type probably this node is " + f"from a different domain or using an input produced by such an operator. " + f"This may happen if you quantize a model already quantized. " + f"You may use extra_options `DefaultTensorType` to indicate " + f"the default weight type, usually `onnx.TensorProto.FLOAT`." + ) + def get_tensor_type(self, tensor_name, mandatory=False): weight = find_by_name(tensor_name, self.model.initializer()) if weight is not None: @@ -450,11 +467,11 @@ def get_tensor_type(self, tensor_name, mandatory=False): vi = self.value_infos[tensor_name] if vi.type.HasField("tensor_type"): if mandatory and vi.type.tensor_type.elem_type == 0: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return vi.type.tensor_type.elem_type if (not self.enable_subgraph_quantization) or (self.parent is None): if mandatory: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return None otype = self.parent.is_valid_quantize_weight(tensor_name) if otype is not None: @@ -464,7 +481,7 @@ def get_tensor_type(self, tensor_name, mandatory=False): if res is not None: return res if mandatory: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return None def is_float_tensor(self, tensor_name): diff --git a/onnxruntime/python/tools/transformers/quantize_helper.py b/onnxruntime/python/tools/transformers/quantize_helper.py index a449e881ad361..6a25196dbc24c 100644 --- a/onnxruntime/python/tools/transformers/quantize_helper.py +++ b/onnxruntime/python/tools/transformers/quantize_helper.py @@ -7,7 +7,7 @@ import logging import os -import onnx # noqa: F401 +import onnx import torch from transformers.modeling_utils import Conv1D @@ -69,6 +69,7 @@ def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data onnx_model_path, quantized_model_path, use_external_data_format=use_external_data_format, + extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT}, ) logger.info(f"quantized model saved to:{quantized_model_path}") # TODO: inlcude external data in total model size. diff --git a/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py new file mode 100644 index 0000000000000..2b5d1f36070e5 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest + +import numpy as np +import onnx +import onnx.helper as oh +import onnx.numpy_helper as onh + +from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer +from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType + + +class TestQuantizerShapeInference(unittest.TestCase): + def test_com_microsoft(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("MatMul", ["X", "W1"], ["T1"]), + oh.make_node("FusedMatMul", ["T1", "W2"], ["T2"], domain="com.microsoft"), + oh.make_node("MatMul", ["T2", "W3"], ["T3"]), + oh.make_node("MatMul", ["T3", "W4"], ["Y"]), + ], + "name", + [oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 4])], + [oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 4])], + [ + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W1"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W2"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W3"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W4"), + ], + ), + opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)], + ) + model_shaped = onnx.shape_inference.infer_shapes(model) + shaped_results = set(t.name for t in model_shaped.graph.value_info) + # every result after T1 depends on T2 coming from a node com.microsoft, + # shape_inference cannot go beyond this point + self.assertEqual(shaped_results, {"T1"}) + + # first try: checks it raises an exception + quantizer = ONNXQuantizer( + model, + False, # per_channel + False, # reduce_range + QuantizationMode.IntegerOps, # mode + False, # static + QuantType.QInt8, # weight_type, + QuantType.QUInt8, # dynamic activation only supports uint8 + None, + [], # nodes_to_quantize, + [], # nodes_to_exclude + ["MatMul"], # op_types_to_quantize, + {"MatMulConstBOnly": True}, # extra_options, + # {'DefaultTensorType': 1, } + ) + + with self.assertRaises(RuntimeError) as e: + quantizer.quantize_model() + self.assertIn("Unable to find data type for weight_name=", str(e)) + + # second try: checks it works + quantizer = ONNXQuantizer( + model, + False, # per_channel + False, # reduce_range + QuantizationMode.IntegerOps, # mode + False, # static + QuantType.QInt8, # weight_type, + QuantType.QUInt8, # dynamic activation only supports uint8 + None, + [], # nodes_to_quantize, + [], # nodes_to_exclude + ["MatMul"], # op_types_to_quantize, + { + "MatMulConstBOnly": True, + "DefaultTensorType": 1, + }, + ) + + model = quantizer.quantize_model() + ops = {n.op_type for n in model.graph.node} + self.assertEqual(ops, {"Cast", "FusedMatMul", "MatMulInteger", "DynamicQuantizeLinear", "Mul"}) + + +if __name__ == "__main__": + unittest.main(verbosity=2) From 1b48054e1b7991ccef664fbedd659ec95d0e7ca7 Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Wed, 21 Feb 2024 01:24:34 +0800 Subject: [PATCH 101/207] [js/webgpu] Create Split indices helpers by rank, not by shape (#19554) ### Description This is required to make shape uniforms really work. ### Motivation and Context The bug was unveiled in a model with multiple Split nodes. The later nodes would try to reuse a previous pipeline cache, while the old shapes were hardcoded as constants in cache. --- js/web/lib/wasm/jsep/webgpu/ops/split.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 14d6f37927590..a09ac78b17006 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -68,7 +68,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const dataType = inputs[0].dataType; const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); const outputs = new Array(attributes.numOutputs); - const input = inputVariable('input', dataType, inputShape); + const input = inputVariable('input', dataType, inputShape.length); const sizeInSplitAxis = new Array(attributes.numOutputs); const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; @@ -80,7 +80,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const outputShape = inputShape.slice(); outputShape[attributes.axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); - outputs[i] = outputVariable(`output${i}`, dataType, outputShape); + outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } programUniforms.push( From 3c49aacd5667b320a4e02626a176098f7423d7c0 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Tue, 20 Feb 2024 13:13:40 -0800 Subject: [PATCH 102/207] Disable __cpuid check on arm64 builds as intrinsic is not available (#19574) Disable __cpuid check on arm64 builds as intrinsic is not available Motivation Breaking the arm64 build. Co-authored-by: Sheil Kumar --- winml/lib/Api/HardwareCoreEnumerator.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp index fa069c7fb66a7..b6b44690f4f6c 100644 --- a/winml/lib/Api/HardwareCoreEnumerator.cpp +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -84,6 +84,7 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetNumberOPhysicalAndEngineeringCores(); +#if !defined(_M_ARM64) && !defined(__aarch64__) const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" int regs_leaf0[4]; int regs_leaf7[4]; @@ -100,6 +101,7 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // On Intel Hybrid processors, numSocCores == cores.Num2CacheCores return cores.PhysicalCores - cores.Num2CacheCores; } +#endif return cores.PhysicalCores; } From ec9c8cbdc9686ccda6553674d6aab61cfd245cf0 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 21 Feb 2024 07:40:35 +1000 Subject: [PATCH 103/207] Use xcode parallel build flags to speed up iOS CI that is timing out (#19570) ### Description Provide specific xcodebuild flags instead of depending on cmake to do the right thing. This built in just over an hour with a ccache miss. Previous CIs with a ccache miss were timing out after 150 minutes. ### Motivation and Context --- tools/ci_build/build.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 244bebd81474d..5b715bb29e5a1 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1631,9 +1631,11 @@ def generate_build_tree( [ *temp_cmake_args, f"-DCMAKE_BUILD_TYPE={config}", - f"-DCMAKE_PREFIX_PATH={build_dir}/{config}/installed" - if preinstalled_dir.exists() and not (args.arm64 or args.arm64ec or args.arm) - else "", + ( + f"-DCMAKE_PREFIX_PATH={build_dir}/{config}/installed" + if preinstalled_dir.exists() and not (args.arm64 or args.arm64ec or args.arm) + else "" + ), ], cwd=config_build_dir, cuda_home=cuda_home, @@ -1667,8 +1669,11 @@ def build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, targe f"/p:CL_MPCount={num_parallel_jobs}", ] elif args.cmake_generator == "Xcode": - # CMake will generate correct build tool args for Xcode - cmd_args += ["--parallel", str(num_parallel_jobs)] + build_tool_args += [ + "-parallelizeTargets", + "-jobs", + str(num_parallel_jobs), + ] else: build_tool_args += [f"-j{num_parallel_jobs}"] From 7a5860e4909387448cb51351d3af50933238ba10 Mon Sep 17 00:00:00 2001 From: Jake Mathern Date: Tue, 20 Feb 2024 13:41:40 -0800 Subject: [PATCH 104/207] Fix cmake function duplicate lib (#19547) ### Description Fixes cmake function definition in winml.cmake to copy link flags. ### Motivation and Context XFGCheck errors in WindowsAI because this function does not transfer linker flags --- cmake/winml.cmake | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 268ee3960e75a..57cecd3e66adb 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -827,6 +827,7 @@ if (winml_is_inbox) get_target_property(compile_options ${target} COMPILE_OPTIONS) get_target_property(include_directories ${target} INCLUDE_DIRECTORIES) get_target_property(link_libraries ${target} LINK_LIBRARIES) + get_target_property(link_flags ${target} LINK_FLAGS) get_target_property(link_options ${target} LINK_OPTIONS) add_library(${new_target} SHARED ${sources}) @@ -835,6 +836,7 @@ if (winml_is_inbox) target_compile_options(${new_target} PRIVATE ${compile_options}) target_include_directories(${new_target} PRIVATE ${include_directories}) target_link_libraries(${new_target} PRIVATE ${link_libraries}) + set_property(TARGET ${new_target} PROPERTY LINK_FLAGS "${link_flags}") target_link_options(${new_target} PRIVATE ${link_options}) endfunction() From 97ff17c2cbb6ee6f27c052e9c4302c70a41af485 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:02:11 -0800 Subject: [PATCH 105/207] update script of run CI for external PRs to add "Big Models" (#19576) ### Description update script of run CI for external PRs to add "Big Models" --- tools/python/run_CIs_for_external_pr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index 7a77839c4a4e7..df4e70b1e51fe 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -93,6 +93,8 @@ def main(): # checks "onnxruntime-python-checks-ci-pipeline", "onnxruntime-binary-size-checks-ci-pipeline", + # big models + "Big Models", # not currently required, but running ensures we're hitting all mobile platforms "Android CI Pipeline", "iOS CI Pipeline", From 3fe2c137ee5923ee369062453d528fe0e33bf4bc Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:23:01 -0800 Subject: [PATCH 106/207] [js] small fix to workaround formatter (#19400) ### Description Rename shader variable names to snake_case naming and also to avoid formatter behaving inconsistently in win/linux. --- js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 3f73d9cb7c5bc..d5f97213e49ce 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -85,28 +85,28 @@ const createLayerNormProgramInfo = ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')} let offset = global_idx * uniforms.norm_size_vectorized; - var meanVector = ${fillVector('f32', components)}; - var meanSquareVector = ${fillVector('f32', components)}; + var mean_vector = ${fillVector('f32', components)}; + var mean_square_vector = ${fillVector('f32', components)}; for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) { let value = ${castToF32(dataType, components, 'x[h + offset]')}; - meanVector += value; - meanSquareVector += value * value; + mean_vector += value; + mean_square_vector += value * value; } - let mean = ${sumVector('meanVector', components)} / uniforms.norm_size; - let invStdDev = - inverseSqrt(${sumVector('meanSquareVector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); + let mean = ${sumVector('mean_vector', components)} / uniforms.norm_size; + let inv_std_dev = inverseSqrt(${ + sumVector('mean_square_vector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; let f32scale = ${castToF32(dataType, components, 'scale[j]')}; - output[j + offset] = ${variables[0].type.value}((f32input - mean) * invStdDev * f32scale + output[j + offset] = ${variables[0].type.value}((f32input - mean) * inv_std_dev * f32scale ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''} ); } ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''}; - ${hasInvStdOutput ? 'inv_std_output[global_idx] = invStdDev' : ''}; + ${hasInvStdOutput ? 'inv_std_output[global_idx] = inv_std_dev' : ''}; }`; }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; From 70567a4b3a8bc74fb0f1a9ed9ea5a5be6b99b378 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:33:21 -0800 Subject: [PATCH 107/207] [js/web] use ApiTensor insteadof onnxjs Tensor in TensorResultValidator (#19358) ### Description use ApiTensor insteadof onnxjs Tensor in TensorResultValidator. Make test runner less depend on onnxjs classes. --- js/web/test/test-runner.ts | 26 +++++++------------ .../unittests/backends/webgl/test-conv-new.ts | 4 ++- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index b01d474788f25..ecc7d4b4a09a5 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -39,10 +39,6 @@ const ONNXRUNTIME_THRESHOLD_RELATIVE_ERROR = 1.00001; */ const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now; -function toInternalTensor(tensor: ort.Tensor): Tensor { - return new Tensor( - tensor.dims, tensor.type as Tensor.DataType, undefined, undefined, tensor.data as Tensor.NumberType); -} function fromInternalTensor(tensor: Tensor): ort.Tensor { return new ort.Tensor(tensor.type, tensor.data as ort.Tensor.DataType, tensor.dims); } @@ -330,6 +326,10 @@ export class TensorResultValidator { } checkTensorResult(actual: Tensor[], expected: Tensor[]): void { + this.checkApiTensorResult(actual.map(fromInternalTensor), expected.map(fromInternalTensor)); + } + + checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void { // check output size expect(actual.length, 'size of output tensors').to.equal(expected.length); @@ -347,10 +347,6 @@ export class TensorResultValidator { } } - checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void { - this.checkTensorResult(actual.map(toInternalTensor), expected.map(toInternalTensor)); - } - checkNamedTensorResult(actual: Record, expected: Test.NamedTensor[]): void { // check output size expect(Object.getOwnPropertyNames(actual).length, 'size of output tensors').to.equal(expected.length); @@ -364,7 +360,7 @@ export class TensorResultValidator { } // This function check whether 2 tensors should be considered as 'match' or not - areEqual(actual: Tensor, expected: Tensor): boolean { + areEqual(actual: ort.Tensor, expected: ort.Tensor): boolean { if (!actual || !expected) { return false; } @@ -392,13 +388,13 @@ export class TensorResultValidator { switch (actualType) { case 'string': - return this.strictEqual(actual.stringData, expected.stringData); + return this.strictEqual(actual.data, expected.data); case 'float32': case 'float64': return this.floatEqual( - actual.numberData as number[] | Float32Array | Float64Array, - expected.numberData as number[] | Float32Array | Float64Array); + actual.data as number[] | Float32Array | Float64Array, + expected.data as number[] | Float32Array | Float64Array); case 'uint8': case 'int8': @@ -409,10 +405,8 @@ export class TensorResultValidator { case 'int64': case 'bool': return TensorResultValidator.integerEqual( - actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | - Int32Array, - expected.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | - Int32Array); + actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, + expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array); default: throw new Error('type not implemented or not supported'); diff --git a/js/web/test/unittests/backends/webgl/test-conv-new.ts b/js/web/test/unittests/backends/webgl/test-conv-new.ts index 8c186b9b36451..014fc57f21558 100644 --- a/js/web/test/unittests/backends/webgl/test-conv-new.ts +++ b/js/web/test/unittests/backends/webgl/test-conv-new.ts @@ -893,7 +893,9 @@ describe('New Conv tests', () => { const expected = cpuConv( inputTensor, kernelTensor, biasTensor, testData.autoPad, testData.dilations, testData.pads, testData.strides); - if (!validator.areEqual(actual, expected)) { + try { + validator.checkTensorResult([actual], [expected]); + } catch { console.log(actual.dims, `[${actual.numberData.slice(0, 20).join(',')},...]`); console.log(expected.dims, `[${expected.numberData.slice(0, 20).join(',')},...]`); throw new Error('Expected and Actual did not match'); From 6e04e36e3faf2d8115c0962c85b86a6a8b48ac5b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:33:37 -0800 Subject: [PATCH 108/207] [js/common] upgrade tsc in common from 4.9.5 to 5.2.2 (#19317) ### Description upgrade tsc in common from 4.9.5 to 5.2.2 --- js/common/package-lock.json | 106 +++++++++++++++++------------------ js/common/package.json | 4 +- js/common/test/tsconfig.json | 2 +- 3 files changed, 56 insertions(+), 56 deletions(-) diff --git a/js/common/package-lock.json b/js/common/package-lock.json index a5ada877b916a..3988ac80707e0 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -9,13 +9,13 @@ "version": "1.18.0", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "node_modules/balanced-match": { @@ -34,9 +34,9 @@ } }, "node_modules/jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "node_modules/lunr": { @@ -46,9 +46,9 @@ "dev": true }, "node_modules/marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true, "bin": { "marked": "bin/marked.js" @@ -58,24 +58,24 @@ } }, "node_modules/minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "dependencies": { "brace-expansion": "^2.0.1" }, "engines": { - "node": ">=10" + "node": ">=16 || 14 >=14.17" }, "funding": { "url": "https://github.com/sponsors/isaacs" } }, "node_modules/shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "dependencies": { "ansi-sequence-parser": "^1.1.0", @@ -85,30 +85,30 @@ } }, "node_modules/typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "dependencies": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" }, "bin": { "typedoc": "bin/typedoc" }, "engines": { - "node": ">= 14.14" + "node": ">= 16" }, "peerDependencies": { - "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x" + "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x || 5.0.x || 5.1.x || 5.2.x || 5.3.x" } }, "node_modules/typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true, "bin": { @@ -116,7 +116,7 @@ "tsserver": "bin/tsserver" }, "engines": { - "node": ">=4.2.0" + "node": ">=14.17" } }, "node_modules/vscode-oniguruma": { @@ -134,9 +134,9 @@ }, "dependencies": { "ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "balanced-match": { @@ -155,9 +155,9 @@ } }, "jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "lunr": { @@ -167,24 +167,24 @@ "dev": true }, "marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true }, "minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "requires": { "brace-expansion": "^2.0.1" } }, "shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "requires": { "ansi-sequence-parser": "^1.1.0", @@ -194,21 +194,21 @@ } }, "typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "requires": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" } }, "typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true }, diff --git a/js/common/package.json b/js/common/package.json index 64ab2736adbe3..cd2612aab4984 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -9,7 +9,7 @@ }, "author": "fs-eire", "scripts": { - "build:cjs": "tsc --module commonjs --outDir ./dist/cjs", + "build:cjs": "tsc --module commonjs --moduleResolution node10 --outDir ./dist/cjs", "build:esm": "tsc", "build:bundles": "webpack", "build": "node ./build.js", @@ -18,7 +18,7 @@ "test": "mocha ./test/**/*.js --timeout 30000" }, "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" }, "main": "dist/cjs/index.js", "exports": { diff --git a/js/common/test/tsconfig.json b/js/common/test/tsconfig.json index 2e4927ac3b325..e9068ad837a81 100644 --- a/js/common/test/tsconfig.json +++ b/js/common/test/tsconfig.json @@ -2,7 +2,7 @@ "extends": "../../tsconfig.tools.json", "exclude": ["type-tests/**/*.ts"], "compilerOptions": { - "module": "ES2022", + "module": "Node16", "sourceMap": true } } From 45e20bf7810689ecf385957c34434c6d2456e32b Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 21 Feb 2024 12:38:37 +1000 Subject: [PATCH 109/207] Use build.py to build in py-win-gpu.yml so parallelization parameters are set (#19578) ### Description build.py sets a few parallelization parameters when building. Using msbuild directly lacks those. https://github.com/microsoft/onnxruntime/blob/7a5860e4909387448cb51351d3af50933238ba10/tools/ci_build/build.py#L1665-L1669 Changed to use build.py. If there's a concern with that we _could_ set the parameters in the yaml, but that will be uglier due to duplicating logic in multiple places. ### Motivation and Context --- .../azure-pipelines/templates/py-win-gpu.yml | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml index 18368e59cad52..4315eae503ebd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml @@ -120,17 +120,17 @@ jobs: $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} workingDirectory: '$(Build.BinariesDirectory)' - - task: VSBuild@1 + # building with build.py so the parallelization parameters are added to the msbuild command + - task: PythonScript@0 displayName: 'Build' inputs: - solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' - platform: x64 - configuration: RelWithDebInfo - msbuildArchitecture: $(buildArch) - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - createLogFile: true + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --parallel --build + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} + workingDirectory: '$(Build.BinariesDirectory)' # Esrp signing - template: win-esrp-dll.yml @@ -188,7 +188,7 @@ jobs: condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) inputs: GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - template: component-governance-component-detection-steps.yml parameters: From 0c4421cb7867434e1e08b4274f16f6c2f14cb4ce Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Wed, 21 Feb 2024 03:39:43 +0100 Subject: [PATCH 110/207] Fix compile warnings (as errors) for functions which miss returning required return value (#19079) Added dummy return values to functions which specify a return value, but do not return an value value. ### Motivation and Context Fix compiler errors with 'warnings as errors' enabled. From 8fadc6c913bc30edff2e89756da515b9bd75d256 Mon Sep 17 00:00:00 2001 From: zhijiang <43435212+zhijxu-MS@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:41:42 +0800 Subject: [PATCH 111/207] Zhijxu/cleanup cached tensors when oom (#19306) in pytorch, when oom happens at bp, user could decrease the batch size and rerun it without restarting the process. while in ORT, the intermediate tensors are kept even OOM, so decrease batch size still fail. this is torch run, we can see after oom failure, torch will release tensor before next step ![image](https://github.com/microsoft/onnxruntime/assets/43435212/92b8a2e3-454b-448a-a223-17cb91d463c2) this is from ort, we can see ort not release its tensors after OOM failure. ![image](https://github.com/microsoft/onnxruntime/assets/43435212/bb6a3882-8e14-4f37-8079-e7f70fc2546b) ort with the PR, we can see memory is released, **the 4GB memory is not own by ort, and will be released by torch at the end**. ![image](https://github.com/microsoft/onnxruntime/assets/43435212/7f39d711-4e36-47d5-aecf-3805433a6d01) --- onnxruntime/core/framework/execution_frame.cc | 21 +++++++++++++++ onnxruntime/core/framework/execution_frame.h | 2 ++ .../training/ortmodule/_training_manager.py | 26 ++++++++++--------- 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 8c08152986cf6..32a5f749af084 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -204,6 +204,14 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtDevice& info) const { Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); } +#ifdef ENABLE_TRAINING +void IExecutionFrame::ReleaseAllMLValues() { + for (size_t ort_value_idx = 0; ort_value_idx < all_values_.size(); ort_value_idx++) { + all_values_[ort_value_idx] = OrtValue(); + } +} +#endif + Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast(ort_value_idx) >= all_values_size_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx); @@ -831,7 +839,20 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const { // This method is not thread safe! // Return S_OK and nullptr if index map to a value that is an unused optional input/output Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) { +#ifdef ENABLE_TRAINING + try { + auto status = AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); + return status; + } catch (const std::exception& e) { + LOGS(session_state_.Logger(), WARNING) + << "Exception caught when allocating memory for ort_value with index: " << ort_value_idx + << "so clean up all OrtValues"; + ReleaseAllMLValues(); + return Status(ONNXRUNTIME, FAIL, e.what()); + } +#else return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); +#endif } void ExecutionFrame::VerifyOutputSizes(int output_index, const Node& node, const TensorShape& output_shape) { diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index 1576c16684faa..18d210ffd48f7 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -67,6 +67,8 @@ class IExecutionFrame { const std::unordered_map& initializers); Status GetOutputs(gsl::span fetch_mlvalue_idxs, std::vector& fetches); + // if OOM happens, then release all values, so session can run next batch. + void ReleaseAllMLValues(); #endif // TO DO: make it thread safe diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index cc533e549db92..73c32a2f51e41 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -196,18 +196,20 @@ def backward(ctx, *grad_outputs): # Run and get results backward_outputs = C.OrtValueVector() - self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) - # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not - # affect peak memory usage in a subsequent graph run. - del ctx.run_info.state - - # Fast version: all backward_outputs are converted first. - # This version only works if backward_outputs is an OrtValueVector. - transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) - - self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) - - return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) + try: + self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) + # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not + # affect peak memory usage in a subsequent graph run. + + # Fast version: all backward_outputs are converted first. + # This version only works if backward_outputs is an OrtValueVector. + transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) + + self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) + res = tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) + return res + finally: + del ctx.run_info.state return _ORTModuleFunction From 6226c5f62f3d16b9702d5c40993ee9bf1cbd119c Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:08:48 +0800 Subject: [PATCH 112/207] [ROCm] Add SkipGroupNorm for ROCm EP (#19303) Add SkipGroupNorm for ROCm EP. --------- Co-authored-by: Peixuan Zuo --- cmake/onnxruntime_rocm_hipify.cmake | 5 - .../contrib_ops/rocm/diffusion/group_norm.cc | 152 ------------- .../rocm/diffusion/group_norm_ck.cuh | 35 +-- .../diffusion/group_norm_ck_impl/impl.cuh | 10 +- .../diffusion/group_norm_ck_impl/impl_fp16.cu | 8 +- .../diffusion/group_norm_ck_impl/impl_fp32.cu | 8 +- .../rocm/diffusion/group_norm_common.h | 125 +++------- .../rocm/diffusion/group_norm_impl.cu | 47 ++-- .../rocm/diffusion/group_norm_impl.h | 47 ---- .../rocm/diffusion/group_norm_impl_kernel.cuh | 213 ------------------ .../rocm/diffusion/group_norm_triton.cuh | 29 +-- .../rocm/diffusion/group_norm_triton.py | 16 +- .../rocm/diffusion/group_norm_tunable_op.h | 153 +++++++------ .../contrib_ops/rocm/rocm_contrib_kernels.cc | 2 + .../kernel_explorer/kernels/groupnorm_test.py | 136 ++++++++--- .../kernels/rocm/group_norm.cu | 112 +++++---- .../contrib_ops/skip_group_norm_op_test.cc | 14 +- tools/ci_build/amd_hipify.py | 2 + 18 files changed, 382 insertions(+), 732 deletions(-) delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h delete mode 100644 onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index d485abe6bb1a6..85a9bf50460d3 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -44,12 +44,7 @@ set(contrib_ops_excluded_files "bert/packed_multihead_attention.cc" "bert/packed_multihead_attention_impl.h" "bert/packed_multihead_attention_impl.cu" - "diffusion/group_norm.cc" "diffusion/group_norm_impl.cu" - "diffusion/group_norm_impl.h" - "diffusion/group_norm_impl_kernel.cuh" - "diffusion/group_norm_common_base.h" - "diffusion/group_norm_common_base.cc" "diffusion/nhwc_conv.cc" "math/gemm_float8.cc" "math/gemm_float8.cu" diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc deleted file mode 100644 index e82e15a304f4c..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define GROUP_NORM_TYPES float, MLFloat16 - -ONNX_OPERATOR_KERNEL_EX( - GroupNorm, kMSDomain, 1, kRocmExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); - -using namespace ONNX_NAMESPACE; - -namespace { -template -struct DispatchGroupNorm { - Status operator()(RocmTuningContext* tuning_ctx, - Stream* stream, - Tensor* output, - const Tensor* input, - const Tensor* gamma, - const Tensor* beta, - void* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_swish_activation) { - typedef typename ToHipType::MappedType HipT; - return LaunchGroupNormKernel( - tuning_ctx, - stream, - reinterpret_cast(output->MutableData()), - reinterpret_cast(input->Data()), - gamma->Data(), - beta->Data(), - workspace, - epsilon, - batch_size, - num_channels, - height, - width, - num_groups, - use_swish_activation); - } -}; - -} // namespace - -GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { - epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); - ORT_ENFORCE(epsilon_ >= 0); - - int64_t num_groups; - ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK()); - ORT_ENFORCE(num_groups >= 0); - num_groups_ = static_cast(num_groups); - - int64_t activation; - ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); - ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish - use_swish_activation_ = (activation == 1); - - channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); -} - -Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, - bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { - is_packed = false; - return Status::OK(); -} - -Status GroupNorm::ComputeInternal(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* gamma = context->Input(1); - const Tensor* beta = context->Input(2); - Tensor* output = context->Output(0, input->Shape()); - - if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "only the channels_last layout is supported"); - } - - const auto& input_dims = input->Shape().GetDims(); - if (input_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 4 dimensions, got ", input_dims.size()); - } - - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != input_dims[3]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in gamma and input does not match"); - } - - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != input_dims[3]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in beta and input does not match"); - } - - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisible by num_groups"); - } - - if (context->GetUseDeterministicCompute()) { - static std::once_flag log_warning; - std::call_once(log_warning, []() { - LOGS_DEFAULT(WARNING) << "GroupNorm has no deterministic GPU kernel, its outputs may still be nondeterministic."; - }); - } - - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); - - utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(GetTuningContext(), context->GetComputeStream(), - output, input, gamma, beta, workspace.get(), - epsilon_, - batch_size, - num_channels, - height, - width, - num_groups_, - use_swish_activation_); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index fb7091592c16e..d0a0d09fcbae3 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -26,13 +26,18 @@ namespace rocm { using onnxruntime::rocm::CKDataTypeAdaptor; -using Swish = ck::tensor_operation::element_wise::Swish; +// The SiLU function is a special case of Swish function, +// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as: +// SiLU(x) = x * sigmoid(x) +// Swish(x) = x * sigmoid(bx) +// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here. +using Silu = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; constexpr int Rank = 5; constexpr int NumReduceDim = 3; -template +template auto GetCKGroupNormNHWCTypeStringAndOps() { using XDataType = typename CKDataTypeAdaptor::type; using YDataType = typename CKDataTypeAdaptor::type; @@ -40,26 +45,30 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { using GammaDataType = float; using BetaDataType = float; - using Activation = std::conditional_t; + using Activation = std::conditional_t; - std::vector>>> ret; + std::vector>>> ret; for (auto&& impl : internal::GetDeviceGroupNormInstances()) { - std::string swish_suffix = WithSwish ? "_Swish" : "_Pass"; - auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix; + std::string silu_suffix = WithSilu ? "_Silu" : "_Pass"; + auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix; auto invoker = impl->MakeInvokerPointer(); - auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams* params) -> Status { - if constexpr (WithSwish) { + auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Input skip or bias is not supported by composable kernel."); + if constexpr (WithSilu) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->withSwish, "Swish version only support groupnorm with swish"); + !params->use_silu, "Silu version only support groupnorm with silu"); } else { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->withSwish, "Pass version only support groupnorm without swish"); + params->use_silu, "Pass version only support groupnorm without silu"); } - std::vector in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup}; - std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1}; - std::vector gamma_beta_strides{0, 0, 0, params->cPerGroup, 1}; + std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; + std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, + params->c, params->channels_per_group, 1}; + std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; std::vector reduce_dims{1, 2, 4}; auto activation = Activation{}; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh index 19b081881dcec..4cb371fdcf960 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -18,7 +18,7 @@ namespace internal { using F16 = ck::half_t; using F32 = float; -using Swish = ck::tensor_operation::element_wise::Swish; +using Silu = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface @@ -101,9 +101,9 @@ GetDeviceGroupNormInstances() { template <> std::vector>> + F16, F32, F32, F16, F32, Silu, 5, 3>>> GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Swish, 5, 3>(); + F16, F32, F32, F16, F32, Silu, 5, 3>(); template <> std::vector std::vector>> + F32, F32, F32, F32, F32, Silu, 5, 3>>> GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Swish, 5, 3>(); + F32, F32, F32, F32, F32, Silu, 5, 3>(); template <> std::vector -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, - device_normalization_f16_instances{}); + device_normalization_f16_instances{}); return instances; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu index 9b0ccab17b4c1..ceb53ed442abc 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -11,12 +11,12 @@ namespace rocm { namespace internal { template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, - device_normalization_f32_instances{}); + device_normalization_f32_instances{}); return instances; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h index 008ae20b0561f..7cff640db2f34 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h @@ -8,110 +8,47 @@ #include "core/providers/rocm/cu_inc/common.cuh" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/rocm/diffusion/group_norm_common_base.h" namespace onnxruntime { namespace contrib { namespace rocm { -using onnxruntime::rocm::CeilDiv; - -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; - for (int32_t i = 1; i <= std::sqrt(n); i++) { - if (n % i == 0) { - int32_t divisor1 = n / i; - int32_t divisor2 = i; - - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; - } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; - } - } - } - return maxDivisor; -} - template -struct GroupNormNHWCParams : OpParams { - GroupNormNHWCParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, T* dst, float* redBuffer, const T* src, const float* gamma, - const float* beta, int32_t n, int32_t h, int32_t w, int32_t c, int32_t groups, float epsilon, bool withSwish) - : OpParams(tuning_ctx, stream), dst(dst), src(src), gamma(gamma), beta(beta), redBuffer(redBuffer), epsilon(epsilon), n(n), h(h), w(w), c(c), groups(groups), withSwish(withSwish) { - int32_t maxBlocksPerHW = 1024; - switch (c) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; - } - - hw = h * w; - const int32_t blocksPerHW = findMaxDivisor(hw, maxBlocksPerHW); - hwPerBlock = CeilDiv(hw, blocksPerHW); - cPerGroup = c / groups; - hwc = hw * c; - invHWC = 1.F / (float)(hw * cPerGroup); - groupsPerBlock = cPerBlock / cPerGroup; - } +struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { + GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx, + onnxruntime::Stream* ort_stream, + T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + float* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) + : OpParams(tuning_ctx, ort_stream), + GroupNormNHWCParams(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, + num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {} std::string Signature() const override { - std::string swish_suffix = withSwish ? "_Swish" : "_Pass"; - std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups) + swish_suffix; + std::string silu_suffix = this->use_silu ? "_silu" : "_pass"; + std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip"; + std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast"; + std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias"; + std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + + std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix + + skip_suffix + broadcast_suffix + bias_suffix; return sig; } - - // The output buffer. Layout NHWC. - T* dst; - // The input buffer. Layout NHWC. - T const* src; - // The gamma scaling factor. - float const* gamma; - // The beta term to add in GN. - float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; - float epsilon; - - // The number of instances in the batch. - int32_t n; - // The height and width of each activation map. - int32_t h; - int32_t w; - // The number of channels. - int32_t c; - // The number of groups. - int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; - - // Precomputed values and parameters to control the execution of the kernels. - - // The number of activations per instance (h * w) and the number of - // activations per block. - int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; - - // The precomputed stride between instances. - int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; - // The precomputed number of groups per block. - int32_t groupsPerBlock; }; } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu index dbd5009e63676..142aaf14e8d2d 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu @@ -15,9 +15,12 @@ namespace rocm { template Status LaunchGroupNormKernel( RocmTuningContext* tuning_ctx, - Stream* stream, + Stream* ort_stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -27,19 +30,26 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, + reinterpret_cast(workspace), epsilon, batch_size, num_channels, + height, width, num_groups, use_silu, broadcast_skip, channels_per_block); - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + if (params.channels_per_block % params.channels_per_group != 0 || + params.channels_per_block > kMaxSize || + (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in ROCM does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - GroupNormNHWCParams params(tuning_ctx, stream, output, reinterpret_cast(workspace), input, gamma, beta, - batch_size, height, width, num_channels, num_groups, epsilon, use_swish_activation); + HIP_RETURN_IF_ERROR(hipMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle())); if (tuning_ctx->IsTunableOpEnabled()) { static GroupNormNHWCTunableOp op; @@ -50,14 +60,17 @@ Status LaunchGroupNormKernel( } template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, - float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + half* add_out, const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, - float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + float* add_out, const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); + } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h deleted file mode 100644 index a0f7e0aca5def..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/common/common.h" -#include "core/common/status.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::tunable::RocmTuningContext; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { - // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; -} - -template -Status LaunchGroupNormKernel( - RocmTuningContext* tuning_ctx, - Stream* stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh deleted file mode 100644 index d6322a12a9363..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// The ROCm kernel is modified from TensorRT 8.5. -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -static inline __device__ __host__ float sigmoid(float x) { - return 1.F / (1.F + expf(-x)); -} - -struct GroupSums { - // Is it the 1st element of the group? - int32_t flag; - // The sum. - float sum; - // The sum of squares. - float sumSq; -}; - -struct GroupSumsOp { - inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { - GroupSums dst; - dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); - dst.flag = a.flag + b.flag; - return dst; - } -}; - -template -inline __device__ void UpdateSum(const T* src, int64_t offset, U& sum, U& sumSq) { - using VecT = onnxruntime::rocm::aligned_vector; - const VecT input_v = *reinterpret_cast(src + offset); - -#pragma unroll - for (int i = 0; i < ILP; i++) { - const U val = static_cast(input_v.val[i]); - sum += val; - sumSq += val * val; - } -} - -template -__global__ void groupNormNHWCSumKernel(const T* src, float* redBuffer, int32_t cPerBlock, int32_t hwPerBlock, int32_t hw, - int32_t hwc, int32_t c, int32_t cPerGroup, int32_t groups, int32_t groupsPerBlock) { - // The object in charge of doing the sums for the different blocks. - typedef hipcub::BlockScan BlockScan; - - // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[ThreadsPerBlock]; - - // The instance in the batch. - int32_t ni = blockIdx.z; - // The channel loaded by that thread (ILP channels per thread). - int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP; - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + hwPerBlock, hw); - - // The sums. - float sum = 0.F; - float sumSq = 0.F; - - // Iterate over the activations to compute the sums. - if (ci < c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * hwc + static_cast(hwi) * c + ci; - UpdateSum(src, offset, sum, sumSq); - } - } - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * ILP / cPerGroup; - int32_t cj = threadIdx.x * ILP - cPerGroup * gi; - - // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; - - // Do the segmented scan. - GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); - - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == cPerGroup - ILP) { // ILP channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The global group index. - int32_t gj = blockIdx.x * groupsPerBlock + threadIdx.x; - - // Threads that have nothing left to do, exit. - if (threadIdx.x >= groupsPerBlock || gj >= groups) { - return; - } - - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(&redBuffer[(2 * ni + 0) * groups + gj], sums.x); - atomicAdd(&redBuffer[(2 * ni + 1) * groups + gj], sums.y); -} - -template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, U mean, U invStdDev, - const U* gamma_v, const U* beta_v, bool swish) { - using VecT = onnxruntime::rocm::aligned_vector; - const VecT input_v = *reinterpret_cast(src + offset); - VecT output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - U val = static_cast(input_v.val[i]); - val = (val - mean) * invStdDev; - val = gamma_v[i] * val + beta_v[i]; - - if (swish) { - val = val * sigmoid(val); - } - output_v.val[i] = static_cast(val); - } - *(reinterpret_cast(dst + offset)) = output_v; -} - -template -__global__ void groupNormNHWCScaleKernel(T* dst, const T* src, const float* gamma, const float* beta, const float* redBuffer, float epsilon, int32_t c, int32_t cPerBlock, - int32_t cPerGroup, int32_t groups, int32_t hwc, float invHWC, int32_t hw, int32_t hwPerBlock, bool withSwish) { - // The channel loaded by that thread (ILP channels per thread for F16x2). - int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP; - if (ci >= c) { - return; - } - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / cPerGroup; - - // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; - if (gi < groups) { - sum = redBuffer[(2 * ni + 0) * groups + gi]; - sumSq = redBuffer[(2 * ni + 1) * groups + gi]; - } - - using VecF = onnxruntime::rocm::aligned_vector; - - const VecF gamma_v = *reinterpret_cast(gamma + ci); - const VecF beta_v = *reinterpret_cast(beta + ci); - - // Compute the mean. - float mean = sum * invHWC; - // Compute the variance. - float var = sumSq * invHWC - (mean * mean); - // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var + epsilon); - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + hwPerBlock, hw); - - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * hwc + hwi * c + ci; - - // Fetch ILP channels per thread. - computeGroupNorm(src, dst, offset, mean, invStdDev, gamma_v.val, beta_v.val, withSwish); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index b7b9441ac997d..b3d3e92209b39 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -20,21 +20,21 @@ namespace rocm { namespace { -template +template std::string GetGroupNormTritonGroupName() { std::string ret = "GroupNormTriton_"; - std::string swish_suffix = WithSwish ? "Swish_" : "Pass_"; - ret += swish_suffix; + std::string silu_suffix = WithSilu ? "Silu_" : "Pass_"; + ret += silu_suffix; ret += GetDataTypeName(); return ret; } } // namespace -template +template auto GetTritonGroupNormNHWCTypeStringAndOps() { - std::vector>>> ret; - auto group_name = GetGroupNormTritonGroupName(); + std::vector>>> ret; + auto group_name = GetGroupNormTritonGroupName(); auto* kernel_list = GetOrtTritonKernelByGroup(group_name); if (kernel_list == nullptr) { return ret; @@ -45,16 +45,19 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto* metadata = GetOrtTritonKernelMetadata(i); auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); - auto impl = [i, block_size, hw_size](const GroupNormNHWCParams* params) -> Status { + auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Input skip or bias is not supported by triton kernel."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->cPerGroup > block_size || params->cPerGroup * 2 <= block_size, - "Arg block_size (", block_size, ") is not the next power of 2 of cPerGroup (", params->cPerGroup, ")."); + params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, + "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", + params->channels_per_group, ")."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); - if constexpr (WithSwish) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->withSwish, "Swish version does not support GN w/o swish."); + if constexpr (WithSilu) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu."); } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->withSwish, "Pass version does not support GN w/ swish."); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu."); } // Construct args for launch kernel struct { @@ -73,7 +76,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { (const void*)params->beta, params->hw, params->c, - params->cPerGroup, + params->channels_per_group, params->epsilon}; // Grid dim is (batch_count, groups, 1) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 56b3a030b289e..5368cb1cf635b 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -21,7 +21,7 @@ def group_norm_kernel( eps, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, - ACTIVATION_SWISH: tl.constexpr, + ACTIVATION_SILU: tl.constexpr, ): row_x = tl.program_id(0) row_y = tl.program_id(1) @@ -62,7 +62,7 @@ def group_norm_kernel( x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) x_hat = (x - group_mean) * rstd y = x_hat * gamma + beta - if ACTIVATION_SWISH: + if ACTIVATION_SILU: y *= tl.sigmoid(y) tl.store(y_ptr + offsets, y, mask=mask) @@ -71,7 +71,7 @@ def group_norm_kernel( # blocks = [16, 32, 64, 128, 256, 512] # hw_sizes = [8, 16, 32, 64, 128, 256, 512] # but this will result in too many functions and slow down the compilation. -with_swish = [True, False] +with_silu = [True, False] dtypes = ["fp32", "fp16"] blocks = [16, 32, 64, 128] hw_sizes = [8, 16, 32, 64, 128, 256] @@ -84,14 +84,14 @@ def group_norm_kernel( def get_function_table(): func_table = [] - for swish, dtype, hw_size, warp, b in product(with_swish, dtypes, hw_sizes, warps, blocks): - swish_suffix = "Swish" if swish else "Pass" - name = name_pattern.format(swish_suffix, dtype, b, hw_size, warp) - group = group_pattern.format(swish_suffix, dtype) + for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks): + silu_suffix = "Silu" if silu else "Pass" + name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) + group = group_pattern.format(silu_suffix, dtype) sig = sig_pattern.format(dtype, dtype) kwargs = { "num_warps": warp, - "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SWISH": int(swish)}, + "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, } func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs} func_table.append(func_desc) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h index 25d820f7ed326..e6831f764b418 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h @@ -20,115 +20,117 @@ namespace rocm { using onnxruntime::rocm::GPU_WARP_SIZE; template -void groupNormNHWCSum(const GroupNormNHWCParams* params) { - // Make sure the values are as we expect. - ORT_ENFORCE(params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0); - +void GroupNormNHWCSum(const GroupNormNHWCTunableParams* params) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params->c / params->cPerBlock; + grid.x = DivUp(params->c, params->channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.y = DivUp(params->hw, params->hw_per_block); // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ - groupNormNHWCSumKernel \ - <<StreamHandle()>>>( \ - params->src, params->redBuffer, params->cPerBlock, \ - params->hwPerBlock, params->hw, params->hwc, params->c, \ - params->cPerGroup, params->groups, params->groupsPerBlock); \ +#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ + GroupNormNHWCSumKernel \ + <<StreamHandle()>>>( \ + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \ + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \ + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \ break; - switch (params->cPerBlock) { - case 320: - LAUNCH_GROUPNORM_SUM(256, 2) - case 480: - LAUNCH_GROUPNORM_SUM(256, 2) + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { case 256: - LAUNCH_GROUPNORM_SUM(128, 2) + LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) case 128: - LAUNCH_GROUPNORM_SUM(64, 2) + LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) default: ORT_NOT_IMPLEMENTED("Not implemented"); } } template -Status GroupNormNHWCSumOp(const GroupNormNHWCParams* params) { +Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams* params) { dim3 grid; - grid.x = params->c / params->cPerBlock; - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); grid.z = params->n; - groupNormNHWCSumKernel + GroupNormNHWCSumKernel <<StreamHandle()>>>( - params->src, params->redBuffer, params->cPerBlock, params->hwPerBlock, - params->hw, params->hwc, params->c, params->cPerGroup, params->groups, params->groupsPerBlock); + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); return HIP_CALL(hipGetLastError()); } template -void groupNormNHWCScale(const GroupNormNHWCParams* params) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params->c % params->cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0); - +void GroupNormNHWCScale(const GroupNormNHWCTunableParams* params) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params->c / params->cPerBlock; + grid.x = DivUp(params->c, params->channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.y = DivUp(params->hw, params->hw_per_block); // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ - groupNormNHWCScaleKernel \ - <<StreamHandle()>>>( \ - params->dst, params->src, params->gamma, params->beta, \ - params->redBuffer, params->epsilon, params->c, params->cPerBlock, \ - params->cPerGroup, params->groups, params->hwc, params->invHWC, \ - params->hw, params->hwPerBlock, params->withSwish); \ +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<StreamHandle()>>>( \ + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ + params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ + params->hw, params->hw_per_block, params->use_silu); \ break; - switch (params->cPerBlock) { - case 320: - LAUNCH_GROUPNORM_SCALE(256, 2) - case 480: - LAUNCH_GROUPNORM_SCALE(256, 2) + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { case 256: - LAUNCH_GROUPNORM_SCALE(128, 2) + LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) case 128: - LAUNCH_GROUPNORM_SCALE(64, 2) + LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) default: ORT_NOT_IMPLEMENTED("Not implemented"); } } template -Status GroupNormNHWCScaleOp(const GroupNormNHWCParams* params) { +Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams* params) { dim3 grid; - grid.x = params->c / params->cPerBlock; - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); grid.z = params->n; - groupNormNHWCScaleKernel + GroupNormNHWCScaleKernel <<StreamHandle()>>>( - params->dst, params->src, params->gamma, params->beta, params->redBuffer, params->epsilon, params->c, params->cPerBlock, - params->cPerGroup, params->groups, params->hwc, params->invHWC, params->hw, params->hwPerBlock, params->withSwish); + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, + params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, + params->use_silu); return HIP_CALL(hipGetLastError()); } template class GroupNormNHWCOp { public: - Status operator()(const GroupNormNHWCParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle())); + Status operator()(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); auto status = GroupNormNHWCSumOp(params); ORT_RETURN_IF_ERROR(status); HIP_RETURN_IF_ERROR(hipGetLastError()); @@ -138,29 +140,30 @@ class GroupNormNHWCOp { return Status::OK(); } - Status IsSupported(const GroupNormNHWCParams* params) { + Status IsSupported(const GroupNormNHWCTunableParams* params) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0), - "The number of channels (", params->c, ") or the number of channels per group (", params->cPerGroup, + !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0), + "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group, ") isn't divisible by the number of vector size: ", VecSize); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock % params->cPerGroup == 0 && - params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0), - "The value of attributes don't meet the requirements."); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock <= ThreadsPerBlock * VecSize && - params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize && + params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), "Configuration: Threads (", ThreadsPerBlock, "), vector size (", - VecSize, ") is redundant for the number of channels per group: ", params->cPerBlock); + VecSize, ") is redundant for the number of channels per group: ", + params->channels_per_block); return Status::OK(); } }; template -Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle())); - groupNormNHWCSum(params); +Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); + GroupNormNHWCSum(params); HIP_RETURN_IF_ERROR(hipGetLastError()); - groupNormNHWCScale(params); + GroupNormNHWCScale(params); HIP_RETURN_IF_ERROR(hipGetLastError()); return Status::OK(); } @@ -178,30 +181,30 @@ Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { ADD_OP_FOR_ALL_VEC_SIZE(name, 320) template -class GroupNormNHWCTunableOp : public TunableOp> { +class GroupNormNHWCTunableOp : public TunableOp> { public: GroupNormNHWCTunableOp() { this->RegisterOp(GroupNormNHWCStaticSelection); ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) #ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 55cd6a1d112f5..382a3951f3a83 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -93,6 +93,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Samp class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -246,6 +247,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py index e32cb032798fc..8334d20e47c86 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py @@ -35,7 +35,11 @@ def sigmoid_function(x): return 1.0 / (1.0 + np.exp(-x)) -def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish): +def group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, with_silu, has_skip): + add_output = None + if has_skip: + input_x = input_x + skip_x + bias_x + add_output = input_x n, h, w, c = input_x.shape input_x = input_x.transpose([0, 3, 1, 2]) assert c % num_groups == 0 @@ -45,46 +49,70 @@ def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish): x = x.transpose([0, 2, 3, 1]) x = x * gamma + beta - if with_swish: + if with_silu: x = x * sigmoid_function(x) - return x + return x, add_output -def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func): +def run_group_norm( + batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, silu: bool, has_skip: bool, func +): np.random.seed(0) width = height input_x = np.random.rand(batch_size, height, width, num_channels).astype(np.float32) gamma = np.random.rand(num_channels).astype(np.float32) beta = np.random.rand(num_channels).astype(np.float32) # the size of workspace is defined in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h L18 - workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * 32 * 32).astype(np.float32) + workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * batch_size * num_groups).astype(np.float32) epsilon = 1e-05 output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - use_swish = swish - host_x = input_x.astype(dtype) - input_d = ke.DeviceArray(host_x) + skip_x = ( + np.random.rand(batch_size, height, width, num_channels).astype(np.float32) + if has_skip + else np.empty((0), dtype=dtype) + ) + bias_x = np.random.rand(num_channels).astype(np.float32) if has_skip else np.empty((0), dtype=dtype) + add_output = ( + np.random.rand(batch_size, height, width, num_channels).astype(dtype) + if has_skip + else np.empty((0), dtype=dtype) + ) + use_silu = silu + broadcast_skip = False + channels_per_block = 0 # Compute in params initialization + + input_d = ke.DeviceArray(input_x.astype(dtype)) + skip_d = ke.DeviceArray(skip_x.astype(dtype)) + bias_d = ke.DeviceArray(bias_x.astype(dtype)) gamma_d = ke.DeviceArray(gamma) beta_d = ke.DeviceArray(beta) workspace_d = ke.DeviceArray(workspace) y_d = ke.DeviceArray(output_y) + y_add_d = ke.DeviceArray(add_output) f = getattr(ke, func) my_op = f( y_d, - workspace_d, + y_add_d, input_d, + skip_d, + bias_d, gamma_d, beta_d, + workspace_d, + epsilon, batch_size, + num_channels, height, width, - num_channels, num_groups, - epsilon, - use_swish, + use_silu, + broadcast_skip, + channels_per_block, ) - y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype) + y_ref, y_add_d_ref = group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, use_silu, has_skip) + y_ref = y_ref.astype(dtype) for impl in my_op.ListOps(): if not my_op.SelectOp(impl): @@ -95,6 +123,10 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: y_d.UpdateHostNumpyArray() np.testing.assert_allclose(y_ref, output_y, atol=1e-02) + if has_skip: + y_add_d_ref = y_add_d_ref.astype(dtype) + y_add_d.UpdateHostNumpyArray() + np.testing.assert_allclose(y_add_d_ref, add_output, atol=1e-02) dtypes = ["float32", "float16"] @@ -102,19 +134,21 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: @pytest.mark.parametrize("sd_sizes", get_sd_sizes()) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("swish", [True]) -def test_group_norm(sd_sizes, dtype, swish): +@pytest.mark.parametrize("silu", [True]) +@pytest.mark.parametrize("has_skip", [True, False]) +def test_group_norm(sd_sizes, dtype, silu, has_skip): for func in dtype_to_funcs(dtype): - run_group_norm(*sd_sizes, dtype, swish, func) + run_group_norm(*sd_sizes, dtype, silu, has_skip, func) @pytest.mark.parametrize("sd_sizes", get_sd_sizes()) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("swish", [True]) -def test_group_norm_ck(sd_sizes, dtype, swish): - swish_suffix = "Swish" if swish else "Pass" - ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype) - run_group_norm(*sd_sizes, dtype, swish, ck_f_name) +@pytest.mark.parametrize("silu", [True]) +@pytest.mark.parametrize("has_skip", [False]) +def test_group_norm_ck(sd_sizes, dtype, silu, has_skip): + silu_suffix = "Silu" if silu else "Pass" + ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype) + run_group_norm(*sd_sizes, dtype, silu, has_skip, ck_f_name) @dataclass @@ -136,37 +170,67 @@ def report(self): def profile_group_norm_func( - batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func + batch_size: int, + height: int, + width: int, + num_channels: int, + num_groups: int, + dtype: str, + silu: bool, + has_skip: bool, + func, ): np.random.seed(0) input_x = np.random.rand(batch_size, height, width, num_channels).astype(dtype) gamma = np.random.rand(num_channels).astype(np.float32) beta = np.random.rand(num_channels).astype(np.float32) - workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * 32 * 32).astype(np.float32) + workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * batch_size * num_groups).astype(np.float32) epsilon = 0.05 output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - use_swish = swish + + skip_x = ( + np.random.rand(batch_size, height, width, num_channels).astype(dtype) + if has_skip + else np.empty((0), dtype=dtype) + ) + bias_x = np.random.rand(num_channels).astype(dtype) if has_skip else np.empty((0), dtype=dtype) + add_output = ( + np.random.rand(batch_size, height, width, num_channels).astype(dtype) + if has_skip + else np.empty((0), dtype=dtype) + ) + use_silu = silu + broadcast_skip = False + channels_per_block = 0 # Compute in params initialization input_d = ke.DeviceArray(input_x) + skip_d = ke.DeviceArray(skip_x) + bias_d = ke.DeviceArray(bias_x) gamma_d = ke.DeviceArray(gamma) beta_d = ke.DeviceArray(beta) workspace_d = ke.DeviceArray(workspace) y_d = ke.DeviceArray(output_y) + y_add_d = ke.DeviceArray(add_output) f = getattr(ke, func) my_op = f( y_d, - workspace_d, + y_add_d, input_d, + skip_d, + bias_d, gamma_d, beta_d, + workspace_d, + epsilon, batch_size, + num_channels, height, width, - num_channels, num_groups, - epsilon, - use_swish, + use_silu, + broadcast_skip, + channels_per_block, ) for impl in my_op.ListOps(): duration_ms = -1 @@ -181,14 +245,14 @@ def profile_group_norm_func( ) -def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, swish=True, sort=True): +def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, silu=True, has_skip=True, sort=True): with ke.benchmark(sort): for func in dtype_to_funcs(dtype): - profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, func) + profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, func) # ck function - swish_suffix = "Swish" if swish else "Pass" - ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype) - profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, ck_f_name) + silu_suffix = "Silu" if silu else "Pass" + ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype) + profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, ck_f_name) sd_profile_sizes = [ @@ -227,7 +291,8 @@ def profile(): group.add_argument("num_channels", type=int) group.add_argument("num_groups", type=int) group.add_argument("dtype", choices=dtypes) - group.add_argument("--swish", action="store_true") + group.add_argument("--silu", action="store_true") + group.add_argument("--has_skip", action="store_true") group.add_argument("--sort", action="store_true") if len(sys.argv) == 1: @@ -241,6 +306,7 @@ def profile(): args.num_channels, args.num_groups, args.dtype, - args.swish, + args.silu, + args.has_skip, args.sort, ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu index 0bd47b2c0387e..6af163ab94b10 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu @@ -12,17 +12,21 @@ #include "python/tools/kernel_explorer/kernel_explorer_interface.h" namespace py = pybind11; - +using onnxruntime::contrib::rocm::GetGroupNormWorkspaceSizeInBytes; namespace onnxruntime { template class GroupNormNHWC : public IKernelExplorer { public: - GroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + GroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias, + DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu, + bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { type_string_ = "GroupNormNHWC_" + std::to_string(ThreadsPerBlock) + "_" + std::to_string(VecSize); } @@ -40,7 +44,7 @@ class GroupNormNHWC : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; ParamsT params_{}; contrib::rocm::GroupNormNHWCOp op_{}; std::string type_string_{}; @@ -49,11 +53,15 @@ class GroupNormNHWC : public IKernelExplorer { template class GroupNormNHWCStaticSelection : public IKernelExplorer { public: - GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { type_string_ = "GroupNormNHWCStaticSelection"; } @@ -71,7 +79,7 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; ParamsT params_{}; std::string type_string_{}; }; @@ -79,11 +87,15 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer { template class GroupNormNHWCTunable : public IKernelExplorer { public: - GroupNormNHWCTunable(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + GroupNormNHWCTunable(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { params_.TuningContext()->EnableTunableOpAndTuning(); } @@ -100,21 +112,25 @@ class GroupNormNHWCTunable : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; ParamsT params_{}; contrib::rocm::GroupNormNHWCTunableOp op_{}; }; #ifdef USE_COMPOSABLE_KERNEL -template +template class CKGroupNormNHWC : public IKernelExplorer { public: - CKGroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { - for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps()) { + CKGroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { + for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } @@ -141,7 +157,7 @@ class CKGroupNormNHWC : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; using OpT = rocm::tunable::Op; ParamsT params_{}; std::vector ops_; @@ -151,15 +167,19 @@ class CKGroupNormNHWC : public IKernelExplorer { #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL -template +template class GroupNormNHWCTriton : public IKernelExplorer { public: - GroupNormNHWCTriton(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { - for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps()) { + GroupNormNHWCTriton(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { + for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps()) { name_strings_.emplace_back(name); ops_.emplace_back(std::move(op)); } @@ -186,7 +206,7 @@ class GroupNormNHWCTriton : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; using OpT = rocm::tunable::Op; ParamsT params_{}; std::vector ops_; @@ -198,7 +218,8 @@ class GroupNormNHWCTriton : public IKernelExplorer { #define REGISTER_OP(name, type, threads_per_block, vec_size) \ py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ .def(py::init()) \ + DeviceArray&, DeviceArray&, DeviceArray&, float, \ + int, int, int, int, int, bool, bool, int>()) \ .def("SetRepeats", &name::SetRepeats) \ .def("Profile", &name::Profile) \ .def("Run", &name::Run) \ @@ -220,7 +241,8 @@ class GroupNormNHWCTriton : public IKernelExplorer { #define REGISTER_COMMON(name, type, ...) \ py::class_>(m, name) \ .def(py::init()) \ + DeviceArray&, DeviceArray&, DeviceArray&, float, \ + int, int, int, int, int, bool, bool, int>()) \ .def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \ .def("Profile", &type<__VA_ARGS__>::Profile) \ .def("Run", &type<__VA_ARGS__>::Run) \ @@ -230,11 +252,11 @@ class GroupNormNHWCTriton : public IKernelExplorer { #define REGISTER_OP_TYPED(name, type) \ REGISTER_COMMON(#name "_" #type, name, type) -#define REGISTER_CK(type, with_swish, swish_suffix) \ - REGISTER_COMMON("CKGroupNormNHWC" swish_suffix "_" #type, CKGroupNormNHWC, type, with_swish) +#define REGISTER_CK(type, with_silu, silu_suffix) \ + REGISTER_COMMON("CKGroupNormNHWC" silu_suffix "_" #type, CKGroupNormNHWC, type, with_silu) -#define REGISTER_TRITON(type, with_swish, swish_suffix) \ - REGISTER_COMMON("GroupNormNHWCTriton" swish_suffix "_" #type, GroupNormNHWCTriton, type, with_swish) +#define REGISTER_TRITON(type, with_silu, silu_suffix) \ + REGISTER_COMMON("GroupNormNHWCTriton" silu_suffix "_" #type, GroupNormNHWCTriton, type, with_silu) KE_REGISTER(m) { REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWC, half); @@ -248,16 +270,16 @@ KE_REGISTER(m) { #ifdef USE_COMPOSABLE_KERNEL REGISTER_CK(half, false, "Pass"); - REGISTER_CK(half, true, "Swish"); + REGISTER_CK(half, true, "Silu"); REGISTER_CK(float, false, "Pass"); - REGISTER_CK(float, true, "Swish"); + REGISTER_CK(float, true, "Silu"); #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL REGISTER_TRITON(half, false, "Pass"); - REGISTER_TRITON(half, true, "Swish"); + REGISTER_TRITON(half, true, "Silu"); REGISTER_TRITON(float, false, "Pass"); - REGISTER_TRITON(float, true, "Swish"); + REGISTER_TRITON(float, true, "Silu"); #endif } diff --git a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc index fefd5722054de..ea8537f243f5d 100644 --- a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc @@ -114,16 +114,21 @@ TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) { int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); std::array channels_last_values = {-1, 1}; for (const int channels_last : channels_last_values) { - if (enable_cuda) { + if (enable_cuda || enable_rocm) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); } + if (enable_rocm && channels_last != 0) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } + // Don't run the test if no providers are supported if (execution_providers.empty()) { continue; @@ -230,6 +235,7 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); std::array has_add_out_values = {true, false}; std::array skip_dims = {2, 4}; @@ -237,12 +243,16 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { constexpr int channels_last = 1; for (const int skip_dim : skip_dims) { for (const bool has_add_out : has_add_out_values) { - if (enable_cuda) { + if (enable_cuda || enable_rocm) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); } + if (enable_rocm && channels_last != 0) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } + // Don't run the test if no providers are supported if (execution_providers.empty()) { continue; diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index e286236ba6447..f1d3702e3245e 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -181,6 +181,8 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("rocm_device_prop_", "cuda_device_prop_") s = s.replace("rocm_device_arch_", "cuda_device_arch_") + s = s.replace("HipTuningContext", "RocmTuningContext") + # We want hipfft, which needs hipDataType etc, but only do this for files that have "fft" in their names # And we do this last, undoing or fixing hipify mistakes. if "fft" in src_file_path: From 124bde985ae883566c44f5cd84d351612006100c Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Tue, 20 Feb 2024 19:20:42 -0800 Subject: [PATCH 113/207] Bring QAT POC back to a functional state (#19290) --- .../test/python/qat_poc_example/README.md | 2 +- .../test/python/qat_poc_example/model.py | 56 +++++++------------ .../test/python/qat_poc_example/qat.py | 2 +- .../test/python/qat_poc_example/train.py | 18 ++---- 4 files changed, 27 insertions(+), 51 deletions(-) diff --git a/orttraining/orttraining/test/python/qat_poc_example/README.md b/orttraining/orttraining/test/python/qat_poc_example/README.md index 6840e98bd9c86..05072b410b730 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/README.md +++ b/orttraining/orttraining/test/python/qat_poc_example/README.md @@ -48,7 +48,7 @@ We use `onnxruntime.training.onnxblock` to perform the above operations to get t > **_NOTE:_** As of this writing, ORT does not have its own `"Observers"`. Instead, we rely on the `onnxruntime.quantization` tool to quantize the model and give us an initial estimate of the quantization parameters using its calibration process. Here the calibration process is used as a substitute for the observers to present the POC. -> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag AddQDQPairToWeight=True` +> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag `AddQDQPairToWeight=True` > **_NOTE:_** Typically, the bias term in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since it is quantized as int32 as opposed to int8. So, we disable quantizing the bias term using the flag QuantizeBias=False` diff --git a/orttraining/orttraining/test/python/qat_poc_example/model.py b/orttraining/orttraining/test/python/qat_poc_example/model.py index 91d7ccd7294f5..601362a59e379 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/model.py +++ b/orttraining/orttraining/test/python/qat_poc_example/model.py @@ -5,7 +5,7 @@ import onnx import torch -import onnxruntime.training.onnxblock as onnxblock +from onnxruntime.training import artifacts class MNIST(torch.nn.Module): @@ -96,42 +96,26 @@ def create_training_artifacts(model_path, artifacts_dir, model_prefix): 4. The checkpoint file """ - class MNISTWithLoss(onnxblock.TrainingModel): - def __init__(self): - super().__init__() - self.loss = onnxblock.loss.CrossEntropyLoss() - - def build(self, output_name): - return self.loss(output_name) - - mnist_with_loss = MNISTWithLoss() - onnx_model, eval_model, optimizer_model = onnx.load(model_path), None, None - - # Build the training and eval graphs - logging.info("Using onnxblock to create the training artifacts.") - with onnxblock.onnx_model(onnx_model) as model_accessor: - _ = mnist_with_loss(onnx_model.graph.output[0].name) - eval_model = model_accessor.eval_model - - # Build the optimizer graph - optimizer = onnxblock.optim.AdamW() - with onnxblock.onnx_model() as accessor: - _ = optimizer(mnist_with_loss.parameters()) - optimizer_model = accessor.model + onnx_model = onnx.load(model_path) + + requires_grad = [ + param.name + for param in onnx_model.graph.initializer + if (not param.name.endswith("_scale") and not param.name.endswith("_zero_point")) + ] + artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=artifacts_dir, + prefix=model_prefix, + ) # Create the training artifacts - train_model_path = os.path.join(artifacts_dir, f"{model_prefix}_train.onnx") - logging.info(f"Saving the training model to {train_model_path}.") - onnx.save(onnx_model, train_model_path) - eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}_eval.onnx") - logging.info(f"Saving the eval model to {eval_model_path}.") - onnx.save(eval_model, eval_model_path) - optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}_optimizer.onnx") - logging.info(f"Saving the optimizer model to {optimizer_model_path}.") - onnx.save(optimizer_model, optimizer_model_path) - trainable_params, non_trainable_params = mnist_with_loss.parameters() - checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}_checkpoint.ckpt") - logging.info(f"Saving the checkpoint to {checkpoint_path}.") - onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_path) + train_model_path = os.path.join(artifacts_dir, f"{model_prefix}training_model.onnx") + eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}eval_model.onnx") + optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}optimizer_model.onnx") + checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}checkpoint") return train_model_path, eval_model_path, optimizer_model_path, checkpoint_path diff --git a/orttraining/orttraining/test/python/qat_poc_example/qat.py b/orttraining/orttraining/test/python/qat_poc_example/qat.py index 51a15475ee911..dcc9e116fda7d 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/qat.py +++ b/orttraining/orttraining/test/python/qat_poc_example/qat.py @@ -46,7 +46,7 @@ ) logging.info("Preparing the training artifacts for QAT.") - training_model_name = "mnist_qat" + training_model_name = "mnist_qat_" artifacts_dir = os.path.join(model_dir, "training_artifacts") utils.makedir(artifacts_dir) training_artifacts = create_training_artifacts( diff --git a/orttraining/orttraining/test/python/qat_poc_example/train.py b/orttraining/orttraining/test/python/qat_poc_example/train.py index 9a429d2adc6f1..a25c071c58a48 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/train.py +++ b/orttraining/orttraining/test/python/qat_poc_example/train.py @@ -26,14 +26,10 @@ def _train_epoch(model, optimizer, train_loader): model.train() cumulative_loss = 0 for data, target in train_loader: - forward_inputs = [ - data.reshape(len(data), 784).numpy(), - target.numpy().astype(np.int32), - ] - train_loss = model(forward_inputs) + train_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64)) optimizer.step() model.lazy_reset_grad() - cumulative_loss += train_loss[0] + cumulative_loss += train_loss return cumulative_loss / len(train_loader) @@ -43,12 +39,8 @@ def _eval(model, test_loader): model.eval() cumulative_loss = 0 for data, target in test_loader: - forward_inputs = [ - data.reshape(len(data), 784).numpy(), - target.numpy().astype(np.int32), - ] - test_loss = model(forward_inputs) - cumulative_loss += test_loss[0] + test_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64)) + cumulative_loss += test_loss return cumulative_loss / len(test_loader) @@ -65,7 +57,7 @@ def train_model(qat_train_model, qat_eval_model, qat_optimizer_model, qat_checkp train_loader, test_loader = _get_dataloaders("data", batch_size) # Load the checkpoint state. - state = orttraining.CheckpointState(qat_checkpoint) + state = orttraining.CheckpointState.load_checkpoint(qat_checkpoint) # Create the training module. model = orttraining.Module(qat_train_model, state, qat_eval_model) From 8092a89688f92dee83d1d0111acaa1e1d2dfdb85 Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Tue, 20 Feb 2024 21:18:54 -0800 Subject: [PATCH 114/207] Changed command line argpasrse to process '--symmetric [True|False]'. (#19577) ### Description Accept the command line option --symmetric and its optional value correctly. If the optional value matches uncased to 'True' then set symmetric to True else set symmetric to False. Asymmetric quantization will generate zero_point input. ``` usage: matmul_4bits_quantizer.py [-h] --input_model INPUT_MODEL --output_model OUTPUT_MODEL [--block_size BLOCK_SIZE] [--symmetric [{True,False}]] [--accuracy_level ACCURACY_LEVEL] [-v] [--nodes_to_exclude NODES_TO_EXCLUDE [NODES_TO_EXCLUDE ...]] ``` ### Motivation and Context --- .../python/tools/quantization/matmul_4bits_quantizer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 3e9f9a6544a71..eb7bbec997d59 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -349,6 +349,10 @@ def process(self): self.int4_quant_algo() +def ort_convert_str_to_bool(value): + return value.lower() in ("true", "1") + + def parse_args(): parser = argparse.ArgumentParser( description="""Blockwise int4 quantization for MatMul 2D weight matrices. @@ -366,7 +370,10 @@ def parse_args(): "--symmetric", required=False, default=True, - type=bool, + const=True, + nargs="?", + type=ort_convert_str_to_bool, + choices=[True, False], help="Indicate whether to quantize the model symmetrically", ) parser.add_argument( From 58f4921686bf0a5b0442fb6df92d1b1972a118cc Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 21 Feb 2024 00:31:06 -0800 Subject: [PATCH 115/207] [js] changes to allow Float16Array if any polyfill is available (#19305) ### Description This change adds only necessary code to enable ort-web works with any Float16Array polyfill. Unlike #19302, in this PR, ort-web does not include any specific polyfill; instead, it's user's choice for how to use a polyfill. ORT-web uses Float16Array if it's available; otherwise, fallback to use Uint16Array. ```js // case 1: user does not use polyfill: import * as ort from 'onnxruntime-web'; const myF16Data = new Uint16Array(...); // need to use Uint16Array const myF16tensor = new ort.Tensor('float16', myF16Data, dims); ``` ```js // case 2: user use polyfill: import * as ort from 'onnxruntime-web'; import { Float16Array, isFloat16Array, isTypedArray, getFloat16, setFloat16, f16round, } from "@petamoriken/float16"; globalThis.Float16Array = Float16Array; // ort-web will pick the global Float16Array const myF16Data = new Float16Array(...); // Use the polyfilled Float16Array type const myF16tensor = new ort.Tensor('float16', myF16Data, dims); ``` --- js/common/lib/tensor-impl-type-mapping.ts | 34 +++++++++++++++-------- js/common/lib/tensor-impl.ts | 10 ++++--- js/web/lib/wasm/wasm-common.ts | 9 +++++- 3 files changed, 37 insertions(+), 16 deletions(-) diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts index c4a43ea27fea1..b29cb8cbd6d35 100644 --- a/js/common/lib/tensor-impl-type-mapping.ts +++ b/js/common/lib/tensor-impl-type-mapping.ts @@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map { - if (!isBigIntChecked) { - isBigIntChecked = true; - const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function'; - const isBigUint64ArrayAvailable = - typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + +// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for +// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array +// polyfill if available. +let isTypedArrayChecked = false; +export const checkTypedArray = () => { + if (!isTypedArrayChecked) { + isTypedArrayChecked = true; + const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from; + const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from; + const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from; if (isBigInt64ArrayAvailable) { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array); @@ -53,5 +58,12 @@ export const checkBigInt = () => { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array); NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64'); } + if (isFloat16ArrayAvailable) { + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array); + NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16'); + } else { + // if Float16Array is not available, use 'Uint16Array' to store the data. + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array); + } } }; diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index de18126a9d0ae..56682ef98e117 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js'; import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js'; import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js'; import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js'; -import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; +import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; import {calculateSize, tensorReshape} from './tensor-utils-impl.js'; import {Tensor as TensorInterface} from './tensor.js'; @@ -67,8 +67,8 @@ export class Tensor implements TensorInterface { arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters| TextureConstructorParameters|GpuBufferConstructorParameters, arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) { - // perform one-time check for BigInt support - checkBigInt(); + // perform one-time check for BigInt/Float16Array support + checkTypedArray(); let type: TensorType; let dims: readonly number[]; @@ -142,7 +142,9 @@ export class Tensor implements TensorInterface { throw new TypeError(`Unsupported tensor type: ${arg0}.`); } if (Array.isArray(arg1)) { - if (arg0 === 'float16') { + if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) { + // When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array. + // // Throw error here because when user try to use number array as data, // e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call // Uint16Array.from(arg1) which generates wrong data. diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index 93910af1f1bf0..54eaf5e0c43cc 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -3,6 +3,12 @@ import {Tensor} from 'onnxruntime-common'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + // This file includes common definitions. They do NOT have dependency on the WebAssembly instance. /** @@ -117,7 +123,8 @@ export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32Arr Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => { switch (type) { case 'float16': - return Uint16Array; + // allow Float16Array polyfill. + return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array; case 'float32': return Float32Array; case 'uint8': From 57d6819212464f49b30db047528be0f409dadc67 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Thu, 22 Feb 2024 00:08:47 +0800 Subject: [PATCH 116/207] [js/web] Fix fused-conv is not included in npm test (#19581) BUG: https://github.com/microsoft/onnxruntime/issues/18855 ### Description ### Motivation and Context --- js/web/test/suite-test-list.jsonc | 1 + 1 file changed, 1 insertion(+) diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 1c61518ddcdd2..b43b1ac37e37d 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1354,6 +1354,7 @@ "expand.jsonc", "fast-gelu.jsonc", "floor.jsonc", + "fused-conv.jsonc", "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", From e5ce81ae847d0b347a3dfe95abfc9e407e2f0469 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Wed, 21 Feb 2024 15:24:41 -0500 Subject: [PATCH 117/207] [java] Adding ML program flag for CoreML (#19551) ### Description Adds the new CoreML enum flags to enable ML Program support in Java. ### Motivation and Context Adds support for #19347 to the Java API. --- .../ai/onnxruntime/providers/CoreMLFlags.java | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java index eb124decf75f3..cec3fadf446ca 100644 --- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -14,7 +14,18 @@ public enum CoreMLFlags implements OrtFlags { /** Enables CoreML on subgraphs. */ ENABLE_ON_SUBGRAPH(2), // COREML_FLAG_ENABLE_ON_SUBGRAPH(0x002) /** Only enable usage of CoreML if the device has an Apple Neural Engine. */ - ONLY_ENABLE_DEVICE_WITH_ANE(4); // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004), + ONLY_ENABLE_DEVICE_WITH_ANE(4), // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004) + /** + * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also + * allow inputs with dynamic shapes. However, the performance may be negatively impacted if inputs + * have dynamic shapes. + */ + ONLY_ALLOW_STATIC_INPUT_SHAPES(8), // COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES(0x008) + /** + * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or + * later. + */ + CREATE_MLPROGRAM(16); // COREML_FLAG_CREATE_MLPROGRAM(0x010) /** The native value of the enum. */ public final int value; From 3afb38cfb7d4263f262dea33bcfa16d35c67fede Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 21 Feb 2024 12:46:16 -0800 Subject: [PATCH 118/207] [CUDA] Add use_tf32 cuda provider option (for FP32 Conv) (#19426) Follow up of https://github.com/microsoft/onnxruntime/pull/19357 to apply the use_tf32 option on fp32 cuDNN convolution. When use_tf32 = 0, we will disable TF32 in cuDNN convolution for FP32 inputs. https://docs.nvidia.com/deeplearning/cudnn/api/cudnn-graph-library.html#cudnnmathtype-t **CUDNN_FMA_MATH** - Restricted to only kernels that use FMA instructions. - On pre-NVIDIA A100 GPU devices, CUDNN_DEFAULT_MATH and CUDNN_FMA_MATH have the same behavior: Tensor Core kernels will not be selected. - With NVIDIA Ampere architecture and CUDA toolkit 11, CUDNN_DEFAULT_MATH permits TF32 Tensor Core operation and CUDNN_FMA_MATH does not. - The TF32 behavior for CUDNN_DEFAULT_MATH and the other Tensor Core math types can be explicitly disabled by the environment variable NVIDIA_TF32_OVERRIDE=0. --- onnxruntime/core/providers/cuda/nn/conv.cc | 17 ++++++++++++++--- onnxruntime/core/providers/cuda/nn/conv.h | 3 ++- .../core/providers/cuda/nn/conv_transpose.cc | 10 ++++++++-- .../training_ops/cuda/nn/conv_grad.cc | 3 ++- .../training_ops/cuda/nn/conv_shared.cc | 6 ++++-- .../training_ops/cuda/nn/conv_shared.h | 2 +- .../training_ops/cuda/nn/conv_transpose_grad.cc | 6 ++++-- 7 files changed, 35 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 82f3503919237..a417be5a86c32 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -326,7 +326,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), - CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType())); + CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType(), + UseTF32())); if (context->InputCount() >= 3) { const Tensor* B = context->Input(2); @@ -351,8 +352,13 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) + if constexpr (std::is_same::value) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } cudnnConvolutionFwdAlgoPerf_t perf; int algo_count = 1; @@ -399,6 +405,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); if (std::is_same::value) { perf.mathType = CUDNN_TENSOR_OP_MATH; + } else if (std::is_same::value && !UseTF32()) { + perf.mathType = CUDNN_FMA_MATH; } else { perf.mathType = CUDNN_DEFAULT_MATH; } @@ -480,7 +488,8 @@ Status CudnnConvolutionDescriptor::Set( const gsl::span& dilations, int groups, cudnnConvolutionMode_t mode, - cudnnDataType_t data_type) { + cudnnDataType_t data_type, + bool use_tf32) { if (!desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_)); @@ -513,6 +522,8 @@ Status CudnnConvolutionDescriptor::Set( CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH)); if (data_type == CUDNN_DATA_HALF) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH)); + } else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH)); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index bcaa4d855b81e..181fbc99fd8e9 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -29,7 +29,8 @@ class CudnnConvolutionDescriptor final { const gsl::span& dilations, int groups, cudnnConvolutionMode_t mode, - cudnnDataType_t data_type); + cudnnDataType_t data_type, + bool use_tf32); operator cudnnConvolutionDescriptor_t() const { return desc_; } diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 55dceaa2698e8..939b9959af818 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -167,7 +167,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, gsl::narrow_cast(conv_transpose_attrs_.group), mode, - CudnnTensor::GetDataType())); + CudnnTensor::GetDataType(), + UseTF32())); if (has_bias) { const auto& b_shape = p.B->Shape(); @@ -187,8 +188,13 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) + if constexpr (std::is_same::value) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } cudnnConvolutionBwdDataAlgoPerf_t perf; int algo_count = 1; diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc index f6c58445c0a5d..fc5d9b65d0f89 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc @@ -114,7 +114,8 @@ Status ConvGrad::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type)); ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args_.params.data_type)); + args_.params.data_type, + UseTF32())); if (dB) { const TensorShape& db_shape = dB->Shape(); diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc index 5dc16c68f6210..d23905496c9bb 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc @@ -233,11 +233,13 @@ bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const } template -Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { +Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results, bool use_tf32) { perf_results.resize(1); perf_results[0].algo = AlgoSearch::DEFAULT_ALGO; if (args.params.data_type == CUDNN_DATA_HALF) { perf_results[0].mathType = CUDNN_TENSOR_OP_MATH; + } else if (args.params.data_type == CUDNN_DATA_FLOAT && !use_tf32) { + perf_results[0].mathType = CUDNN_FMA_MATH; } else { perf_results[0].mathType = CUDNN_DEFAULT_MATH; } @@ -256,7 +258,7 @@ Status AlgoIterator::TryAll(const CUDAExecutionProvider* provider, const std::vector perf_results; ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault - ? OnlyDefaultAlgorithm(args_, perf_results) + ? OnlyDefaultAlgorithm(args_, perf_results, provider->UseTF32()) : AlgoSearch::FindAlgorithms(args_, provider, allocator, perf_results)); for (auto& algo_perf : perf_results) { if (f(algo_perf) == Status::OK()) { diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h index a2d4bf3bdc006..3fdb4306bfbbb 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h @@ -75,7 +75,7 @@ class AlgoIterator { Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, std::function f); - static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results); + static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results, bool use_tf32); private: const ConvArgs& args_; diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc index 5f7206fc121ec..d3f5a89434a48 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc @@ -182,7 +182,8 @@ Status ConvTransposeGrad::PrepareConvForwardArgs(const Tensor& X, const Tenso ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args.params.data_type)); + args.params.data_type, + UseTF32())); } return Status::OK(); @@ -287,7 +288,8 @@ Status ConvTransposeGrad::PrepareConvBackwardFilterArgs(const Tensor& X, cons ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args.params.data_type)); + args.params.data_type, + UseTF32())); if (dB) { const auto& b_shape = dB->Shape(); From ebd220b0730f9898aaa0275ef0d8195ce70057d0 Mon Sep 17 00:00:00 2001 From: Matttttt <18152455+martholomew@users.noreply.github.com> Date: Wed, 21 Feb 2024 21:38:18 +0000 Subject: [PATCH 119/207] Misspelling in README.md (#19433) Fixed a misspelling. --- js/web/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/web/README.md b/js/web/README.md index c75a40ad6da28..906c78a1b7ec4 100644 --- a/js/web/README.md +++ b/js/web/README.md @@ -12,7 +12,7 @@ The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard f With ONNX Runtime Web, web developers can score models directly on browsers with various benefits including reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience. -ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web complies the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. +ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web compiles the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend. See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports. @@ -22,7 +22,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun ## Documents -### Developement +### Development Refer to the following links for development information: From 38c34323939bac03b9648b2e59dbbe8de0bd7092 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 21 Feb 2024 13:58:53 -0800 Subject: [PATCH 120/207] Bump ip from 1.1.8 to 1.1.9 in /js/react_native (#19582) Bumps [ip](https://github.com/indutny/node-ip) from 1.1.8 to 1.1.9.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ip&package-manager=npm_and_yarn&previous-version=1.1.8&new-version=1.1.9)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) Dependabot will merge this PR once CI passes on it, as requested by @fs-eire. [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/react_native/yarn.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index 4dca90d7415cf..bbb0c4f3d1e22 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -3701,9 +3701,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-absolute@^1.0.0: version "1.0.0" From 5197db19802a39e47d19ac829cd08a94bacbdfbb Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Wed, 21 Feb 2024 15:45:44 -0800 Subject: [PATCH 121/207] Diable __cpuid call for ARM64EC (#19592) Diable __cpuid call for ARM64EC Co-authored-by: Sheil Kumar --- winml/lib/Api/HardwareCoreEnumerator.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp index b6b44690f4f6c..d04e276347170 100644 --- a/winml/lib/Api/HardwareCoreEnumerator.cpp +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -84,7 +84,7 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetNumberOPhysicalAndEngineeringCores(); -#if !defined(_M_ARM64) && !defined(__aarch64__) +#if !defined(_M_ARM64EC) && !defined(_M_ARM64) && !defined(__aarch64__) const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" int regs_leaf0[4]; int regs_leaf7[4]; From 3d88487c96bf467c4b83dff179c9e282602e2d64 Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 22 Feb 2024 10:35:26 +0800 Subject: [PATCH 122/207] Minor Triton Fix (#19589) Including removing a unnecessary assert, and add support of passing string attribute from ONNX node attribute to python functoin kwargs (mainly for passing debug info from graph to python for now). --- .../orttraining/core/framework/triton/triton_op_executor.cc | 2 ++ orttraining/orttraining/python/training/ort_triton/_utils.py | 3 ++- orttraining/orttraining/training_ops/cpu/triton/triton_op.h | 5 ++++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/orttraining/orttraining/core/framework/triton/triton_op_executor.cc b/orttraining/orttraining/core/framework/triton/triton_op_executor.cc index 092ab89d5d760..f30d6ddee253a 100644 --- a/orttraining/orttraining/core/framework/triton/triton_op_executor.cc +++ b/orttraining/orttraining/core/framework/triton/triton_op_executor.cc @@ -106,6 +106,8 @@ void TritonOpExecutor::ExecuteByFuncName(const std::string& func_name, const Inl PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyLong_FromLongLong(std::stoll(kv.second.first))); } else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyFloat_FromDouble(std::stod(kv.second.first))); + } else if (kv.second.second == ONNX_NAMESPACE::TensorProto_DataType_STRING) { + PyDict_SetItemString(python_kwargs.get(), kv.first.c_str(), PyUnicode_FromString(kv.second.first.c_str())); } else { ORT_THROW("Unsupported kwargs data type: ", kv.second.second); } diff --git a/orttraining/orttraining/python/training/ort_triton/_utils.py b/orttraining/orttraining/python/training/ort_triton/_utils.py index 95e6703be8783..877eacc0b775f 100644 --- a/orttraining/orttraining/python/training/ort_triton/_utils.py +++ b/orttraining/orttraining/python/training/ort_triton/_utils.py @@ -141,13 +141,14 @@ def get_reduce_info(node: NodeProto, graph: GraphProto, input_rank: int) -> Tupl def next_power_of_2(n: int) -> int: - assert n <= 2**32, "32-bit only" + """Return the smallest power of 2 greater than or equal to n""" n -= 1 n |= n >> 1 n |= n >> 2 n |= n >> 4 n |= n >> 8 n |= n >> 16 + n |= n >> 32 n += 1 return n diff --git a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h index f226db76f7ed7..db8e8558ab884 100644 --- a/orttraining/orttraining/training_ops/cpu/triton/triton_op.h +++ b/orttraining/orttraining/training_ops/cpu/triton/triton_op.h @@ -25,12 +25,15 @@ class TritonOp final : public OpKernel { attr.first == "onnx_string") { continue; } - // Support int64 and float only for now, skip other types. + // Support int64, float and string only for now, skip other types. if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_INT) { kwargs_.insert({attr.first, {std::to_string(attr.second.i()), ONNX_NAMESPACE::TensorProto_DataType_INT64}}); } else if (attr.second.type() == ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_FLOAT) { kwargs_.insert({attr.first, {std::to_string(attr.second.f()), ONNX_NAMESPACE::TensorProto_DataType_FLOAT}}); + } else if (attr.second.type() == + ONNX_NAMESPACE::AttributeProto::AttributeType::AttributeProto_AttributeType_STRING) { + kwargs_.insert({attr.first, {attr.second.s(), ONNX_NAMESPACE::TensorProto_DataType_STRING}}); } } } From 8354329086ebb190db9ea0cb6a3fa72f53f8f881 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:34:45 +0800 Subject: [PATCH 123/207] [ROCm] SkipGroupNorm triton (#19408) Change GroupNorm triton to support SkipGroupNorm --- .../rocm/diffusion/group_norm_triton.cuh | 23 ++++++++--- .../rocm/diffusion/group_norm_triton.py | 39 +++++++++++++++++-- .../kernel_explorer/kernels/groupnorm_test.py | 12 ++++++ 3 files changed, 64 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index b3d3e92209b39..c6ca16bfdfc80 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -46,8 +46,6 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), - "Input skip or bias is not supported by triton kernel."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", @@ -61,23 +59,36 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { } // Construct args for launch kernel struct { - void* X; - void* Y; + const void* src; + const void* skip; + const void* bias; + void* out; + void* add_out; const void* gamma; const void* beta; int hw; int c; int c_per_group; float eps; + bool has_skip; + bool has_bias; + bool broadcast_skip; } args = { - (void*)params->src, + (const void*)params->src, + (const void*)params->skip, + (const void*)params->bias, (void*)params->dst, + (void*)params->skip_workspace, (const void*)params->gamma, (const void*)params->beta, params->hw, params->c, params->channels_per_group, - params->epsilon}; + params->epsilon, + params->skip != nullptr, + params->bias != nullptr, + params->broadcast_skip, + }; // Grid dim is (batch_count, groups, 1) return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args)); diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 5368cb1cf635b..5ba96ebc117f0 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -12,13 +12,19 @@ @triton.jit def group_norm_kernel( input_ptr, + skip_ptr, + bias_ptr, output_ptr, + add_out_ptr, gamma_ptr, beta_ptr, img_size, c, c_per_group, eps, + has_skip, + has_bias, + broadcast_skip, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, ACTIVATION_SILU: tl.constexpr, @@ -36,14 +42,35 @@ def group_norm_kernel( offsets = hw[:, None] * c + cols[None, :] mask = (cols < c_per_group)[None, :] + bias = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + if has_skip: + add_out_ptr += row_x * stride + row_y * c_per_group + if broadcast_skip: + broadcast_skip_ptr = skip_ptr + row_x * c + row_y * c_per_group + bias += tl.load(broadcast_skip_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + else: + skip_ptr += row_x * stride + row_y * c_per_group + if has_bias: + bias_ptr += row_y * c_per_group + bias += tl.load(bias_ptr + cols, mask=cols < c_per_group, other=0.0).to(tl.float32) + # Calculate mean and variance _sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) _square_sum = tl.zeros([HW_SIZE, BLOCK_SIZE], dtype=tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): x_ptr = input_ptr + i * HW_SIZE * c a = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip and not broadcast_skip: + s_ptr = skip_ptr + i * HW_SIZE * c + s = tl.load(s_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + a += s + if has_bias or broadcast_skip: + a += bias _sum += a _square_sum += a * a + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + tl.store(add_y_ptr + offsets, a, mask=mask) # Set axis=None (or leave it unspecified) to reduce all axes. # TODO: In older Triton we have to reduce an axis at a time, but in our case @@ -57,9 +84,13 @@ def group_norm_kernel( gamma = tl.load(gamma_ptr + cols, mask=cols < c_per_group).to(tl.float32) beta = tl.load(beta_ptr + cols, mask=cols < c_per_group).to(tl.float32) for i in range(tl.cdiv(img_size, HW_SIZE)): - x_ptr = input_ptr + i * HW_SIZE * c y_ptr = output_ptr + i * HW_SIZE * c - x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + if has_skip: + add_y_ptr = add_out_ptr + i * HW_SIZE * c + x = tl.load(add_y_ptr + offsets, mask=mask, other=0.0).to(tl.float32) + else: + x_ptr = input_ptr + i * HW_SIZE * c + x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) x_hat = (x - group_mean) * rstd y = x_hat * gamma + beta if ACTIVATION_SILU: @@ -77,7 +108,7 @@ def group_norm_kernel( hw_sizes = [8, 16, 32, 64, 128, 256] warps = [1, 2, 4, 8, 16] name_pattern = "GroupNormTriton_{}_{}_b{}_hw{}_w{}" -sig_pattern = "*{},*{},*fp32,*fp32,i32,i32,i32,fp32" +sig_pattern = "*{},*{},*{},*{},*{},*fp32,*fp32,i32,i32,i32,fp32,i1,i1,i1" group_pattern = "GroupNormTriton_{}_{}" @@ -88,7 +119,7 @@ def get_function_table(): silu_suffix = "Silu" if silu else "Pass" name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) group = group_pattern.format(silu_suffix, dtype) - sig = sig_pattern.format(dtype, dtype) + sig = sig_pattern.format(dtype, dtype, dtype, dtype, dtype) kwargs = { "num_warps": warp, "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py index 8334d20e47c86..400a9d8a7a187 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py @@ -80,6 +80,18 @@ def run_group_norm( ) use_silu = silu broadcast_skip = False + if has_skip: + skip_x_shape = skip_x.shape + b2 = len(skip_x_shape) == 2 and skip_x_shape[0] == batch_size and skip_x_shape[1] == num_channels + b4 = ( + len(skip_x_shape) == 4 + and skip_x_shape[0] == batch_size + and skip_x_shape[1] == 1 + and skip_x_shape[2] == 1 + and skip_x_shape[3] == num_channels + ) + if b2 or b4: + broadcast_skip = True channels_per_block = 0 # Compute in params initialization input_d = ke.DeviceArray(input_x.astype(dtype)) From 05ed89f46980b7e5a5328bc20af8b32ca9f1f715 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:34:55 +0800 Subject: [PATCH 124/207] [ROCm] Add excluded libs for ROCm python package (#19586) The rocm lib version has changed in rocm 6.0 Using libs packaged in whl might cause errors. For example, `libamdhip64.so.6` packaged in whl will cause compute error when training gpt2 model. The root cause still in investigating. --- setup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.py b/setup.py index 03e1cb75ba581..9a5fc29dd5e02 100644 --- a/setup.py +++ b/setup.py @@ -205,18 +205,23 @@ def run(self): rocm_dependencies = [ "libamd_comgr.so.2", "libamdhip64.so.5", + "libamdhip64.so.6", "libdrm.so.2", "libdrm_amdgpu.so.1", "libelf.so.1", "libhipfft.so.0", "libhiprtc.so.5", + "libhiprtc.so.6", "libhsa-runtime64.so.1", "libMIOpen.so.1", "libnuma.so.1", "librccl.so.1", "librocblas.so.3", + "librocblas.so.4", "librocfft.so.0", + "libroctx64.so.4", "librocm_smi64.so.5", + "librocm_smi64.so.6", "libroctracer64.so.4", "libtinfo.so.6", "libmigraphx_c.so.3", From 6b73ab3e3e72a9f2008e8d0e221b0be77d2993b1 Mon Sep 17 00:00:00 2001 From: cao lei Date: Thu, 22 Feb 2024 10:19:08 -0800 Subject: [PATCH 125/207] Introduce reused_buffer_index_per_stream in allocation planner which will be reset after computing the reuse buffer for each stream (#19515) ### Description Introduce reused_buffer_index_per_stream in allocation planner which will be reset after computing the reuse buffer for each stream. So if a NodeArg is an input of several Ops across different streams and reuses other NodeArg, the reused NodeArg won't be involved when computing the second stream's reuse plan. ### Motivation and Context This is to fix https://github.com/microsoft/onnxruntime/issues/19480, which is a crash for the scenario mentioned above. --------- Co-authored-by: Lei Cao --- .../core/framework/allocation_planner.cc | 44 ++++++------ .../test/framework/allocation_planner_test.cc | 68 ++++++++++++++++++ .../multi_stream_models/issue_19480.onnx | Bin 0 -> 760 bytes 3 files changed, 91 insertions(+), 21 deletions(-) create mode 100644 onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index ea7a6432a7507..158ab8ed610f4 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -182,7 +182,6 @@ class PlannerImpl { // upstream_node_0 and upstream_node_1 are the immmediate upstream nodes of downstream_node // upstream_node_2 is the immediate nodes ahead of downstream_node in the same logic stream InlinedHashMap> dependence_graph_; - InlinedHashMap> value_consumer_map_; InlinedHashMap value_node_map_; // OrtValueInfo: Auxiliary information about an OrtValue used only during plan-generation: @@ -295,7 +294,7 @@ class PlannerImpl { } #endif - // Find if there exists some input tensor that we can use in-place for output_arg_num-th input in the node. + // Find if there exists some input tensor that we can use in-place for output_arg_num-th output in the node. bool FindReusableInput(const onnxruntime::Node& node, int output_arg_num, OrtValueIndex* reusable_input, bool* is_strided_tensor) { *is_strided_tensor = false; @@ -530,6 +529,7 @@ class PlannerImpl { // Initialize allocation plan: plan_.allocation_plan.resize(num_ml_values); + for (int i = 0; static_cast(i) < num_ml_values; i++) AllocPlan(i).reused_buffer = i; } bool HasExternalOutputs(const Node& node) const { @@ -1065,7 +1065,8 @@ class PlannerImpl { // build the consumer list for each value int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1; - value_consumer_map_.reserve(num_ml_values); + InlinedHashMap> value_consumer_map; + value_consumer_map.reserve(num_ml_values); // iterate each stream from back, so the first element is the last consumer in single stream case for (auto& stream : stream_nodes_) { @@ -1078,10 +1079,10 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer - value_consumer_map_[origin].insert(node_index); + value_consumer_map[origin].insert(node_index); } } return Status::OK(); @@ -1138,8 +1139,8 @@ class PlannerImpl { std::cout << p_input_arg->Name() << " reused by " << p_output_arg->Name() << " as input" << std::endl; allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); found_reusable = true; break; @@ -1168,8 +1169,8 @@ class PlannerImpl { allocation_plan[reusable_input].alloc_kind == AllocKind::kAllocate) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = reusable_input; - value_consumer_map_[reusable_input].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[reusable_input].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(reusable_input); continue; } // if @@ -1187,11 +1188,11 @@ class PlannerImpl { OrtValueIndex input_arg_index{}; if (value_map.GetIdx(p_input_arg->Name(), input_arg_index).IsOK() && allocation_plan[input_arg_index].alloc_kind == AllocKind::kAllocate) { - if (value_consumer_map_[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { + if (value_consumer_map[input_arg_index].size() == 1 && SameSize(*p_input_arg, *p_output_arg)) { allocation_plan[output_idx_global].alloc_kind = AllocKind::kReuse; allocation_plan[output_idx_global].reused_buffer = input_arg_index; - value_consumer_map_[input_arg_index].insert(value_consumer_map_[output_idx_global].begin(), - value_consumer_map_[output_idx_global].end()); + value_consumer_map[input_arg_index].insert(value_consumer_map[output_idx_global].begin(), + value_consumer_map[output_idx_global].end()); reused.insert(input_arg_index); } } @@ -1266,7 +1267,7 @@ class PlannerImpl { } bool all_covered = true; - for (auto consumer : value_consumer_map_[output_idx_global]) { + for (auto consumer : value_consumer_map[output_idx_global]) { if (deps->find(consumer) == deps->end()) { all_covered = false; break; @@ -1277,9 +1278,9 @@ class PlannerImpl { allocation_plan[downstream_value].reused_buffer = output_idx_global; get_reused = true; // add new consumer for the value to be reused - value_consumer_map_[output_idx_global].insert(value_node_map_[downstream_value]); - value_consumer_map_[output_idx_global].insert(value_consumer_map_[downstream_value].begin(), - value_consumer_map_[downstream_value].end()); + value_consumer_map[output_idx_global].insert(value_node_map_[downstream_value]); + value_consumer_map[output_idx_global].insert(value_consumer_map[downstream_value].begin(), + value_consumer_map[downstream_value].end()); node_iter = size_iter->second.erase(node_iter); if (size_iter->second.empty()) { local_iter->second.erase(size_iter); @@ -1342,8 +1343,9 @@ class PlannerImpl { ort_value_usecount.reserve(ort_value_info_.size()); #endif for (size_t i = 0; i < stream_nodes_.size(); ++i) { - // compute use count first + // compute use count first. TODO(leca): call ComputeReuseCount() only once is enough! ORT_RETURN_IF_ERROR(ComputeReuseCount()); + for (int j = 0; static_cast(j) < ort_value_info_.size(); j++) Buffer(j) = j; #if !defined(ORT_MINIMAL_BUILD) && defined(ORT_MEMORY_PROFILE) if (i == 0) { for (auto ort_value_info : ort_value_info_) { @@ -1693,8 +1695,8 @@ class PlannerImpl { const auto& name = input.Name(); int value_idx; ORT_RETURN_IF_ERROR(ort_value_name_idx_map_.GetIdx(name, value_idx)); - auto origin = Buffer(value_idx); - if (origin != -1 && plan_.allocation_plan[origin].alloc_kind == AllocKind::kAllocate) { + auto origin = AllocPlan(value_idx).reused_buffer; + if (AllocPlan(origin).alloc_kind == AllocKind::kAllocate) { // add current node as consumer for origin buffer value_consumers[origin].push_back(node_index); } @@ -1889,7 +1891,7 @@ class PlannerImpl { // 2. the consumer is in the same stream(non-cpu device), but it consumes a CPU tensor from an non-shape op. // for example, a resize cuda kernel consumer a tensor from MemCpyToHost cuda kernel on the same stream. // in this case, the FIFO can't guarantee the cpu tensor is ready when resize kernel is launching - OrtDevice::DeviceType output_arg_device = plan_.allocation_plan[output_arg_idx].location.Type(); + OrtDevice::DeviceType output_arg_device = AllocPlan(output_arg_idx).location.Type(); WaitNotificationFn wait_handle = stream_handle_registry.GetWaitHandle(stream_device, output_arg_device); if ((node_stream_map_[it->Index()] != i || output_arg_device == OrtDevice::CPU) && wait_handle != nullptr) { if (node_to_notification.find(node_index) == node_to_notification.end()) { diff --git a/onnxruntime/test/framework/allocation_planner_test.cc b/onnxruntime/test/framework/allocation_planner_test.cc index d7b1de5c930c5..3e0d94e94e48c 100644 --- a/onnxruntime/test/framework/allocation_planner_test.cc +++ b/onnxruntime/test/framework/allocation_planner_test.cc @@ -1974,6 +1974,74 @@ TEST_F(PlannerTest, TestCpuIf) { ASSERT_TRUE(exe_plan[1]->steps_[6]->ToString().substr(0, WaitOnEPStep.size()) == WaitOnEPStep); } } + +// model looks like: +// |-----------> Gather +// |-----------> Gather +// |-----------> Gather +// |-----------> Gather +// Shape ----------------> Reshape --> Shape ------------------> Reshape +// ^ ^ +// InstanceNormalization ----| InstanceNormalization ------| +// +// Python script to create this model: +// def CreateModelFor19480(): +// #shape->reshape->shape->reshape, 4 gather +// graphNodes = [] +// graphNodes.append(h.make_node('Shape', inputs=['shape_input'], outputs=['9'])) +// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in0_input', 'scale0', 'B0'], outputs=['8'])) +// graphNodes.append(h.make_node('Reshape', inputs=['8', '9'], outputs=['Reshape15_output'])) +// graphNodes.append(h.make_node('Shape', inputs=['Reshape15_output'], outputs=['281'])) +// graphNodes.append(h.make_node('InstanceNormalization', inputs=['in1_input', 'scale1', 'B1'], outputs=['293'])) +// graphNodes.append(h.make_node('Reshape', inputs=['293', '281'], outputs=['output0'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices1'], outputs=['output1'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices2'], outputs=['output2'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices3'], outputs=['output3'])) +// graphNodes.append(h.make_node('Gather', inputs=['281', 'indices4'], outputs=['output4'])) +// g = h.make_graph(graphNodes, 'issue_19480', +// [h.make_tensor_value_info('shape_input', tp.FLOAT, ['batch', 128, None, None]), +// h.make_tensor_value_info('in0_input', tp.FLOAT, ['batch', 32, None]), +// h.make_tensor_value_info('scale0', tp.FLOAT, [32]), +// h.make_tensor_value_info('B0', tp.FLOAT, [32]), +// h.make_tensor_value_info('in1_input', tp.FLOAT, ['batch', 32, None]), +// h.make_tensor_value_info('scale1', tp.FLOAT, [32]), +// h.make_tensor_value_info('B1', tp.FLOAT, [32]), +// h.make_tensor_value_info('indices1', tp.INT32, []), +// h.make_tensor_value_info('indices2', tp.INT32, []), +// h.make_tensor_value_info('indices3', tp.INT32, []), +// h.make_tensor_value_info('indices4', tp.INT32, [])], +// [h.make_tensor_value_info('output0', tp.FLOAT, None), +// h.make_tensor_value_info('output1', tp.INT64, None), +// h.make_tensor_value_info('output2', tp.INT64, None), +// h.make_tensor_value_info('output3', tp.INT64, None), +// h.make_tensor_value_info('output4', tp.INT64, None)]) +// model = h.make_model(g, opset_imports=[h.make_operatorsetid("", 17)], producer_name='producer_name') +// onnx.save(model, 'issue_19480.onnx') +// +TEST(AllocationPlannerTest, ReusedInputCrossDifferentStreams) { + SessionOptions sess_opt; + sess_opt.graph_optimization_level = TransformerLevel::Default; + + InferenceSession sess(sess_opt, GetEnvironment(), ORT_TSTR("./testdata/multi_stream_models/issue_19480.onnx")); + auto status = sess.RegisterExecutionProvider(DefaultCudaExecutionProvider()); + status = sess.Load(); + status = sess.Initialize(); + ASSERT_TRUE(status.IsOK()) << "No crash"; + const SequentialExecutionPlan* plan = sess.GetSessionState().GetExecutionPlan(); + ASSERT_EQ(plan->allocation_plan[14].alloc_kind, AllocKind::kReuse) << "The input of reshape and gather will reuse the output of shape"; + + int gather_count = 0; + for (size_t i = 0; i < plan->execution_plan[1]->steps_.size(); i++) { + if (strstr(typeid(*(plan->execution_plan[1]->steps_[i])).name(), "LaunchKernelStep")) { + const Node* node = sess.GetSessionState().GetGraphViewer().GetNode(plan->execution_plan[1]->steps_[i]->GetNodeIndex()); + if (node->OpType() == "Gather") + gather_count++; + else + FAIL() << "CPU stream should contain only gather ops"; + } + } + ASSERT_EQ(gather_count, 4) << "4 gather ops are all placed in CPU stream"; +} #endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx b/onnxruntime/test/testdata/multi_stream_models/issue_19480.onnx new file mode 100644 index 0000000000000000000000000000000000000000..dc7d39206dd49f4ef6daf65b7d58c5b456ecf331 GIT binary patch literal 760 zcmaixKTm@|7>Bw3f%9#m_0_6_F_p#1gab^#v5RqW(2b?J(o1>?g{IKO$uH`6@t{2^ z#m0q%=Y9D7j(aJ^(>&%f;j=_M&c!l&{_evy4DtnEiK$Fin*vE__dm*aU~nQ+XN$p9 zA11aA3GP#64zs z+VGAUzBYVq;6Ud2Mod}g2Tt_Ry!IQoq685v?9X@+FQ7}m2pC{Q_TCzB1Q$v>tF;at zE9X-02LY%OdZ2hTthTjJs;u4R{+GpCSxtg_7ivO}nrK8dbFt05KbWuCO#ReubJg)l W4Oj)N8n}nRI|Tj~OnP7p&wl_GGP8{U literal 0 HcmV?d00001 From 3bdb10d5ca4f258ec444863bcd5e839eeac5c238 Mon Sep 17 00:00:00 2001 From: jingyanwangms <47403504+jingyanwangms@users.noreply.github.com> Date: Thu, 22 Feb 2024 10:56:25 -0800 Subject: [PATCH 126/207] Move import to when needed to avoid circular dependency error (#19579) ### Description Move import to when needed to avoid circular dependency error ### Motivation and Context Fixes dependency error described here: https://github.com/microsoft/DeepSpeed/issues/5140 --------- Co-authored-by: Thiago Crepaldi --- .../python/training/ortmodule/_graph_execution_manager.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 779b6bfe50422..fda6e345da235 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -20,7 +20,6 @@ from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference from onnxruntime.training.utils import ORTModelInputOutputSchemaType, PTable, onnx_dtype_to_pytorch_dtype -from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils from ._fallback import ( @@ -143,6 +142,9 @@ def __init__( self._zero_stage3_param_map = {} if self._runtime_options.enable_zero_stage3_support: + # Move import to here to avoid circular dependency error + from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 # type: ignore[import] + # Cannot toggle feature enabling/disabling after the first time enabled. configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) From fe82fccf1a4d7ea6c24c8448d7264df36605c370 Mon Sep 17 00:00:00 2001 From: Xu Xing Date: Fri, 23 Feb 2024 05:09:28 +0800 Subject: [PATCH 127/207] [js/webgpu] Fix Conv2DTransposeMatMul f16 compilation failure (#19596) This is used in sam-h-decoder-f16. ### Description ### Motivation and Context --- .../ops/3rd-party/conv_backprop_mm_webgpu.ts | 22 +++++++++++-------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts index b5b6a2a15cd8c..11c8778b72335 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common'; import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types'; -import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common'; +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils'; -import {biasSnippet, typeSnippet} from './activation_util'; +import {biasSnippet} from './activation_util'; import {utilFunctions} from './conv_util'; import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; const conv2dTransposeCommonSnippet = - (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => { - const type = typeSnippet(innerElementSize, 'f32'); + (isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string, + innerElementSize = 4): string => { const getWSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: @@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet = let v1 = w[getIndexFromCoords4D(coord1, vec4(uniforms.w_shape))]; let v2 = w[getIndexFromCoords4D(coord2, vec4(uniforms.w_shape))]; let v3 = w[getIndexFromCoords4D(coord3, vec4(uniforms.w_shape))]; - return vec4(v0, v1, v2, v3); + return ${type}(v0, v1, v2, v3); `; default: throw new Error(`innerElementSize ${innerElementSize} is not supported.`); @@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo = const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components); inputVariables.push(bias); declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${bias.type.value} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo = {name: 'pads', type: 'i32', length: pads.length} ]; appendActivationUniforms(attributes, uniforms); + const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1); + if (elemType !== 'f16' && elemType !== 'f32') { + throw new Error(`elemType ${elemType} is not supported.`); + } return ` ${utilFunctions('uniforms.result_strides')} ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)}; ${declareFunctions} - ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)} ${ isVec4 ? makeMatMulPackedVec4Source( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( - elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, + elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false, undefined, sequentialAccessByThreads)}`; }; From 09622418c45b265977a8f1f17581e15719357423 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 22 Feb 2024 13:15:13 -0800 Subject: [PATCH 128/207] Add special handling if there is only 1 graph inside the cached QNN context binary (#19594) Add special handling if there is only 1 graph inside the cached QNN context binary. No need to make the EPContext node name match the QNN graph name. This is for better backward compatibility in case the QNN context model is generated before the PR for QNN context binary model support multi-partition. --- .../qnn/builder/onnx_ctx_model_helper.cc | 6 +- .../qnn/builder/onnx_ctx_model_helper.h | 3 +- .../qnn/builder/qnn_backend_manager.cc | 15 ++-- .../providers/qnn/qnn_execution_provider.cc | 3 +- .../test/providers/qnn/qnn_ep_context_test.cc | 83 ++++++++++++++++++- 5 files changed, 99 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index c2e71081b898e..2d8ec295d613b 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -151,12 +151,14 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - std::unordered_map>& qnn_models) { + std::unordered_map>& qnn_models, + const logging::Logger& logger) { Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models); // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model if (!status.IsOK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage()); + LOGS(logger, ERROR) << "Failed to load from EpContext model. " << status.ErrorMessage(); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContext model. ", status.ErrorMessage()); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index b1360b4e576fa..7d56b45a1dbcd 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -56,7 +56,8 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::PathString& ctx_onnx_model_path, QnnBackendManager* qnn_backend_manager, - std::unordered_map>& qnn_models); + std::unordered_map>& qnn_models, + const logging::Logger& logger); Status CreateEPContextNodes(Model* model, unsigned char* buffer, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 5f0b87c7cb9d7..ca34a1efa6ca7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -573,11 +573,16 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t // More work to support multiple partition, how to map the graph name in compile to qnn graph name // Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile - for (uint32_t i = 0; i < graph_count; ++i) { - std::string graph_name(graphs_info[i].graphInfoV1.graphName); - auto qnn_model_pos = qnn_models.find(graph_name); - ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names."); - ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i])); + if (1 == graph_count) { + auto qnn_model_pose = qnn_models.begin(); + ORT_RETURN_IF_ERROR(qnn_model_pose->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[0])); + } else { + for (uint32_t i = 0; i < graph_count; ++i) { + std::string graph_name(graphs_info[i].graphInfoV1.graphName); + auto qnn_model_pos = qnn_models.find(graph_name); + ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names."); + ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i])); + } } qnn_sys_interface_.systemContextFree(sys_ctx_handle); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index f5a166d36b15a..9a6540a3efea5 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -670,7 +670,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, context_cache_path, qnn_backend_manager_.get(), - qnn_models)); + qnn_models, + logger)); for (auto fused_node_and_graph : fused_nodes_and_graphs) { const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index b1f3b52e77553..eaef6f6315157 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -463,7 +463,6 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_InvalidGraph) { InferenceSessionWrapper session_object{so, GetEnvironment()}; - std::string provider_type = kCpuExecutionProvider; ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); ASSERT_STATUS_OK(session_object.Load(qnn_ctx_model_data.data(), static_cast(qnn_ctx_model_data.size()))); // Verify the return status with code INVALID_GRAPH @@ -486,7 +485,6 @@ std::string CreateQnnCtxModelWithNonEmbedMode(std::string external_bin_path) { auto* graph_output = helper.MakeOutput(shape); Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); ep_context_node.AddAttribute("embed_mode", static_cast(0)); - // The .. in the path will cause INVALID_GRAPH ep_context_node.AddAttribute("ep_cache_context", external_bin_path); ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); ep_context_node.AddAttribute("source", "QNN"); @@ -651,6 +649,87 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary2InputsTest) { ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } +// Context binary only contains a single QNN graph, generated context cache model (detached mode) only has 1 EPContext node +// Create another Onnx model which also reference to the bin file, +// but the node name is not same with the QNN graph name inside the bin file. +// This is to support backward compitable for the models generated before the PR that +// make context generation support multi-partition +TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphNameInCtx) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + const std::string context_binary_file = "./qnn_context_cache_non_embed.onnx"; + std::filesystem::path context_bin = "qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"; + std::remove(context_binary_file.c_str()); + std::remove(context_bin.string().c_str()); + + std::unordered_map session_option_pairs; + session_option_pairs.emplace(kOrtSessionOptionEpContextEnable, "1"); + session_option_pairs.emplace(kOrtSessionOptionEpContextFilePath, context_binary_file); + session_option_pairs.emplace(kOrtSessionOptionEpContextEmbedMode, "0"); + + const TestInputDef input_def({1, 2, 3}, false, -10.0f, 10.0f); + const std::string op_type = "Atan"; + + // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. + // 1st run will generate the Onnx skeleton file + Qnn context cache binary file + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), + provider_options, + 14, + ExpectedEPNodeAssignment::All, + QDQTolerance(), + logging::Severity::kERROR, + "", // context model file path, not required for this inference + session_option_pairs); + + // Check the Onnx skeleton file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + // Check the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_bin)); + + const std::unordered_map domain_to_version = {{"", 11}, {kMSDomain, 1}}; + auto& logging_manager = DefaultLoggingManager(); + onnxruntime::Model model("QNN_ctx_model", false, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, {}, + logging_manager.DefaultLogger()); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + std::vector shape = {1, 2, 3}; + NodeArg* graph_input = MakeTestInput(helper, TestInputDef(shape, false, {0.0f, 1.0f, 0.0f, 1.0f, 0.0f, 1.0f})); + auto* graph_output = helper.MakeOutput(shape); + Node& ep_context_node = helper.AddNode("EPContext", {graph_input}, {graph_output}, kMSDomain); + ep_context_node.AddAttribute("embed_mode", static_cast(0)); + ep_context_node.AddAttribute("ep_cache_context", context_bin.string()); + ep_context_node.AddAttribute("partition_name", "QNNExecutionProvider_QNN_1110111000111000111_1_0"); + ep_context_node.AddAttribute("source", "QNNExecutionProvider"); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(graph.Resolve()); + std::string model_data; + model.ToProto().SerializeToString(&model_data); + + // loads and run from Onnx skeleton file + Qnn context cache binary file + + SessionOptions so; + so.session_logid = "qnn_ctx_model_logger"; + RunOptions run_options; + run_options.run_tag = so.session_logid; + + InferenceSessionWrapper session_object{so, GetEnvironment()}; + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(QnnExecutionProviderWithOptions(provider_options))); + ASSERT_STATUS_OK(session_object.Load(model_data.data(), static_cast(model_data.size()))); + // Verify the return status with code INVALID_GRAPH + ASSERT_TRUE(session_object.Initialize().Code() == common::StatusCode::OK); + + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); + ASSERT_EQ(std::remove(context_bin.string().c_str()), 0); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test From 76a2a487a12c7ec579f453a36932429164494ef6 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 22 Feb 2024 13:58:17 -0800 Subject: [PATCH 129/207] Bump ip from 1.1.8 to 1.1.9 in /js/react_native/e2e (#19583) Bumps [ip](https://github.com/indutny/node-ip) from 1.1.8 to 1.1.9.
Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=ip&package-manager=npm_and_yarn&previous-version=1.1.8&new-version=1.1.9)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) Dependabot will merge this PR once CI passes on it, as requested by @fs-eire. [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/microsoft/onnxruntime/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- js/react_native/e2e/yarn.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/react_native/e2e/yarn.lock b/js/react_native/e2e/yarn.lock index 9e20a286c4e27..6f05faf046098 100644 --- a/js/react_native/e2e/yarn.lock +++ b/js/react_native/e2e/yarn.lock @@ -3351,9 +3351,9 @@ invariant@^2.2.4: loose-envify "^1.0.0" ip@^1.1.5: - version "1.1.8" - resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48" - integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg== + version "1.1.9" + resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396" + integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ== is-accessor-descriptor@^0.1.6: version "0.1.6" From 5e5c36f6df95dfbb25787ea385f733f8c9ef691e Mon Sep 17 00:00:00 2001 From: AtomicVar Date: Fri, 23 Feb 2024 09:03:56 +0800 Subject: [PATCH 130/207] Fix citation author name issue (#19597) Use `name` rather than `given-names` to set author name. ### Motivation and Context The old CITATION.cff uses `given-names` to set author names, which won't be rendered properly with some bibtex style of LaTeX: image The problem is that **the `"ONNX Runtime developers"` is regarded as a human name**. How to fix: by using `name` to set author name, the generated Bibtex entry will use `{}` to enclose the `"ONNX Runtime developers"`. Then it is displayed literally: image --- CITATION.cff | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CITATION.cff b/CITATION.cff index 82bcac5a7b750..10b7290022aef 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -3,8 +3,7 @@ title: ONNX Runtime message: "Please use this information to cite ONNX Runtime in research or other publications." authors: - - affiliation: Microsoft Corporation - given-names: ONNX Runtime developers + - name: ONNX Runtime developers date-released: 2018-11-29 url: "https://onnxruntime.ai" repository-code: "https://github.com/microsoft/onnxruntime" From 4ab497603e915ca992b96ef1ec25bfcf8b9a2ad5 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 22 Feb 2024 17:04:59 -0800 Subject: [PATCH 131/207] Enable user to set QNN HTP performance mode for every session run (#19521) ### Description Currently, the QNN HTP performance mode is set during session creation, there's no way to change it afterwards. There's requirement to set it high performance mode for high priority request and set it back to low performance mode later to save the power when the incoming request is idle for example. Now, still keeps the performance mode at the session level in QNN EP options which is used at the default one. Ort QNN EP will set it once if user set it. And there are setting (qnn.htp_perf_mode and qnn.htp_perf_mode_post_run) in run option to change the performance mode before and after session run. There's recommended scenario that user set the mode to high performance mode before the the inference sun so that user can get the result back ASAP. And set the mode to low performance mode after the inference to save the power. --- .../core/framework/execution_provider.h | 10 +- .../onnxruntime_run_options_config_keys.h | 12 + .../framework/stream_execution_context.cc | 4 +- .../providers/cann/cann_execution_provider.cc | 2 +- .../providers/cann/cann_execution_provider.h | 2 +- .../providers/cuda/cuda_execution_provider.cc | 4 +- .../providers/cuda/cuda_execution_provider.h | 5 +- .../src/ExecutionProvider.h | 4 +- .../providers/js/js_execution_provider.cc | 4 +- .../core/providers/js/js_execution_provider.h | 4 +- .../migraphx/migraphx_execution_provider.cc | 4 +- .../migraphx/migraphx_execution_provider.h | 4 +- .../qnn/builder/qnn_backend_manager.cc | 75 +++--- .../qnn/builder/qnn_backend_manager.h | 19 +- .../providers/qnn/qnn_execution_provider.cc | 198 +++++++++++++++- .../providers/qnn/qnn_execution_provider.h | 73 +++++- .../providers/rocm/rocm_execution_provider.cc | 4 +- .../providers/rocm/rocm_execution_provider.h | 4 +- .../tensorrt/tensorrt_execution_provider.cc | 4 +- .../tensorrt/tensorrt_execution_provider.h | 4 +- onnxruntime/core/session/inference_session.cc | 12 +- .../cuda_execution_provider_test.cc | 13 +- .../test/providers/qnn/qnn_basic_test.cc | 217 ++++++++++++++++-- 23 files changed, 577 insertions(+), 105 deletions(-) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 31c988f500779..c1cc69edc17d8 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -33,6 +33,8 @@ class Node; #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" +struct OrtRunOptions; + namespace onnxruntime { /** @@ -51,6 +53,8 @@ struct NodeComputeInfo { DestroyFunctionStateFunc release_state_func; }; +using RunOptions = OrtRunOptions; + enum class DataLayout { NCHW, NHWC, @@ -184,7 +188,7 @@ class IExecutionProvider { Run may not be finished on device This function should be regarded as the point after which a new Run would start to submit commands from CPU */ - virtual common::Status OnRunStart() { return Status::OK(); } + virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } /** Called when InferenceSession::Run ended @@ -192,7 +196,9 @@ class IExecutionProvider { may not be finished on device This function should be regarded as the point that all commands of current Run has been submmited by CPU */ - virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); } + virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) { + return Status::OK(); + } /** Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h index 1f5fcd50e185c..b0a17e175fef3 100644 --- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h @@ -30,3 +30,15 @@ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memor // Per default it will be set to '0' // Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream. static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers"; + +// Set HTP performance mode for QNN HTP backend before session run. +// options for HTP performance mode: "burst", "balanced", "default", "high_performance", +// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", +// "sustained_high_performance". Default to "default". +static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode"; + +// Set HTP performance mode for QNN HTP backend post session run. +static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run"; + +// Set RPC control latency for QNN HTP backend +static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency"; diff --git a/onnxruntime/core/framework/stream_execution_context.cc b/onnxruntime/core/framework/stream_execution_context.cc index 875e7f395bfa8..dd7f4d35b34bd 100644 --- a/onnxruntime/core/framework/stream_execution_context.cc +++ b/onnxruntime/core/framework/stream_execution_context.cc @@ -181,11 +181,13 @@ void RunSince(size_t stream_idx, StreamExecutionContext& ctx, SessionScope& sess } #ifdef USE_CANN + // Leave it to CANN EP to fill the gap if they want to use run_options + static onnxruntime::RunOptions run_options; // For CANN EP, it is necessary to explicitly create a corresponding Context for each thread in the thread pool, // which is different from CUDA Runtime API, but similar to CUDA Driver API. auto& execution_providers = ctx.GetSessionState().GetExecutionProviders(); for (auto& xp : execution_providers) { - auto status = xp->OnRunStart(); + auto status = xp->OnRunStart(run_options); if (!status.IsOK()) { ctx.SetStatus(status); return; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 752b742805a7c..9a242919665bb 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1045,7 +1045,7 @@ CANNExecutionProvider::~CANNExecutionProvider() { } // All threads share the same context and stream -Status CANNExecutionProvider::OnRunStart() { +Status CANNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { CANN_RETURN_IF_ERROR(aclrtSetDevice(info_.device_id)); return Status::OK(); diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index 63ae980869c65..d83bd88d6958f 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -33,7 +33,7 @@ class CANNExecutionProvider : public IExecutionProvider { explicit CANNExecutionProvider(const CANNExecutionProviderInfo& info); virtual ~CANNExecutionProvider(); - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; template Status Fill(Tensor* y, void* addr, aclrtStream stream) const { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 48a952e6dd98f..0dd568c5ecc05 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -386,7 +386,7 @@ Status CUDAExecutionProvider::Sync() const { return Status::OK(); } -Status CUDAExecutionProvider::OnRunStart() { +Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { // always set CUDA device when session::Run() in case it runs in a worker thread CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId())); if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { @@ -396,7 +396,7 @@ Status CUDAExecutionProvider::OnRunStart() { return Status::OK(); } -Status CUDAExecutionProvider::OnRunEnd(bool sync_stream) { +Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { if (GetPerThreadContext().IsGraphCaptureAllowed()) { GetPerThreadContext().CaptureEnd(); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 55f0b5570e0ee..5f62f313b86a2 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -29,9 +29,9 @@ class CUDAExecutionProvider : public IExecutionProvider { Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; DataLayout GetPreferredLayout() const override; @@ -115,6 +115,7 @@ class CUDAExecutionProvider : public IExecutionProvider { PerThreadContext(OrtDevice::DeviceId device_id, cudaStream_t stream, size_t cuda_mem_limit, ArenaExtendStrategy arena_extend_strategy, CUDAExecutionProviderExternalAllocatorInfo external_alloc_info, OrtArenaCfg* arena_cfg); ~PerThreadContext(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); cublasHandle_t CublasHandle() const { return cublas_handle_; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 5617bc7bdcac6..841d6244a983e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -270,7 +270,7 @@ namespace Dml return m_impl->OnSessionInitializationEnd(); } - virtual onnxruntime::Status Sync() const final override + onnxruntime::Status Sync() const final override { // Completely wait until the device has completed all preceding tasks. // The application could have called SynchronizeBoundOutputs(). @@ -278,7 +278,7 @@ namespace Dml return Status::OK(); } - virtual onnxruntime::Status OnRunEnd(bool /*sync_stream*/) final override + onnxruntime::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) final override { // Flush any pending work to the GPU, but don't block for completion, permitting it // to overlap other work. diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 799d4172f2b64..62c3981682cfc 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -756,7 +756,7 @@ std::unique_ptr JsExecutionProvider::GetDataTransfer JsExecutionProvider::~JsExecutionProvider() { } -Status JsExecutionProvider::OnRunStart() { +Status JsExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; EM_ASM({ Module.jsepCaptureBegin(); }); @@ -764,7 +764,7 @@ Status JsExecutionProvider::OnRunStart() { return Status::OK(); } -Status JsExecutionProvider::OnRunEnd(bool sync_stream) { +Status JsExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { if (IsGraphCaptureAllowed()) { EM_ASM({ Module.jsepCaptureEnd(); }); diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 91a3256ec2bd5..b4518c67d1e60 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -59,8 +59,8 @@ class JsExecutionProvider : public IExecutionProvider { std::vector CreatePreferredAllocators() override; - Status OnRunStart() override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; bool IsGraphCaptureEnabled() const override; bool IsGraphCaptured() const override; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 40e76a0a67782..50782569ee80a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1383,11 +1383,11 @@ Status MIGraphXExecutionProvider::Sync() const { return Status::OK(); } -Status MIGraphXExecutionProvider::OnRunStart() { +Status MIGraphXExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } -Status MIGraphXExecutionProvider::OnRunEnd(bool) { +Status MIGraphXExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) { auto status = hipStreamQuery(stream_); if (status != hipSuccess) { diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index d582338c7e067..c3617f409e72c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -56,9 +56,9 @@ class MIGraphXExecutionProvider : public IExecutionProvider { #ifdef MIGRAPHX_STREAM_SYNC Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; #endif std::vector> diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index ca34a1efa6ca7..e354bf6562722 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -634,11 +634,6 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ LOGS(logger, VERBOSE) << "CreateContext succeed."; } - if (htp_performance_mode_ != HtpPerformanceMode::kHtpDefault) { - ORT_RETURN_IF_ERROR(SetHtpPowerConfig()); - LOGS(logger, VERBOSE) << "SetHtpPowerConfig succeed."; - } - LOGS(logger, VERBOSE) << "QNN SetupBackend succeed"; backend_setup_completed_ = true; @@ -646,7 +641,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ return Status::OK(); } -Status QnnBackendManager::SetHtpPowerConfig() { +Status QnnBackendManager::CreateHtpPowerCfgId(uint32_t device_id, uint32_t core_id, uint32_t& htp_power_config_id) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -656,23 +651,37 @@ Status QnnBackendManager::SetHtpPowerConfig() { "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; // Get power client id - status = htp_perf_infra.createPowerConfigId(/*device_id=*/0, /*core_id=*/0, &htp_power_config_client_id_); + status = htp_perf_infra.createPowerConfigId(device_id, core_id, &htp_power_config_id); ORT_RETURN_IF(QNN_SUCCESS != status, "createPowerConfigId failed."); + return Status::OK(); +} + +Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, + HtpPerformanceMode htp_performance_mode) { + QnnDevice_Infrastructure_t qnn_device_infra = nullptr; + auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); + ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); + + auto* htp_infra = static_cast(qnn_device_infra); + ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType, + "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); + QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; + constexpr const int kNumConfigs = 1; std::vector power_configs( kNumConfigs); QnnHtpPerfInfrastructure_PowerConfig_t& dcvs_config = power_configs[0]; dcvs_config.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_DCVS_V3; QnnHtpPerfInfrastructure_DcvsV3_t& dcvs_v3 = dcvs_config.dcvsV3Config; - dcvs_v3.contextId = htp_power_config_client_id_; + dcvs_v3.contextId = htp_power_config_client_id; dcvs_v3.setSleepDisable = 0; dcvs_v3.sleepDisable = 0; dcvs_v3.setDcvsEnable = 1; dcvs_v3.dcvsEnable = kDcvsDisable; dcvs_v3.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_PERFORMANCE_MODE; // choose performance mode - switch (htp_performance_mode_) { + switch (htp_performance_mode) { case HtpPerformanceMode::kHtpBurst: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMinLatency; @@ -771,25 +780,40 @@ Status QnnBackendManager::SetHtpPowerConfig() { dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS; break; default: - ORT_THROW("Invalid performance profile %d", static_cast(htp_performance_mode_)); + ORT_THROW("Invalid performance profile %d", static_cast(htp_performance_mode)); break; } std::vector perf_power_configs_ptr = ObtainNullTermPtrVector(power_configs); - status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr.data()); + status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for HTP performance mode."); - // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. - if (rpc_control_latency_ != 0) { + return Status::OK(); +} + +Status QnnBackendManager::SetRpcControlLatency(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency) { + if (rpc_control_latency != 0) { + QnnDevice_Infrastructure_t qnn_device_infra = nullptr; + auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); + ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); + + auto* htp_infra = static_cast(qnn_device_infra); + ORT_RETURN_IF(QNN_HTP_DEVICE_INFRASTRUCTURE_TYPE_PERF != htp_infra->infraType, + "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); + QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; + + // Set rpc control latency here, but note that v68 doesn't support rpc polling mode. constexpr int kNumRpcPollingPowerConfigs = 2; std::vector rpc_power_configs(kNumRpcPollingPowerConfigs); - QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency = rpc_power_configs[0]; + QnnHtpPerfInfrastructure_PowerConfig_t& rpc_control_latency_cfg = rpc_power_configs[0]; // v68 doesn't support this. QnnHtpPerfInfrastructure_PowerConfig_t& rpc_polling_time = rpc_power_configs[1]; - rpc_control_latency.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; + rpc_control_latency_cfg.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_CONTROL_LATENCY; rpc_polling_time.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_RPC_POLLING_TIME; - rpc_control_latency.rpcControlLatencyConfig = rpc_control_latency_; - perf_power_configs_ptr = ObtainNullTermPtrVector(rpc_power_configs); - status = htp_perf_infra.setPowerConfig(htp_power_config_client_id_, perf_power_configs_ptr.data()); + rpc_control_latency_cfg.rpcControlLatencyConfig = rpc_control_latency; + std::vector perf_power_configs_ptr = + ObtainNullTermPtrVector(rpc_power_configs); + status = htp_perf_infra.setPowerConfig(htp_power_config_client_id, perf_power_configs_ptr.data()); ORT_RETURN_IF(QNN_SUCCESS != status, "setPowerConfig failed for RPC control latency."); } @@ -810,11 +834,7 @@ void QnnBackendManager::Split(std::vector& split_string, } } -Status QnnBackendManager::DestroyHTPPowerConfigID() { - if (htp_performance_mode_ == HtpPerformanceMode::kHtpDefault) { - return Status::OK(); - } - +Status QnnBackendManager::DestroyHTPPowerConfigID(uint32_t htp_power_config_id) { QnnDevice_Infrastructure_t qnn_device_infra = nullptr; auto status = qnn_interface_.deviceGetInfrastructure(&qnn_device_infra); ORT_RETURN_IF(QNN_SUCCESS != status, "backendGetPerfInfrastructure failed."); @@ -824,7 +844,7 @@ Status QnnBackendManager::DestroyHTPPowerConfigID() { "HTP infra type = ", htp_infra->infraType, ", which is not perf infra type."); QnnHtpDevice_PerfInfrastructure_t& htp_perf_infra = htp_infra->perfInfra; - Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_client_id_); + Qnn_ErrorHandle_t destroy_ret = htp_perf_infra.destroyPowerConfigId(htp_power_config_id); ORT_RETURN_IF(QNN_SUCCESS != destroy_ret, "destroyPowerConfigId failed."); return Status::OK(); } @@ -834,12 +854,7 @@ void QnnBackendManager::ReleaseResources() { return; } - auto result = DestroyHTPPowerConfigID(); - if (Status::OK() != result) { - ORT_THROW("Failed to DestroyHTPPowerConfigID."); - } - - result = ReleaseContext(); + auto result = ReleaseContext(); if (Status::OK() != result) { ORT_THROW("Failed to ReleaseContext."); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 36375522b5a0a..ff97c4c3a991c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -33,8 +33,6 @@ class QnnBackendManager { public: QnnBackendManager(std::string&& backend_path, ProfilingLevel profiling_level, - uint32_t rpc_control_latency, - HtpPerformanceMode htp_performance_mode, ContextPriority context_priority, std::string&& qnn_saver_path, uint32_t device_id, @@ -42,8 +40,6 @@ class QnnBackendManager { uint32_t soc_model) : backend_path_(backend_path), profiling_level_(profiling_level), - rpc_control_latency_(rpc_control_latency), - htp_performance_mode_(htp_performance_mode), context_priority_(context_priority), qnn_saver_path_(qnn_saver_path), device_id_(device_id), @@ -92,7 +88,13 @@ class QnnBackendManager { Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); - Status SetHtpPowerConfig(); + Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id); + + Status SetHtpPowerConfig(uint32_t htp_power_config_client_id, + HtpPerformanceMode htp_performance_mode); + + Status SetRpcControlLatency(uint32_t htp_power_config_client_id, + uint32_t rpc_control_latency); const QNN_INTERFACE_VER_TYPE& GetQnnInterface() { return qnn_interface_; } @@ -141,6 +143,8 @@ class QnnBackendManager { const std::string& GetSdkVersion() { return sdk_build_version_; } + Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id); + private: void* LoadLib(const char* file_name, int flags, std::string& error_msg); @@ -150,8 +154,6 @@ class QnnBackendManager { Status UnloadLib(void* handle); - Status DestroyHTPPowerConfigID(); - void* LibFunction(void* handle, const char* symbol, std::string& error_msg); template @@ -232,15 +234,12 @@ class QnnBackendManager { QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; Qnn_ProfileHandle_t profile_backend_handle_ = nullptr; std::vector op_package_paths_; - uint32_t rpc_control_latency_ = 0; - HtpPerformanceMode htp_performance_mode_; ContextPriority context_priority_; std::string sdk_build_version_ = ""; #ifdef _WIN32 std::set mod_handles_; #endif const std::string qnn_saver_path_; - uint32_t htp_power_config_client_id_ = 0; uint32_t device_id_ = 0; QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE; uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 9a6540a3efea5..3d9cfd92b7922 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -7,6 +7,7 @@ #include "core/framework/compute_capability.h" #include "core/graph/graph_viewer.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/kernel_registry.h" #include "core/platform/env.h" @@ -18,11 +19,36 @@ #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" +#include "core/framework/run_options.h" namespace onnxruntime { constexpr const char* QNN = "QNN"; +static std::unique_ptr>> s_run_on_unload_; + +void RunOnUnload(std::function function) { + OrtMutex mutex; + std::lock_guard guard(mutex); + if (!s_run_on_unload_) { + s_run_on_unload_ = std::make_unique>>(); + } + s_run_on_unload_->push_back(std::move(function)); +} + +struct OnUnload { + ~OnUnload() { + if (!s_run_on_unload_) + return; + + for (auto& function : *s_run_on_unload_) + function(); + + s_run_on_unload_.reset(); + } + +} g_on_unload; + static void ParseProfilingLevel(std::string profiling_level_string, qnn::ProfilingLevel& profiling_level) { std::transform(profiling_level_string.begin(), @@ -193,18 +219,18 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string RPC_CONTROL_LANTENCY = "rpc_control_latency"; - uint32_t rpc_control_latency = 0; auto latency_pos = provider_options_map.find(RPC_CONTROL_LANTENCY); if (latency_pos != provider_options_map.end()) { - rpc_control_latency = static_cast(std::stoul(latency_pos->second)); - LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; + default_rpc_control_latency_ = static_cast(std::stoul(latency_pos->second)); + LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << default_rpc_control_latency_; } - qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + // default_htp_performance_mode from QNN EP option. + // set it once only for each thread as default so user don't need to set it for every session run static const std::string HTP_PERFORMANCE_MODE = "htp_performance_mode"; auto htp_performance_mode_pos = provider_options_map.find(HTP_PERFORMANCE_MODE); if (htp_performance_mode_pos != provider_options_map.end()) { - ParseHtpPerformanceMode(htp_performance_mode_pos->second, htp_performance_mode); + ParseHtpPerformanceMode(htp_performance_mode_pos->second, default_htp_performance_mode_); } htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; @@ -241,15 +267,14 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string QNN_DEVICE_ID = "device_id"; - uint32_t device_id = 0; auto dev_id_pos = provider_options_map.find(QNN_DEVICE_ID); if (dev_id_pos != provider_options_map.end()) { int value = std::stoi(dev_id_pos->second); if (value < 0) { LOGS_DEFAULT(WARNING) << "Invalid device ID '" << value - << "', only >= 0 allowed. Set to " << device_id << "."; + << "', only >= 0 allowed. Set to " << device_id_ << "."; } else { - device_id = static_cast(value); + device_id_ = static_cast(value); } } @@ -276,15 +301,23 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level, - rpc_control_latency, - htp_performance_mode, context_priority, std::move(qnn_saver_path), - device_id, + device_id_, htp_arch, soc_model); } +QNNExecutionProvider::~QNNExecutionProvider() { + // clean up thread local context caches + std::lock_guard lock(context_state_.mutex); + for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) { + const auto cache = cache_weak.lock(); + if (!cache) continue; + ORT_IGNORE_RETURN_VALUE(cache->erase(this)); + } +} + bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const { const std::string& op_type = node_unit.OpType(); @@ -725,4 +758,147 @@ const InlinedVector QNNExecutionProvider::GetEpContextNodes() const return ep_context_nodes; } + +QNNExecutionProvider::PerThreadContext::PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, + uint32_t device_id, + uint32_t core_id, + qnn::HtpPerformanceMode default_htp_performance_mode, + uint32_t default_rpc_control_latency) + : qnn_backend_manager_(qnn_backend_manager) { + Status rt = qnn_backend_manager_->CreateHtpPowerCfgId(device_id, core_id, htp_power_config_id_); + is_htp_power_config_id_valid_ = rt.IsOK(); + // default_htp_performance_mode and default_rpc_control_latency are from QNN EP option. + // set it once only for each thread as default so user don't need to set it for every session run + if (is_htp_power_config_id_valid_) { + if (qnn::HtpPerformanceMode::kHtpDefault != default_htp_performance_mode) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetHtpPowerConfig(htp_power_config_id_, + default_htp_performance_mode)); + } + if (default_rpc_control_latency > 0) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->SetRpcControlLatency(htp_power_config_id_, + default_rpc_control_latency)); + } + } +} + +QNNExecutionProvider::PerThreadContext::~PerThreadContext() { + if (is_htp_power_config_id_valid_) { + ORT_IGNORE_RETURN_VALUE(qnn_backend_manager_->DestroyHTPPowerConfigID(htp_power_config_id_)); + } +} + +QNNExecutionProvider::PerThreadContext& QNNExecutionProvider::GetPerThreadContext() const { + const auto& per_thread_context_cache = PerThreadContextCache(); + + // try to use cached context + auto cached_context_it = per_thread_context_cache->find(this); + if (cached_context_it != per_thread_context_cache->end()) { + auto cached_context = cached_context_it->second.lock(); + ORT_ENFORCE(cached_context); + return *cached_context; + } + + // get context and update cache + std::shared_ptr context; + { + std::lock_guard lock(context_state_.mutex); + + // get or create a context + if (context_state_.retired_context_pool.empty()) { + uint32_t core_id = 0; + context = std::make_shared(qnn_backend_manager_.get(), device_id_, core_id, + default_htp_performance_mode_, default_rpc_control_latency_); + } else { + context = context_state_.retired_context_pool.back(); + context_state_.retired_context_pool.pop_back(); + } + + // insert into active_contexts, should not already be present + const auto active_contexts_insert_result = context_state_.active_contexts.insert(context); + ORT_ENFORCE(active_contexts_insert_result.second); + + // insert into caches_to_update_on_destruction, may already be present + ORT_IGNORE_RETURN_VALUE(context_state_.caches_to_update_on_destruction.insert(per_thread_context_cache)); + } + + per_thread_context_cache->insert(std::make_pair(this, context)); + + return *context; +} + +void QNNExecutionProvider::ReleasePerThreadContext() const { + const auto& per_thread_context_cache = PerThreadContextCache(); + + auto cached_context_it = per_thread_context_cache->find(this); + ORT_ENFORCE(cached_context_it != per_thread_context_cache->end()); + auto cached_context = cached_context_it->second.lock(); + ORT_ENFORCE(cached_context); + + { + std::lock_guard lock(context_state_.mutex); + context_state_.active_contexts.erase(cached_context); + context_state_.retired_context_pool.push_back(cached_context); + } + + per_thread_context_cache->erase(cached_context_it); +} + +Status QNNExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { + auto backend_type = qnn_backend_manager_->GetQnnBackendType(); + if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) { + return Status::OK(); + } + + std::string htp_perf_mode = ""; + qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnPerfMode, htp_perf_mode)) { + // set power mode + ParseHtpPerformanceMode(htp_perf_mode, htp_performance_mode); + } + + std::string rpc_latency = ""; + uint32_t rpc_control_latency = 0; + if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnRpcControlLatency, rpc_latency)) { + rpc_control_latency = static_cast(std::stoul(rpc_latency)); + LOGS_DEFAULT(VERBOSE) << "rpc_control_latency: " << rpc_control_latency; + } + + if (GetPerThreadContext().IsHtpPowerConfigIdValid()) { + if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), + htp_performance_mode)); + } + + if (rpc_control_latency > 0) { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetRpcControlLatency(GetPerThreadContext().GetHtpPowerConfigId(), + rpc_control_latency)); + } + } + + return Status::OK(); +} + +Status QNNExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& run_options) { + auto backend_type = qnn_backend_manager_->GetQnnBackendType(); + if (qnn::QnnBackendType::HTP != backend_type && qnn::QnnBackendType::DSP != backend_type) { + return Status::OK(); + } + + std::string htp_perf_mode = ""; + qnn::HtpPerformanceMode htp_performance_mode = qnn::HtpPerformanceMode::kHtpDefault; + if (run_options.config_options.TryGetConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, htp_perf_mode)) { + // set power mode + ParseHtpPerformanceMode(htp_perf_mode, htp_performance_mode); + } + + if (qnn::HtpPerformanceMode::kHtpDefault != htp_performance_mode) { + if (!GetPerThreadContext().IsHtpPowerConfigIdValid()) { + return Status::OK(); + } + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetHtpPowerConfig(GetPerThreadContext().GetHtpPowerConfigId(), + htp_performance_mode)); + } + + return Status::OK(); +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 0bcaa39b22f6d..43b5e7bff827e 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -12,14 +12,19 @@ #include "core/providers/qnn/builder/qnn_model.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" #include "HTP/QnnHtpGraph.h" +#include +#include +#include namespace onnxruntime { +void RunOnUnload(std::function function); + // Logical device representation. class QNNExecutionProvider : public IExecutionProvider { public: explicit QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options); - virtual ~QNNExecutionProvider() = default; + virtual ~QNNExecutionProvider(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QNNExecutionProvider); // we implement the Compile that takes FusedNodeAndGraph instances @@ -40,6 +45,10 @@ class QNNExecutionProvider : public IExecutionProvider { const InlinedVector GetEpContextNodes() const override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + private: bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger) const; @@ -72,6 +81,68 @@ class QNNExecutionProvider : public IExecutionProvider { int32_t vtcm_size_in_mb_ = 0; std::unique_ptr qnn_ep_context_model_; ModelMetadefIdGenerator metadef_id_generator_; + uint32_t device_id_ = 0; + qnn::HtpPerformanceMode default_htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; + uint32_t default_rpc_control_latency_ = 0; + + class PerThreadContext final { + public: + PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, + uint32_t device_id, uint32_t core_id, + qnn::HtpPerformanceMode default_htp_performance_mode, + uint32_t default_rpc_control_latency); + ~PerThreadContext(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(PerThreadContext); + + bool IsHtpPowerConfigIdValid() { return is_htp_power_config_id_valid_; } + + uint32_t GetHtpPowerConfigId() { return htp_power_config_id_; } + + private: + bool is_htp_power_config_id_valid_ = false; + uint32_t htp_power_config_id_ = 0; + qnn::QnnBackendManager* qnn_backend_manager_; + }; + + using PerThreadContextMap = std::unordered_map>; + + struct ContextCacheHolder { + ContextCacheHolder() { + RunOnUnload([&, weak_p_ = std::weak_ptr(p)] { + if (auto lock = weak_p_.lock()) + p.reset(); + }); + } + + std::shared_ptr p = std::make_shared(); + }; + + static const std::shared_ptr& PerThreadContextCache() { + thread_local const ContextCacheHolder per_thread_context_cache; + return per_thread_context_cache.p; + } + + struct PerThreadContextState { + // contexts that are currently active + std::set, std::owner_less>> active_contexts; + // contexts available for reuse + std::vector> retired_context_pool; + // weak references to thread local caches from which this QNNExecutionProvider instance's entry should be removed + // upon destruction + std::set, std::owner_less>> + caches_to_update_on_destruction; + // synchronizes access to PerThreadContextState members + OrtMutex mutex; + }; + + // The execution provider maintains the PerThreadContexts in this structure. + // Synchronization is required to update the contained structures. + // On the other hand, access to an individual PerThreadContext is assumed to be from a single thread at a time, + // so synchronization is not required for that. + mutable PerThreadContextState context_state_; + + PerThreadContext& GetPerThreadContext() const; + void ReleasePerThreadContext() const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index ee3578326ac6d..3fd5423681b81 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -353,7 +353,7 @@ Status ROCMExecutionProvider::Sync() const { return Status::OK(); } -Status ROCMExecutionProvider::OnRunStart() { +Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { @@ -363,7 +363,7 @@ Status ROCMExecutionProvider::OnRunStart() { return Status::OK(); } -Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { +Status ROCMExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { if (GetPerThreadContext().IsGraphCaptureAllowed()) { GetPerThreadContext().CaptureEnd(); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index 37d5f7b42210f..da671d9e863bb 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -28,9 +28,9 @@ class ROCMExecutionProvider : public IExecutionProvider { Status Sync() const override; - Status OnRunStart() override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; const void* GetExecutionHandle() const noexcept override { // The ROCM interface does not return anything interesting. diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index c0bf29e486c88..81346671f2aad 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1818,11 +1818,11 @@ std::unique_ptr TensorrtExecutionProvider::GetDataTransfer() cons return onnxruntime::CreateGPUDataTransfer(); } -Status TensorrtExecutionProvider::OnRunStart() { +Status TensorrtExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); } -Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { +Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { if (sync_stream && external_stream_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index e86f997b6597a..26f6b2dcc3020 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -233,8 +233,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; - Status OnRunStart() override; - Status OnRunEnd(bool sync_stream) override; + Status OnRunStart(const onnxruntime::RunOptions& run_options) override; + Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; ProviderOptions GetProviderOptions() const override { return TensorrtExecutionProviderInfo::ToProviderOptions(info_); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b045f30a59797..efd7db4ea7629 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2289,8 +2289,8 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, // TODO: only call OnRunStart for all providers in-use for (auto& xp : execution_providers_) { // call OnRunStart and add to exec_providers_to_stop if successful - auto start_func = [&xp, &exec_providers_to_stop]() { - auto status = xp->OnRunStart(); + auto start_func = [&xp, &exec_providers_to_stop, run_options]() { + auto status = xp->OnRunStart(run_options); if (status.IsOK()) exec_providers_to_stop.push_back(xp.get()); @@ -2326,7 +2326,7 @@ Status InferenceSession::PartialRun(onnxruntime::RunOptions& run_options, // info all execution providers InferenceSession:Run ended for (auto* xp : exec_providers_to_stop) { - auto status = xp->OnRunEnd(/*sync_stream*/ false); + auto status = xp->OnRunEnd(/*sync_stream*/ false, run_options); ORT_CHECK_AND_SET_RETVAL(status); } @@ -2448,8 +2448,8 @@ Status InferenceSession::Run(const RunOptions& run_options, // TODO: only call OnRunStart for all providers in-use for (auto& xp : execution_providers_) { // call OnRunStart and add to exec_providers_to_stop if successful - auto start_func = [&xp, &exec_providers_to_stop]() { - auto status = xp->OnRunStart(); + auto start_func = [&xp, &exec_providers_to_stop, &run_options]() { + auto status = xp->OnRunStart(run_options); if (status.IsOK()) exec_providers_to_stop.push_back(xp.get()); @@ -2490,7 +2490,7 @@ Status InferenceSession::Run(const RunOptions& run_options, // info all execution providers InferenceSession:Run ended for (auto* xp : exec_providers_to_stop) { bool synchronize_execution_providers = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigDisableSynchronizeExecutionProviders, "0") == "0"; - auto status = xp->OnRunEnd(synchronize_execution_providers); + auto status = xp->OnRunEnd(synchronize_execution_providers, run_options); ORT_CHECK_AND_SET_RETVAL(status); } diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index a70e439cdf755..5505d689381c9 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -22,6 +22,8 @@ TEST(TestDeferredRelease, WithArena) { CUDAExecutionProvider ep(info); AllocatorPtr gpu_alloctor = ep.CreatePreferredAllocators()[0]; + RunOptions run_opts; + run_opts.run_tag = "log1"; // Allocator for call cudaMallocHost and cudaFreeHost // For details, see CUDAPinnedAllocator in cuda_allocator.cc. AllocatorPtr cpu_pinned_alloc = ep.CreatePreferredAllocators()[1]; @@ -31,7 +33,7 @@ TEST(TestDeferredRelease, WithArena) { // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; - ORT_THROW_IF_ERROR(ep.OnRunStart()); + ORT_THROW_IF_ERROR(ep.OnRunStart(run_opts)); for (size_t i = 0; i < n_allocs; ++i) { // Allocate 10MB CUDA pinned memory. auto pinned_buffer = IAllocator::MakeUniquePtr(cpu_pinned_alloc, n_bytes); @@ -44,7 +46,7 @@ TEST(TestDeferredRelease, WithArena) { cpu_pinned_alloc->GetStats(&stats); ASSERT_EQ(stats.num_allocs, n_allocs); ORT_THROW_IF_ERROR(stream.CleanUpOnRunEnd()); - ORT_THROW_IF_ERROR(ep.OnRunEnd(true)); + ORT_THROW_IF_ERROR(ep.OnRunEnd(true, run_opts)); } TEST(TestDeferredRelease, WithoutArena) { @@ -52,6 +54,9 @@ TEST(TestDeferredRelease, WithoutArena) { CUDAExecutionProviderInfo info; CUDAExecutionProvider ep(info); + RunOptions run_opts; + run_opts.run_tag = "log1"; + OrtDevice pinned_device{OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, DEFAULT_CPU_ALLOCATOR_DEVICE_ID}; // Create allocator without BFCArena AllocatorCreationInfo pinned_memory_info( @@ -70,7 +75,7 @@ TEST(TestDeferredRelease, WithoutArena) { // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; - ORT_THROW_IF_ERROR(ep.OnRunStart()); + ORT_THROW_IF_ERROR(ep.OnRunStart(run_opts)); for (size_t i = 0; i < n_allocs; ++i) { // Allocate 10MB CUDA pinned memory. auto pinned_buffer = IAllocator::MakeUniquePtr(cuda_pinned_alloc, n_bytes); @@ -79,7 +84,7 @@ TEST(TestDeferredRelease, WithoutArena) { } ORT_THROW_IF_ERROR(stream.CleanUpOnRunEnd()); - ORT_THROW_IF_ERROR(ep.OnRunEnd(true)); + ORT_THROW_IF_ERROR(ep.OnRunEnd(true, run_opts)); } } // namespace test diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 4e1aef2c40b2b..8f07c2ce77e77 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -7,6 +7,7 @@ #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU #include "core/session/inference_session.h" @@ -332,19 +333,23 @@ static void CreateModelInMemory(std::unique_ptr& result, static void RunSessionAndVerify(InferenceSession& session, const RunOptions& run_options, const NameMLValMap& feeds, const std::vector& output_names, const std::vector>& output_shapes, - const std::vector>& expected_values) { - std::vector fetches; - auto status = session.Run(run_options, feeds, output_names, &fetches); - ASSERT_TRUE(status.IsOK()); - - for (size_t i = 0; i < fetches.size(); i++) { - auto& tensor = fetches[i].Get(); - TensorShape expected_shape(output_shapes[i]); - ASSERT_EQ(expected_shape, tensor.Shape()); - - gsl::span actual = tensor.DataAsSpan(); - gsl::span expected(expected_values[i].data(), expected_values[i].size()); - ASSERT_EQ(expected, actual); + const std::vector>& expected_values, + int loop_count = 10) { + // Let it run for a while + for (int it = 0; it < loop_count; ++it) { + std::vector fetches; + auto status = session.Run(run_options, feeds, output_names, &fetches); + ASSERT_TRUE(status.IsOK()); + + for (size_t i = 0; i < fetches.size(); i++) { + auto& tensor = fetches[i].Get(); + TensorShape expected_shape(output_shapes[i]); + ASSERT_EQ(expected_shape, tensor.Shape()); + + gsl::span actual = tensor.DataAsSpan(); + gsl::span expected(expected_values[i].data(), expected_values[i].size()); + ASSERT_EQ(expected, actual); + } } } @@ -404,11 +409,11 @@ TEST_F(QnnCPUBackendTests, MultithreadSessionRun) { std::vector threads; constexpr int num_threads = 5; - + constexpr int loop_count = 10; for (int i = 0; i < num_threads; i++) { threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, model->builder.feeds_, model->builder.output_names_, - output_shapes, output_values)); + output_shapes, output_values, loop_count)); } for (auto& th : threads) { @@ -484,11 +489,191 @@ TEST_F(QnnHTPBackendTests, MultithreadSessionRun) { std::vector threads; constexpr int num_threads = 5; + constexpr int loop_count = 10; for (int i = 0; i < num_threads; i++) { threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, model->builder.feeds_, model->builder.output_names_, - output_shapes, output_values)); + output_shapes, output_values, loop_count)); + } + + for (auto& th : threads) { + th.join(); + } +} + +// Tests running a single session in multiple threads on the HTP backend with run option to set power config +TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgSessionRunOption) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + constexpr int loop_count = 10; + + std::vector perf_modes{ + "burst", "balanced", "default", "high_performance", "high_power_saver", + "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver"}; + + size_t post_i = perf_modes.size() - 1; + ASSERT_TRUE(post_i > num_threads); + for (int i = 0; i < num_threads; ++i, --post_i) { + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + auto rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfMode, perf_modes[i].c_str()); + ASSERT_TRUE(rt.IsOK()); + rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, perf_modes[post_i].c_str()); + ASSERT_TRUE(rt.IsOK()); + + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values, loop_count)); + } + + for (auto& th : threads) { + th.join(); + } +} + +// Tests running a single session in multiple threads on the HTP backend with EP option to set default power config +TEST_F(QnnHTPBackendTests, MultithreadDefaultHtpPowerCfgFromEpOption) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + options["htp_performance_mode"] = "burst"; + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + constexpr int loop_count = 10; + + for (int i = 0; i < num_threads; i++) { + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values, loop_count)); + } + + for (auto& th : threads) { + th.join(); + } +} + +// Tests running a single session in multiple threads on the HTP backend with +// EP option to set default power config + run option to set power config for each run +TEST_F(QnnHTPBackendTests, MultithreadHtpPowerCfgDefaultAndRunOption) { + std::unique_ptr model; + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector shape = {1, 3, 2}; + std::vector> output_shapes = {shape}; + std::vector> output_values = {{3.0f, 6.0f, 9.0f, 12.0f, 15.0f, 18.0f}}; + + CreateModelInMemory(model, + QDQBuildAdd3Tensors(TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data), + TestInputDef(shape, false, input_data)), + "add3.qdq"); + + SessionOptions session_opts; + session_opts.session_logid = "logger0"; + + InferenceSession session_obj{session_opts, GetEnvironment()}; + onnxruntime::ProviderOptions options; + +#if defined(_WIN32) + options["backend_path"] = "QnnHtp.dll"; +#else + options["backend_path"] = "libQnnHtp.so"; +#endif + options["htp_performance_mode"] = "burst"; + + auto qnn_ep = QnnExecutionProviderWithOptions(options, &session_opts); + EXPECT_TRUE(session_obj.RegisterExecutionProvider(std::move(qnn_ep)).IsOK()); + + auto status = session_obj.Load(model->model_data.data(), static_cast(model->model_data.size())); + ASSERT_TRUE(status.IsOK()); + status = session_obj.Initialize(); + ASSERT_TRUE(status.IsOK()); + + std::vector threads; + constexpr int num_threads = 5; + constexpr int loop_count = 10; + + std::vector perf_modes{ + "burst", "balanced", "default", "high_performance", "high_power_saver", + "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver"}; + + size_t post_i = perf_modes.size() - 1; + ASSERT_TRUE(post_i > num_threads); + for (int i = 0; i < num_threads; ++i, --post_i) { + RunOptions run_opts; + run_opts.run_tag = session_opts.session_logid; + auto rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfMode, perf_modes[i].c_str()); + ASSERT_TRUE(rt.IsOK()); + rt = run_opts.config_options.AddConfigEntry(kOrtRunOptionsConfigQnnPerfModePostRun, perf_modes[post_i].c_str()); + ASSERT_TRUE(rt.IsOK()); + + threads.push_back(std::thread(RunSessionAndVerify, std::ref(session_obj), run_opts, + model->builder.feeds_, model->builder.output_names_, + output_shapes, output_values, loop_count)); } for (auto& th : threads) { From 29b1106033e291947debb49c3fd03feb479c4b1b Mon Sep 17 00:00:00 2001 From: Segev Finer Date: Fri, 23 Feb 2024 04:53:50 +0200 Subject: [PATCH 132/207] [node] Switch to setImmediate to avoid starving the Node.js event loop (#19610) ### Description Switch to setImmediate to avoid starving the Node.js event loop There should really be a true async version though, running computationally intensive things on the event loop will stop everything else from happening while it is running, e.g. a web server from answering requests. This can be done by wrapping `RunAsync` behind a [`napi::Promise`](https://github.com/nodejs/node-addon-api/blob/main/doc/promises.md) to run on the onnxruntime thread pool or [`AsyncWorker`]( https://github.com/nodejs/node-addon-api/blob/main/doc/async_worker.md) for the Node.js/libuv thread pool. ### Motivation and Context Without this, if you run inference in a tight loop, without anything else in between that is async/deferred, `process.nextTick` will lead to starving the event loop and not letting anything else run, `setImmediate` at least lets the event loop spin between calls to `run`. See https://dev.to/ynmanware/setimmediate-settimeout-and-process-nexttick-3mfd Contributed on behalf of [Swimm](https://swimm.io/) --- js/node/lib/backend.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index e8eb0e9babf5a..927953b4f1dd6 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -36,7 +36,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler { async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions): Promise { return new Promise((resolve, reject) => { - process.nextTick(() => { + setImmediate(() => { try { resolve(this.#inferenceSession.run(feeds, fetches, options)); } catch (e) { @@ -56,7 +56,7 @@ class OnnxruntimeBackend implements Backend { async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { return new Promise((resolve, reject) => { - process.nextTick(() => { + setImmediate(() => { try { resolve(new OnnxruntimeSessionHandler(pathOrBuffer, options || {})); } catch (e) { From ae92d593c0e2b06decbea64797f9145bc10f34af Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 23 Feb 2024 11:05:16 +0800 Subject: [PATCH 133/207] ONNX Gelu Op in Opset 20 (#19560) ### ONNX Gelu Op in Opset 20 Refactor code to support MSDomain Gelu and ONNX Gelu-opset20 Op 1. Move CPU-GELU implmentation from `onnxruntime/contrib_ops/cpu/activations.h/cc` to `onnxruntime/core/providers/cpu/tensor/gelu.h/cc`, as the implementation for approximate attribute to be 'none'. 2. Dumplicate some logic from `onnxruntime/contrib_ops/cpu/bert/bias_gelu.cc` to `onnxruntime/core/providers/cpu/tensor/gelu.h/cc`, as the implementation for approximate attribute to be 'tanh'. 3. Register ONNX domain Gelu CPU kernel from opset 20 in `onnxruntime/core/providers/cpu/cpu_execution_provider.cc`. 4. Move `onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h/cu` to `onnxruntime/core/providers/cuda/tensor/gelu_impl.h` and `onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu` respectively, as the implementation for approximate attribute to be 'tanh'. 5. Implement the logic for approximate attribute to be 'none' in `onnxruntime/core/providers/cuda/tensor/gelu_impl.cu`. 6. Register ONNX domain Gelu CUDA kernel from opset 20 in `onnxruntime/core/providers/cuda/cuda_execution_provider.cc`. 7. ROCM ep related changes. 8. Enrich the tests for ONNX domain Gelu in `onnxruntime/test/providers/cpu/activation/activation_op_test.cc`. --- cmake/onnxruntime_rocm_hipify.cmake | 4 - .../InferenceTest.netcore.cs | 2 +- docs/OperatorKernels.md | 2 + .../core/providers/cuda/cuda_resource.h | 2 +- onnxruntime/contrib_ops/cpu/activations.cc | 10 +- onnxruntime/contrib_ops/cpu/activations.h | 41 ------- .../cuda/activation/activations.cc | 1 - .../contrib_ops/cuda/activation/activations.h | 11 -- .../cuda/activation/activations_impl.cu | 14 --- .../cuda/activation/activations_impl.h | 2 - .../contrib_ops/cuda/bert/fast_gelu.cc | 20 +++- onnxruntime/contrib_ops/cuda/bert/fast_gelu.h | 2 +- .../contrib_ops/rocm/bert/fast_gelu.cc | 59 ---------- onnxruntime/contrib_ops/rocm/bert/fast_gelu.h | 24 ---- .../providers/cpu/cpu_execution_provider.cc | 2 + onnxruntime/core/providers/cpu/tensor/gelu.cc | 108 ++++++++++++++++++ onnxruntime/core/providers/cpu/tensor/gelu.h | 18 +++ .../providers/cuda/cuda_execution_provider.cc | 10 ++ .../core/providers/cuda/tensor/gelu.cc | 89 +++++++++++++++ onnxruntime/core/providers/cuda/tensor/gelu.h | 28 +++++ .../cuda/tensor/gelu_approximate_impl.cu} | 17 ++- .../core/providers/cuda/tensor/gelu_impl.cu | 48 ++++++++ .../providers/cuda/tensor/gelu_impl.h} | 7 +- .../test/contrib_ops/activation_op_test.cc | 13 ++- .../test/onnx/microbenchmark/activation.cc | 3 +- .../cpu/activation/activation_op_test.cc | 48 ++++++-- .../cpu/activation/activation_op_test.h | 7 +- 27 files changed, 395 insertions(+), 197 deletions(-) delete mode 100644 onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc delete mode 100644 onnxruntime/contrib_ops/rocm/bert/fast_gelu.h create mode 100644 onnxruntime/core/providers/cpu/tensor/gelu.cc create mode 100644 onnxruntime/core/providers/cpu/tensor/gelu.h create mode 100644 onnxruntime/core/providers/cuda/tensor/gelu.cc create mode 100644 onnxruntime/core/providers/cuda/tensor/gelu.h rename onnxruntime/{contrib_ops/cuda/bert/fast_gelu_impl.cu => core/providers/cuda/tensor/gelu_approximate_impl.cu} (88%) create mode 100644 onnxruntime/core/providers/cuda/tensor/gelu_impl.cu rename onnxruntime/{contrib_ops/cuda/bert/fast_gelu_impl.h => core/providers/cuda/tensor/gelu_impl.h} (80%) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 85a9bf50460d3..1bb70e9c2ed27 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -20,10 +20,6 @@ set(contrib_ops_excluded_files "bert/fastertransformer_decoder_attention/*" "bert/multihead_attention.cc" "bert/multihead_attention.h" - "bert/fast_gelu_impl.cu" - "bert/fast_gelu_impl.h" - "bert/fast_gelu.cc" - "bert/fast_gelu.h" "bert/relative_attn_bias.cc" "bert/relative_attn_bias.h" "bert/relative_attn_bias_impl.cu" diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index 715aed7e1d64f..7f3d5d6624b07 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -145,7 +145,7 @@ private void TestCUDAProviderOptions() private void CanRunInferenceOnAModelWithTensorRT() { string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx"); - + int deviceId = 0; string deviceIdStr = System.Environment.GetEnvironmentVariable("ONNXRUNTIME_TEST_GPU_DEVICE_ID"); if (!string.IsNullOrEmpty(deviceIdStr) && int.TryParse(deviceIdStr, out int parsedValue) && parsedValue >= 0) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8ff2135c6b1f6..46149c577a106 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -127,6 +127,7 @@ Do not modify directly.* |GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| |||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| |||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**indices** = tensor(int64)| +|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(float)| |Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[9, 10]|**T** = tensor(double), tensor(float)| @@ -606,6 +607,7 @@ Do not modify directly.* |GatherND|*in* data:**T**
*in* indices:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)| |||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)| |||11|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)
**indices** = tensor(int64)| +|Gelu|*in* X:**T**
*out* Y:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16)| |Gemm|*in* A:**T**
*in* B:**T**
*in* C:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index 1fef077860be3..00e7dec5727d1 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -19,4 +19,4 @@ enum CudaResource : int { enable_skip_layer_norm_strict_mode_t, prefer_nhwc_t, use_tf32_t, -}; \ No newline at end of file +}; diff --git a/onnxruntime/contrib_ops/cpu/activations.cc b/onnxruntime/contrib_ops/cpu/activations.cc index 556699192d2eb..3e0533dd8b9e5 100644 --- a/onnxruntime/contrib_ops/cpu/activations.cc +++ b/onnxruntime/contrib_ops/cpu/activations.cc @@ -2,7 +2,7 @@ // Licensed under the MIT License. #include "core/providers/cpu/activation/activations.h" -#include "activations.h" +#include "contrib_ops/cpu/activations.h" namespace onnxruntime { namespace contrib { @@ -26,14 +26,6 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType()), ThresholdedRelu); -ONNX_OPERATOR_KERNEL_EX( - Gelu, - kMSDomain, - 1, - kCpuExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - Gelu); - ONNX_OPERATOR_KERNEL_EX( QuickGelu, kMSDomain, diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index aed4c2229215d..7e64235d3fc3d 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -54,47 +54,6 @@ namespace contrib { DEFINE_ELE_KERNEL(ScaledTanh); DEFINE_ELE_KERNEL(ParametricSoftplus); -template -class Gelu : public OpKernel { - public: - Gelu(const OpKernelInfo& info) : OpKernel(info) { - } - - Status Compute(OpKernelContext* context) const override { - const Tensor* input = context->Input(0); - const T* input_data = input->Data(); - - Tensor* output = context->Output(0, input->Shape()); - T* output_data = output->MutableData(); - - concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); - int64_t elem_count = input->Shape().Size(); - constexpr int64_t length_per_task = 4096; // this number comes from FastGelu. - int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; - concurrency::ThreadPool::TryBatchParallelFor( - tp, static_cast(task_count), - [&](ptrdiff_t task_idx) { - const auto start = task_idx * length_per_task; - const T* p_input = input_data + start; - T* p_output = output_data + start; - int64_t count = std::min(length_per_task, elem_count - start); - - for (int64_t i = 0; i < count; i++) { - T value = p_input[i]; - p_output[i] = value * static_cast(M_SQRT1_2); - } - - MlasComputeErf(p_output, p_output, narrow(count)); - - for (int64_t i = 0; i < count; i++) { - p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); - } - }, - 0); - return Status::OK(); - } -}; - // Implement a new one instead of inheriting from ElementWiseRangedTransform so that we can call // MlasComputeLogistic instead of using Eigen for better perf. template diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.cc b/onnxruntime/contrib_ops/cuda/activation/activations.cc index 1a86c5dbece5a..6303858b9bd48 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations.cc +++ b/onnxruntime/contrib_ops/cuda/activation/activations.cc @@ -49,7 +49,6 @@ namespace cuda { UNARY_ACTIVATION_OP_HFD(Affine, 1, kOnnxDomain); UNARY_ACTIVATION_OP_HFD(ParametricSoftplus, 1, kOnnxDomain); UNARY_ACTIVATION_OP_HFD(ScaledTanh, 1, kOnnxDomain); -UNARY_ACTIVATION_OP_HFD(Gelu, 1, kMSDomain); UNARY_ACTIVATION_OP_HFD(QuickGelu, 1, kMSDomain); REGISTER_ACTIVATION_KERNEL(ThresholdedRelu, 1, kOnnxDomain, MLFloat16) diff --git a/onnxruntime/contrib_ops/cuda/activation/activations.h b/onnxruntime/contrib_ops/cuda/activation/activations.h index ab339f276c2bd..fc9a71b0b7fa1 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations.h +++ b/onnxruntime/contrib_ops/cuda/activation/activations.h @@ -66,17 +66,6 @@ class ScaledTanh final : public UnaryElementwise { float beta_; }; -template -class Gelu final : public UnaryElementwise { - public: - Gelu(const OpKernelInfo& info) : UnaryElementwise(info) {} - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - MAKE_FUNC_CTX_NULL() -}; - template class QuickGelu final : public UnaryElementwise { public: diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu index 0c856815fd437..36f33fbb24c18 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.cu @@ -36,20 +36,6 @@ struct OP_ScaledTanh : public CtxScaledTanh { } }; -template -struct OP_Gelu : public CtxGelu { - __device__ __inline__ T operator()(const T& a) const { - return _Gelu(a); - } -}; - -template <> -struct OP_Gelu : public CtxGelu { - __device__ __inline__ half operator()(const half& a) const { - return static_cast(_Gelu(static_cast(a))); - } -}; - template struct OP_QuickGelu : public CtxQuickGelu { __device__ __inline__ T operator()(const T& a) const { diff --git a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h index 5d18283a395e3..782d4bf59a5ad 100644 --- a/onnxruntime/contrib_ops/cuda/activation/activations_impl.h +++ b/onnxruntime/contrib_ops/cuda/activation/activations_impl.h @@ -11,14 +11,12 @@ namespace cuda { typedef onnxruntime::cuda::CtxAlphaBeta CtxAffine; typedef onnxruntime::cuda::CtxAlphaBeta CtxParametricSoftplus; typedef onnxruntime::cuda::CtxAlphaBeta CtxScaledTanh; -typedef onnxruntime::cuda::CtxNull CtxGelu; typedef onnxruntime::cuda::CtxAlpha CtxQuickGelu; #define UNARY_CONTRIB_ACTIVATION_OPS() \ UNARY_ACTIVATION_OP_NAME(ScaledTanh) \ UNARY_ACTIVATION_OP_NAME(Affine) \ UNARY_ACTIVATION_OP_NAME(ParametricSoftplus) \ - UNARY_ACTIVATION_OP_NAME(Gelu) \ UNARY_ACTIVATION_OP_NAME(QuickGelu) #define UNARY_ACTIVATION_OP_NAME(name) UNARY_ACTIVATION_IMPL_DECLARATION(name); diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index 892f5c181a607..e8974a29476b6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -4,9 +4,14 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cudnn_common.h" #include "fast_gelu.h" -#include "fast_gelu_impl.h" +#include "core/providers/cuda/tensor/gelu_impl.h" #include "contrib_ops/cpu/bert/bias_gelu_helper.h" -#include "transformer_common.h" +#ifdef USE_ROCM +#include "contrib_ops/rocm/bert/elementwise.h" +#endif +#ifdef USE_CUDA +#include "contrib_ops/cuda/bert/transformer_common.h" +#endif namespace onnxruntime { namespace contrib { @@ -31,8 +36,10 @@ using namespace ONNX_NAMESPACE; template FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { +#ifdef USE_CUDA const TransformerOptions* options = TransformerOptions::GetInstance(); use_half2_ = !options->DisableHalf2(); +#endif } template @@ -50,6 +57,14 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); typedef typename ToCudaType::MappedType CudaT; +#ifdef USE_ROCM + return LaunchElementwiseKernel( + GetTuningContext(), context->GetComputeStream(), + reinterpret_cast(input->Data()), static_cast(input_length), + (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, static_cast(bias_length), + reinterpret_cast(output->MutableData())); +#endif +#ifdef USE_CUDA return LaunchFastGeluKernel(GetDeviceProp(), Stream(context), static_cast(input_length), @@ -58,6 +73,7 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, reinterpret_cast(output->MutableData()), use_half2_); +#endif } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h index 3e642a70afef5..d563556593e6e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h @@ -18,7 +18,7 @@ class FastGelu final : public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; private: - bool use_half2_; + bool use_half2_; // Only applicable to CUDA kernel (not ROCM). }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc deleted file mode 100644 index 9cb414e4e8980..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "contrib_ops/rocm/bert/fast_gelu.h" - -#include "core/providers/rocm/rocm_common.h" -#include "core/providers/rocm/miopen_common.h" -#include "contrib_ops/cpu/bert/bias_gelu_helper.h" -#include "contrib_ops/rocm/bert/elementwise.h" -#include "contrib_ops/rocm/bert/transformer_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - FastGelu, \ - kMSDomain, \ - 1, \ - T, \ - kRocmExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - FastGelu); - -REGISTER_KERNEL_TYPED(float) -REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) - -using namespace ONNX_NAMESPACE; - -template -Status FastGelu::ComputeInternal(OpKernelContext* context) const { - ORT_RETURN_IF_ERROR(bias_gelu_helper::CheckInputs(context)); - - const Tensor* input = context->Input(0); - const Tensor* bias = context->Input(1); - Tensor* output = context->Output(0, input->Shape()); - - int64_t input_length = input->Shape().Size(); - if (input_length == 0) { - return Status::OK(); - } - int64_t bias_length = (nullptr == bias) ? 0 : bias->Shape().Size(); - typedef typename ToHipType::MappedType HipT; - - const HipT* input_buffer = reinterpret_cast(input->Data()); - const HipT* bias_buffer = (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr; - return LaunchElementwiseKernel( - GetTuningContext(), context->GetComputeStream(), - input_buffer, static_cast(input_length), - bias_buffer, static_cast(bias_length), - reinterpret_cast(output->MutableData())); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h b/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h deleted file mode 100644 index 42bfe5a0b0246..0000000000000 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/rocm/rocm_kernel.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -using namespace onnxruntime::rocm; - -template -class FastGelu final : public RocmKernel { - public: - FastGelu(const OpKernelInfo& op_kernel_info) : RocmKernel(op_kernel_info) {} - Status ComputeInternal(OpKernelContext* ctx) const override; -}; - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 813fdc54ecd0d..48e4617b33b4d 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1035,6 +1035,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Gelu); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN); @@ -2562,6 +2563,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.cc b/onnxruntime/core/providers/cpu/tensor/gelu.cc new file mode 100644 index 0000000000000..d55973eda180f --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/gelu.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/common.h" +#include "core/common/narrow.h" +#include "core/framework/op_kernel.h" +#include "core/util/math_cpuonly.h" +#include "core/mlas/inc/mlas.h" + +#include "core/platform/threadpool.h" +#include +#include "core/providers/cpu/element_wise_ranged_transform.h" +#include "core/providers/cpu/tensor/gelu.h" + +using onnxruntime::narrow; +using namespace onnxruntime::common; + +namespace onnxruntime { + +// May revisit the implementations to support inplace computation, if needed. + +ONNX_CPU_OPERATOR_KERNEL( + Gelu, + 20, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gelu); + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib { +ONNX_OPERATOR_KERNEL_EX( + Gelu, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + Gelu); +} +#endif + +template +Status Gelu::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const T* input_data = input->Data(); + + Tensor* output = context->Output(0, input->Shape()); + T* output_data = output->MutableData(); + + concurrency::ThreadPool* tp = context->GetOperatorThreadPool(); + int64_t elem_count = input->Shape().Size(); + constexpr int64_t length_per_task = 4096; // this number comes from FastGelu. + int64_t task_count = (elem_count + length_per_task - 1) / length_per_task; + + if (approximation_algorithm_ == "tanh") { + // FastGelu allows optional bias. Here we split input data into chunks. Each chunk + // has N elements (except the last chunk), and use thread pool to parallel chunks. + // N = 4096 is selected based on performance test results on input shape 1x128x768. + // FastGelu uses approximation for Gelu. The formula is 0.5 * (1 + Tanh(x * (C * x * x + B))) * x. + static constexpr float B = 0.7978845608028654f; // sqrt(2.0 / M_PI) + static constexpr float C = 0.035677408136300125f; // 0.044715 * sqrt(2.0 / M_PI) + + concurrency::ThreadPool::TryBatchParallelFor( + tp, static_cast(task_count), + [&](ptrdiff_t task_idx) { + const auto start = task_idx * length_per_task; + const T* p_input = input_data + start; + T* p_output = output_data + start; + int64_t count = std::min(length_per_task, elem_count - start); + + for (int64_t i = 0; i < count; i++) { + T value = p_input[i]; + p_output[i] = value * (static_cast(C) * value * value + static_cast(B)); + } + + MlasComputeTanh(p_output, p_output, narrow(count)); + + for (int64_t i = 0; i < count; i++) { + p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); + } + }, + 0); + return Status::OK(); + } else if (approximation_algorithm_ == "none") { + concurrency::ThreadPool::TryBatchParallelFor( + tp, static_cast(task_count), + [&](ptrdiff_t task_idx) { + const auto start = task_idx * length_per_task; + const T* p_input = input_data + start; + T* p_output = output_data + start; + int64_t count = std::min(length_per_task, elem_count - start); + + for (int64_t i = 0; i < count; i++) { + T value = p_input[i]; + p_output[i] = value * static_cast(M_SQRT1_2); + } + + MlasComputeErf(p_output, p_output, narrow(count)); + + for (int64_t i = 0; i < count; i++) { + p_output[i] = 0.5f * p_input[i] * (p_output[i] + 1.0f); + } + }, + 0); + return Status::OK(); + } + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/gelu.h b/onnxruntime/core/providers/cpu/tensor/gelu.h new file mode 100644 index 0000000000000..13238028d878a --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/gelu.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +namespace onnxruntime { + +template +class Gelu final : public OpKernel { + public: + explicit Gelu(const OpKernelInfo& info) : OpKernel(info) { + approximation_algorithm_ = info.GetAttrOrDefault("approximate", "none"); + } + Status Compute(OpKernelContext* ctx) const override; + + private: + std::string approximation_algorithm_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 0dd568c5ecc05..be2530aec49fa 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1329,6 +1329,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Shape); #endif +// Opset 20 +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); + template <> KernelCreateInfo BuildKernelCreateInfo() { return {}; @@ -2222,6 +2227,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // Opset 20 + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/tensor/gelu.cc b/onnxruntime/core/providers/cuda/tensor/gelu.cc new file mode 100644 index 0000000000000..67b2fad373a7f --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/gelu.cc @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/tensor/gelu.h" +#include "core/providers/cuda/tensor/gelu_impl.h" + +namespace onnxruntime { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Gelu, \ + kOnnxDomain, \ + 20, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .MayInplace(0, 0), \ + Gelu); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(double) + +template +Status Gelu::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() < 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 0 is expected to have 1 or more dimensions, got ", input_dims.size()); + } + + Tensor* output = context->Output(0, input->Shape()); + + int64_t input_length = input->Shape().Size(); + if (input_length == 0) { + return Status::OK(); + } + + typedef typename ToCudaType::MappedType CudaT; + + if (approximation_algorithm_ == "tanh") { + return LaunchFastGeluKernel(GetDeviceProp(), + Stream(context), + static_cast(input_length), + 0 /* no bias */, + reinterpret_cast(input->Data()), + nullptr /* no bias */, + reinterpret_cast(output->MutableData()), + use_half2_); + } else if (approximation_algorithm_ == "none") { + return LaunchGeluKernel(Stream(context), + reinterpret_cast(input->Data()), + reinterpret_cast(output->MutableData()), + static_cast(input_length)); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported approximation_algorithm: ", approximation_algorithm_); +} + +} // namespace cuda + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib::cuda { +#define REGISTER_CONTRIB_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Gelu, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .MayInplace(0, 0), \ + onnxruntime::cuda::Gelu); + +REGISTER_CONTRIB_KERNEL_TYPED(float) +REGISTER_CONTRIB_KERNEL_TYPED(MLFloat16) +REGISTER_CONTRIB_KERNEL_TYPED(double) + +#undef REGISTER_CONTRIB_KERNEL_TYPED +} // namespace contrib::cuda +#endif + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/gelu.h b/onnxruntime/core/providers/cuda/tensor/gelu.h new file mode 100644 index 0000000000000..1c8189ab24121 --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/gelu.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/math/unary_elementwise_ops.h" + +namespace onnxruntime { +namespace cuda { + +template +class Gelu final : public UnaryElementwise { + public: + Gelu(const OpKernelInfo& info) : UnaryElementwise(info) { + approximation_algorithm_ = info.GetAttrOrDefault("approximate", "none"); + } + + Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + const bool use_half2_{true}; + + std::string approximation_algorithm_; +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu similarity index 88% rename from onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu rename to onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu index c9498eb1bcd7b..3292650584de8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu @@ -24,12 +24,9 @@ limitations under the License. #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/shared_inc/cuda_call.h" -#include "contrib_ops/cuda/bert/fast_gelu_impl.h" - -using namespace onnxruntime::cuda; +#include "core/providers/cuda/tensor/gelu_impl.h" namespace onnxruntime { -namespace contrib { namespace cuda { // constants for approximating the normal cdf @@ -75,6 +72,17 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int return CUDA_CALL(cudaGetLastError()); } +template <> +Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, + const double* input, const double* bias, double* output, bool /*use_half2*/) { + constexpr int blockSize = 256; + const int gridSize = (input_length + blockSize - 1) / blockSize; + FastGeluKernel<<>>(A, B, C, input_length, bias_length, + input, bias, output); + + return CUDA_CALL(cudaGetLastError()); +} + template <> Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, const half* input, const half* bias, half* output, bool use_half2) { @@ -114,5 +122,4 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int } } // namespace cuda -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu new file mode 100644 index 0000000000000..3f96da38b37bb --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/gelu_impl.cu @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/cuda/tensor/gelu_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cu_inc/unary_elementwise_impl.cuh" + +namespace onnxruntime { +namespace cuda { + +template +struct OP_Gelu { + __device__ __inline__ T operator()(const T& a) const { + return _Gelu(a); + } +}; + +template <> +struct OP_Gelu { + __device__ __inline__ half operator()(const half& a) const { + return static_cast(_Gelu(static_cast(a))); + } +}; + +template +Status LaunchGeluKernel( + cudaStream_t stream, + const T* input_data, + T* output_data, + size_t count) { + UnaryElementWiseImpl(stream, input_data, output_data, OP_Gelu(), count); + + return CUDA_CALL(cudaGetLastError()); +} + +#define SPECIALIZED_GELU_IMPL(T) \ + template Status LaunchGeluKernel(cudaStream_t stream, const T* input_data, T* output_data, \ + size_t count); + +SPECIALIZED_GELU_IMPL(float); +SPECIALIZED_GELU_IMPL(half); +SPECIALIZED_GELU_IMPL(double); + +#undef SPECIALIZED_GELU_IMPL + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h b/onnxruntime/core/providers/cuda/tensor/gelu_impl.h similarity index 80% rename from onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h rename to onnxruntime/core/providers/cuda/tensor/gelu_impl.h index ba78310f5dfc2..2ea0d3441fda3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/gelu_impl.h @@ -1,17 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #pragma once + #include "core/common/common.h" namespace onnxruntime { -namespace contrib { namespace cuda { +template +Status LaunchGeluKernel(cudaStream_t stream, const T* input, T* output, size_t count); + template Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, const T* input, const T* bias, T* output, bool use_half2); } // namespace cuda -} // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/activation_op_test.cc b/onnxruntime/test/contrib_ops/activation_op_test.cc index b1e54ec605a39..2a56991ec5af4 100644 --- a/onnxruntime/test/contrib_ops/activation_op_test.cc +++ b/onnxruntime/test/contrib_ops/activation_op_test.cc @@ -22,7 +22,8 @@ namespace test { TEST_F(ActivationOpTest, ThresholdedRelu_version_1_to_9) { float alpha = 0.1f; TestActivationOp( - "ThresholdedRelu", input_values, [alpha](float x) { return (x >= alpha) ? x : 0; }, {{"alpha", alpha}}, true, 1); + "ThresholdedRelu", input_values, [alpha](float x) { return (x >= alpha) ? x : 0; }, {{"alpha", alpha}}, {}, + true, 1); } TEST_F(ActivationOpTest, ScaledTanh) { @@ -46,13 +47,13 @@ TEST_F(ActivationOpTest, ParametricSoftplus) { else return alpha * logf(expf(bx) + 1); }, - {{"alpha", alpha}, {"beta", beta}}, false); // Disable TensorRT due to result mismatch + {{"alpha", alpha}, {"beta", beta}}, {}, false); // Disable TensorRT due to result mismatch } TEST_F(ActivationOpTest, Gelu) { TestActivationOp( "Gelu", input_values, [](float x) { return x * 0.5f * (1.0f + std::erf(x * static_cast(M_SQRT1_2))); }, {}, - false, 1, kMSDomain); + {}, false, 1, kMSDomain); } #if defined(USE_DNNL) @@ -115,7 +116,7 @@ TEST_F(ActivationOpTest, QuickGelu) { y = tmp >= 0 ? y : 1 - y; return x * y; }, - {{"alpha", alpha}}, false, 1, kMSDomain); + {{"alpha", alpha}}, {}, false, 1, kMSDomain); } // Silu = x*sigmoid(x), i.e., alpha = 1.0f. @@ -129,7 +130,7 @@ TEST_F(ActivationOpTest, QuickGelu) { y = tmp >= 0 ? y : 1 - y; return x * y; }, - {{"alpha", alpha}}, false, 1, kMSDomain); + {{"alpha", alpha}}, {}, false, 1, kMSDomain); } // Negative alpha. @@ -143,7 +144,7 @@ TEST_F(ActivationOpTest, QuickGelu) { y = tmp >= 0 ? y : 1 - y; return x * y; }, - {{"alpha", alpha}}, false, 1, kMSDomain); + {{"alpha", alpha}}, {}, false, 1, kMSDomain); } } diff --git a/onnxruntime/test/onnx/microbenchmark/activation.cc b/onnxruntime/test/onnx/microbenchmark/activation.cc index cf859facf4765..69ee72996365e 100644 --- a/onnxruntime/test/onnx/microbenchmark/activation.cc +++ b/onnxruntime/test/onnx/microbenchmark/activation.cc @@ -11,6 +11,7 @@ #include "core/framework/node_index_info.h" #include "core/framework/execution_frame.h" #include "contrib_ops/cpu/activations.h" +#include "core/providers/cpu/tensor/gelu.h" #include "core/providers/cpu/activation/activations.h" #include #include @@ -182,7 +183,7 @@ static void RunSingleNode(const std::string& op_name, const std::string& domain, } static void BM_GeluCompute(benchmark::State& state) { - RunSingleNode>("Gelu", kMSDomain, {}, state); + RunSingleNode>("Gelu", kMSDomain, {}, state); } BENCHMARK(BM_GeluCompute) diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc index ddb0a6620619c..acd513172f95d 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.cc +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.cc @@ -116,13 +116,13 @@ TEST_F(ActivationOpTest, Relu) { "Relu", input_values_double, [](double x) { return std::max(x, 0.0); }, - {}, + {}, {}, /*is_tensorrt_supported=*/false); TestActivationOp( "Relu", input_values_int8, [](int8_t x) { return std::max(x, static_cast(0)); }, - {}, + {}, {}, /*is_tensorrt_supported=*/false, /*opset_version= */ 14); #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -133,7 +133,7 @@ TEST_F(ActivationOpTest, Relu) { if (x.ToFloat() > 0.0f) return x; return MLFloat16(); }, - {}, + {}, {}, /*is_tensorrt_supported=*/false, /*opset_version= */ 11); #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -402,7 +402,7 @@ TEST_F(ActivationOpTest, Celu) { // TODO: Investigate why gcc 4 fails to compile without the explicit cast [alpha](float x) { return std::max(0.0f, x) + std::min(0.0f, alpha * (static_cast(exp(x / alpha)) - 1)); }, // Disable on TensorRT as it seems like it doesn't yet support Celu - {{"alpha", alpha}}, false, 12); + {{"alpha", alpha}}, {}, false, 12); } TEST_F(ActivationOpTest, LeakyRelu) { @@ -410,7 +410,7 @@ TEST_F(ActivationOpTest, LeakyRelu) { TestActivationOp("LeakyRelu", input_values, [alpha](float x) { return (x >= 0) ? x : alpha * x; }, - {{"alpha", alpha}}); + {{"alpha", alpha}}, {}); } #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED @@ -442,7 +442,7 @@ TEST_F(ActivationOpTest, ThresholdedRelu) { "ThresholdedRelu", input_values, [alpha](float x) { return (x >= alpha) ? x : 0; }, - {{"alpha", alpha}}, true, 10); + {{"alpha", alpha}}, {}, true, 10); } TEST_F(ActivationOpTest, Selu) { @@ -452,7 +452,7 @@ TEST_F(ActivationOpTest, Selu) { TestActivationOp("Selu", input_values, [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; }, - {{"alpha", alpha}, {"gamma", gamma}}); + {{"alpha", alpha}, {"gamma", gamma}}, {}); } TEST_F(ActivationOpTest, Selu_Attributes) { @@ -462,7 +462,7 @@ TEST_F(ActivationOpTest, Selu_Attributes) { TestActivationOp("Selu", input_values, [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; }, - {{"alpha", alpha}, {"gamma", gamma}}); + {{"alpha", alpha}, {"gamma", gamma}}, {}); } TEST_F(ActivationOpTest, Selu_GH10726) { @@ -472,7 +472,7 @@ TEST_F(ActivationOpTest, Selu_GH10726) { TestActivationOp("Selu", {{1.f, -1.f}}, [](float x) { return x <= 0 ? gamma * (alpha * exp(x) - alpha) : gamma * x; }, - {{"alpha", alpha}, {"gamma", gamma}}); + {{"alpha", alpha}, {"gamma", gamma}}, {}); } TEST_F(ActivationOpTest, PRelu) { @@ -625,7 +625,7 @@ TEST_F(ActivationOpNoInfTest, Softsign) { return result; }, - {}, false); // Disable TensorRT because result mismatches + {}, {}, false); // Disable TensorRT because result mismatches } #if defined(ENABLE_TRAINING_OPS) @@ -695,5 +695,33 @@ TEST(LeakyReluGradInferenceTest, Basic) { } #endif +// Remove DNNL from running this test because DNNL Gelu op seems not check domain for kernel implementation. +// It will run the DNNL Gelu op which only be part of standard of Gelu-20 op. +#if !defined(USE_DNNL) && !defined(USE_QNN) +TEST_F(ActivationOpTest, ONNX_Gelu) { + TestActivationOp( + "Gelu", + input_values, + [](float x) { return 0.5 * x * (1 + erf(x * M_SQRT1_2)); }, {}, + {{"approximate", "none"}}, true, 20); + + TestActivationOp( + "Gelu", + input_values, + [](float x) { return 0.5 * x * (1 + erf(x * M_SQRT1_2)); }, + {}, + {/*default value of approximate attribute is none */}, true, 20); + + TestActivationOp( + "Gelu", + input_values, + [](float x) { + return 0.5 * x * (1 + tanh(sqrt(2 / M_PI) * (x + 0.044715 * x * x * x))); + }, + {}, + {{"approximate", "tanh"}}, true, 20); +} +#endif + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/activation/activation_op_test.h b/onnxruntime/test/providers/cpu/activation/activation_op_test.h index b5ec1402584fb..984b8f4437a3b 100644 --- a/onnxruntime/test/providers/cpu/activation/activation_op_test.h +++ b/onnxruntime/test/providers/cpu/activation/activation_op_test.h @@ -17,13 +17,16 @@ namespace test { template inline void TestActivationOp(const char* szOp, const std::vector>& input_vals_vec, std::function expected_func, - const std::unordered_map attribs = {}, + const std::unordered_map float_attribs = {}, + const std::unordered_map string_attribs = {}, bool is_tensorrt_supported = true, int opset_version = 7, const char* domain = kOnnxDomain) { for (const std::vector& input_vals : input_vals_vec) { OpTester test(szOp, opset_version, domain); - for (auto attr : attribs) test.AddAttribute(attr.first, attr.second); + for (auto attr : float_attribs) test.AddAttribute(attr.first, attr.second); + for (auto attr : string_attribs) test.AddAttribute(attr.first, attr.second); + std::vector dims{(int64_t)input_vals.size()}; std::vector expected_vals; From 5e432a3ae69dbbed603420493c52ba48b3726471 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Fri, 23 Feb 2024 04:47:15 +0100 Subject: [PATCH 134/207] Add support for NHWC GridSample in the CUDA EP and enable grid_sample_test for all EPs (#19562) I've added NHWC GridSample support to the CUDA EP to reduce the number of layout transforms. Also I've enabled the full set of GridSampleTests for all EPs. I've also added the GridSample OpSet 16 to the registered kernels. ### Motivation and Context This is the first PR is a series of enhancements of the CUDA EP improving NHWC support to avoid costly layout transforms between NWHC and NCHW nodes which are layout sensitive. Also testing was quite rudimentary for the CUDA EP while it was great for the CPU path. I've regenerated grid_sample_test.cc enabling tests for other platforms as well. Those tests resurfaced #10607 again which is fixed as well. --- docs/OperatorKernels.md | 1 + .../contrib_ops/cuda/cuda_contrib_kernels.cc | 7 + onnxruntime/contrib_ops/cuda/grid_sample.cc | 35 ++-- onnxruntime/contrib_ops/cuda/grid_sample.h | 2 +- .../contrib_ops/cuda/grid_sample_impl.cu | 101 ++++++---- .../contrib_ops/cuda/grid_sample_impl.h | 2 +- .../layout_transformation.cc | 2 + .../providers/cuda/cuda_execution_provider.cc | 2 + .../providers/cuda/shared_inc/cuda_utils.h | 26 +++ .../providers/cpu/tensor/grid_sample_test.cc | 172 ++++++++---------- .../cpu/tensor/grid_sample_test_gen.py | 2 +- onnxruntime/test/util/default_providers.cc | 16 ++ .../test/util/include/default_providers.h | 3 + 13 files changed, 223 insertions(+), 148 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 46149c577a106..b0ed68d595c42 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -619,6 +619,7 @@ Do not modify directly.* |||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)| |GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| |||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| +|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)| |HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| |Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index be8c0dc86c135..57e951d3a68ff 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -203,6 +203,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze); #endif +#ifdef ENABLE_CUDA_NHWC_OPS +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample); +#endif + template <> KernelCreateInfo BuildKernelCreateInfo() { KernelCreateInfo info; @@ -408,6 +412,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif +#ifdef ENABLE_CUDA_NHWC_OPS + BuildKernelCreateInfo, +#endif }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.cc b/onnxruntime/contrib_ops/cuda/grid_sample.cc index 4c2999c279e0a..2500de39d3536 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample.cc +++ b/onnxruntime/contrib_ops/cuda/grid_sample.cc @@ -9,22 +9,23 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL_TYPED(T) \ +#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ GridSample, \ - kMSDomain, \ - 1, \ + DOMAIN, \ + VERSION, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ - GridSample); + onnxruntime::contrib::cuda::GridSample); -REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain) +REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain) -template -GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { +template +GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { std::string mode_str = info.GetAttrOrDefault("mode", "bilinear"); std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros"); align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0)); @@ -48,8 +49,8 @@ GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) { } } -template -Status GridSample::ComputeInternal(OpKernelContext* context) const { +template +Status GridSample::ComputeInternal(OpKernelContext* context) const { const Tensor* X = context->Input(0); const auto& dims_input = X->Shape().GetDims(); const Tensor* Grid = context->Input(1); @@ -61,11 +62,13 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]); ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2"); + using Ch = Channels; + TensorShapeVector dims_output(4); - dims_output[0] = dims_input[0]; - dims_output[1] = dims_input[1]; - dims_output[2] = dims_grid[1]; - dims_output[3] = dims_grid[2]; + dims_output[Ch::N] = dims_input[Ch::N]; + dims_output[Ch::C] = dims_input[Ch::C]; + dims_output[Ch::H] = dims_grid[1 /* Grid::H */]; + dims_output[Ch::W] = dims_grid[2 /* Grid::W */]; Tensor* Y = context->Output(0, dims_output); // Return early if the output tensor is going to be of size 0 if (Y->Shape().Size() == 0) { @@ -74,7 +77,7 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { typedef typename ToCudaType::MappedType CudaT; CudaT* Y_data = reinterpret_cast(Y->MutableData()); - GridSampleImpl( + GridSampleImpl( Stream(context), reinterpret_cast(X->Data()), reinterpret_cast(Grid->Data()), @@ -89,4 +92,8 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const { } } // namespace cuda } // namespace contrib + +namespace cuda { +REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain) +} // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.h b/onnxruntime/contrib_ops/cuda/grid_sample.h index 08ca58c7cc458..16581bfe77482 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample.h +++ b/onnxruntime/contrib_ops/cuda/grid_sample.h @@ -12,7 +12,7 @@ namespace cuda { using namespace onnxruntime::cuda; -template +template class GridSample final : public CudaKernel { public: explicit GridSample(const OpKernelInfo& info); diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu index 8a391eca7e86a..b23da635bc83d 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu @@ -50,28 +50,34 @@ __device__ T GsReflect(T x, float x_min, float x_max) { return static_cast(fx); } -template +template __device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x, - int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) { + int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) { T pixel = 0.0f; + + auto PixelOffset = [bIdx, cIdx, C, H, W](int64_t x, int64_t y) -> int64_t { + return Layout == LAYOUT_NCHW + ? (bIdx * C * H * W + cIdx * H * W + y * W + x) + : (bIdx * H * W * C + y * W * C + x * C + cIdx); + }; + if (padding_mode == 0) { // zeros if (x >= 0 && x < W && y >= 0 && y < H) { - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + pixel = input_data[PixelOffset(x, y)]; } - } else if (padding_mode == 1) { //border + } else if (padding_mode == 1) { // border x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x)); y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y)); - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + pixel = input_data[PixelOffset(x, y)]; } else { // Reflection - x = (int64_t) GsReflect(x, border[0], border[2]); - y = (int64_t) GsReflect(y, border[1], border[3]); - pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x]; + x = (int64_t)GsReflect(x, border[0], border[2]); + y = (int64_t)GsReflect(y, border[1], border[3]); + pixel = input_data[PixelOffset(x, y)]; } return pixel; } -__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) -{ +__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) { float cubic_alpha = -0.75f; x = abs(x); coeffs[0] = (((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha); @@ -93,7 +99,7 @@ __device__ T GsBicubicInterpolate(T p[4][4], float x, float y) { return pixel; } -template +template __global__ void _GridSampleKernel( const T* input_data, const T* grid_data, @@ -110,16 +116,32 @@ __global__ void _GridSampleKernel( { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out); // extract batch index, channel index, y index, x index for current thread - int BIdx = idx / (C * H_out * W_out ); - int tmpBCnt = BIdx * (C * H_out * W_out); + int BIdx, yIdx, xIdx, cIdx; + if constexpr (Layout == LAYOUT_NCHW) { + BIdx = idx / (C * H_out * W_out); + int tmpBCnt = BIdx * (C * H_out * W_out); + + cIdx = (idx - tmpBCnt) / (H_out * W_out); + int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out); - int cIdx = (idx - tmpBCnt) / (H_out * W_out); - int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out); + yIdx = (idx - tmpCCnt) / W_out; + int tmpHCnt = tmpCCnt + yIdx * W_out; - int yIdx = (idx - tmpCCnt) / W_out; - int tmpHCnt = tmpCCnt + yIdx * W_out; + xIdx = (idx - tmpHCnt); + } else { + static_assert(Layout == LAYOUT_NHWC, "Unsupported layout"); - int xIdx = (idx - tmpHCnt); + BIdx = idx / (H_out * W_out * C); + int tmpBCnt = BIdx * (H_out * W_out * C); + + yIdx = (idx - tmpBCnt) / (W_out * C); + int tmpHCnt = tmpBCnt + yIdx * (W_out * C); + + xIdx = (idx - tmpHCnt) / C; + int tmpWCnt = tmpHCnt + xIdx * C; + + cIdx = (idx - tmpWCnt); + } int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx; T grid_X = grid_data[grid_idx * 2 + 0]; @@ -147,8 +169,9 @@ __global__ void _GridSampleKernel( if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max || grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound if (padding_mode == 1) { // border - grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); - grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); + // Clamping must not be done here, see #10607 + // grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f)); + // grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f)); } else if (padding_mode == 2) { // reflection grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max); grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max); @@ -175,10 +198,10 @@ __global__ void _GridSampleKernel( w_lb = w_b * w_l; w_rb = w_b * w_r; - T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); - T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); - T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); - T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); + T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border); + T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border); + T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border); + T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border); T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v; output_data[outIdx] = interpoV; return; @@ -186,7 +209,8 @@ __global__ void _GridSampleKernel( if (mode == 1) { // nearest int x_n = grid_x_imgSpace; int y_n = grid_y_imgSpace; - output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); + output_data[outIdx] = + PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border); return; } if (mode == 2) { // bicubic @@ -195,7 +219,8 @@ __global__ void _GridSampleKernel( T p[4][4] = {}; // [H][W] for (int64_t h = 0; h < 4; h++) { for (int64_t w = 0; w < 4; w++) { - p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); + p[h][w] = + PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border); } } T dx = grid_x_imgSpace - x0 - 1; @@ -204,7 +229,7 @@ __global__ void _GridSampleKernel( } } -template +template void GridSampleImpl( cudaStream_t stream, const T* input_data, @@ -216,17 +241,23 @@ void GridSampleImpl( const int64_t H_out, const int64_t W_out, T* output_data) { - int blocksPerGrid = (int)(ceil(static_cast(dims[0] * dims[1] * H_out * W_out) / GridDim::maxThreadsPerBlock)); - _GridSampleKernel<<>>( - input_data, grid_data, mode, padding_mode, align_corners, dims[0], dims[1], dims[2], dims[3], H_out, W_out, output_data); + using Ch = Channels; + + int blocksPerGrid = static_cast( + ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock)); + _GridSampleKernel<<>>( + input_data, grid_data, mode, padding_mode, align_corners, + dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W], + H_out, W_out, output_data); } -#define SPECIALIZED_IMPL(T) \ - template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \ - const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \ - const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data); +#define SPECIALIZED_IMPL(T, IsNHWC) \ + template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \ + const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \ + const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data); -SPECIALIZED_IMPL(float) +SPECIALIZED_IMPL(float, false) // NCHW +SPECIALIZED_IMPL(float, true) // NHWC } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h index 6df86ce161908..62cd66a48fa84 100644 --- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h +++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h @@ -8,7 +8,7 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template void GridSampleImpl( cudaStream_t stream, const T* input_data, diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 4505d4afdf1e0..a8717b99a8750 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -31,6 +31,7 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a } #if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS +// TODO(mtavenrath) generate list from registered kernels using nhwc domain const std::unordered_set& GetCUDALayoutSensitiveOps() { static std::unordered_set cuda_nhwc_ops = []() { return std::unordered_set{ @@ -41,6 +42,7 @@ const std::unordered_set& GetCUDALayoutSensitiveOps() { "MaxPool", "GlobalAveragePool", "AveragePool", + "GridSample", }; }(); return cuda_nhwc_ops; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index be2530aec49fa..00783bcbc2665 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1256,6 +1256,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample); // Opset 17 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); @@ -2148,6 +2149,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 17 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h index fa987866c002f..54c024793ff0b 100644 --- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h +++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h @@ -168,5 +168,31 @@ struct NumericLimits { } }; +// TODO Where to put this? good places might be +// core/framework/tensor_shape.h +// core/util/matrix_layout.h + +constexpr bool LAYOUT_NCHW = false; +constexpr bool LAYOUT_NHWC = true; + +template +struct Channels; + +template <> +struct Channels { + static constexpr size_t N = 0; + static constexpr size_t H = 1; + static constexpr size_t W = 2; + static constexpr size_t C = 3; +}; + +template <> +struct Channels { + static constexpr size_t N = 0; + static constexpr size_t C = 1; + static constexpr size_t H = 2; + static constexpr size_t W = 3; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc index 0f097622abff0..5c89d6ea7bd75 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc @@ -6,6 +6,33 @@ namespace onnxruntime { namespace test { + +std::vector> GetExecutionProviders(int opset_version) { + ORT_UNUSED_PARAMETER(opset_version); + + std::vector> execution_providers; + + execution_providers.emplace_back(DefaultCpuExecutionProvider()); +#ifdef USE_CUDA + if (opset_version < 20) { + execution_providers.emplace_back(DefaultCudaExecutionProvider()); +#ifdef ENABLE_CUDA_NHWC_OPS + execution_providers.push_back(DefaultCudaNHWCExecutionProvider()); +#endif + } + +#endif + return execution_providers; +} + +template +void RunTests(T& test, std::vector>&& execution_providers) { + for (size_t idx = 0; idx < execution_providers.size(); ++idx) { + test.ConfigEp(std::move(execution_providers[idx])).RunWithConfig(); + } + execution_providers.clear(); +} + // DO NOT edit following tests. They are generated by: // onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { @@ -25,8 +52,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { @@ -46,8 +72,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { @@ -67,8 +92,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { @@ -88,8 +112,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { @@ -109,8 +132,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) { @@ -130,8 +152,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { @@ -151,8 +172,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { @@ -172,8 +192,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { @@ -193,8 +212,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { @@ -214,8 +232,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { @@ -235,8 +252,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) { @@ -256,8 +272,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { @@ -277,8 +292,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { @@ -298,8 +312,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { @@ -319,8 +332,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { @@ -340,8 +352,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { @@ -361,8 +372,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) { @@ -382,8 +392,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(16)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { @@ -403,8 +412,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { @@ -424,8 +432,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { @@ -445,8 +452,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { @@ -466,8 +472,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { @@ -487,8 +492,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { @@ -508,8 +512,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { @@ -529,8 +532,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { @@ -550,8 +552,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { @@ -571,8 +572,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { @@ -592,8 +592,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) { @@ -613,8 +612,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) { @@ -634,8 +632,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { @@ -655,8 +652,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { @@ -676,8 +672,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { @@ -697,8 +692,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { @@ -718,8 +712,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { @@ -739,8 +732,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { @@ -760,8 +752,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { @@ -781,8 +772,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { @@ -802,8 +792,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { @@ -823,8 +812,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { @@ -844,8 +832,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) { @@ -865,8 +852,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) { @@ -886,8 +872,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { @@ -907,8 +892,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { @@ -928,8 +912,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { @@ -949,8 +932,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { @@ -970,8 +952,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { @@ -991,8 +972,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) { test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) { @@ -1012,8 +992,8 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) test.AddAttribute("padding_mode", padding_mode); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", Y_shape, Y_data); - test.ConfigEp(DefaultCpuExecutionProvider()) - .RunWithConfig(); + RunTests(test, GetExecutionProviders(20)); } + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py index e4d58e79243ef..c60e55617774f 100644 --- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py +++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py @@ -76,6 +76,6 @@ print('test.AddAttribute("padding_mode", padding_mode);') print('test.AddAttribute("align_corners", align_corners);') print('test.AddOutput("Y", Y_shape, Y_data);') - print("test.Run();") + print(f"RunTests(test, GetExecutionProviders({opset_version}));") print("}") print("\n") diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 40b40136af1af..b404c12db3582 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -8,6 +8,9 @@ #ifdef USE_COREML #include "core/providers/coreml/coreml_provider_factory.h" #endif +#if defined(ENABLE_CUDA_NHWC_OPS) +#include +#endif #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/session_options.h" @@ -118,6 +121,19 @@ std::unique_ptr DefaultCudaExecutionProvider() { return nullptr; } +#ifdef ENABLE_CUDA_NHWC_OPS +std::unique_ptr DefaultCudaNHWCExecutionProvider() { +#if defined(USE_CUDA) + OrtCUDAProviderOptionsV2 provider_options{}; + provider_options.do_copy_in_default_stream = true; + provider_options.prefer_nhwc = true; + if (auto factory = CudaProviderFactoryCreator::Create(&provider_options)) + return factory->CreateProvider(); +#endif + return nullptr; +} +#endif + std::unique_ptr CudaExecutionProviderWithOptions(const OrtCUDAProviderOptionsV2* provider_options) { #ifdef USE_CUDA if (auto factory = CudaProviderFactoryCreator::Create(provider_options)) diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 9f78e0a0d4eb2..738fc66d775c6 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -35,6 +35,9 @@ namespace test { // unique_ptr providers with default values for session registration std::unique_ptr DefaultCpuExecutionProvider(bool enable_arena = true); std::unique_ptr DefaultCudaExecutionProvider(); +#ifdef ENABLE_CUDA_NHWC_OPS +std::unique_ptr DefaultCudaNHWCExecutionProvider(); +#endif std::unique_ptr CudaExecutionProviderWithOptions(const OrtCUDAProviderOptionsV2* provider_options); std::unique_ptr DefaultDnnlExecutionProvider(); std::unique_ptr DnnlExecutionProviderWithOptions(const OrtDnnlProviderOptions* provider_options); From ae3d73c9818c34af42c785ff2bd9558007ba315f Mon Sep 17 00:00:00 2001 From: satyajandhyala Date: Fri, 23 Feb 2024 00:21:15 -0800 Subject: [PATCH 135/207] [JS/WebGPU] Fix Split and Where to handle corner cases. (#19613) ### Description 1. Fix Where operator to handle Boolean input less than 4 bytes. 2. Fix JSEP test harness to use tensor names consistently. ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/where.ts | 3 ++- js/web/test/data/ops/where.jsonc | 34 ++++++++++++++++++++++++ js/web/test/test-runner.ts | 4 +-- 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts index cfee07a9239d7..a6375847fc42f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -27,7 +27,7 @@ const createWhereOpProgramShader = const expressionA = `a_data[index_a${x}][component_a${x}]`; const expressionB = `b_data[index_b${x}][component_b${x}]`; // eslint-disable-next-line no-bitwise - const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`; return ` let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)}; @@ -38,6 +38,7 @@ const createWhereOpProgramShader = let index_c${x} = offset_c${x} / 4u; let component_a${x} = offset_a${x} % 4u; let component_b${x} = offset_b${x} % 4u; + let component_c${x} = offset_c${x} % 4u; ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); `; }; diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc index 047fd6fd7511b..990120dd3708e 100644 --- a/js/web/test/data/ops/where.jsonc +++ b/js/web/test/data/ops/where.jsonc @@ -168,5 +168,39 @@ ] } ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1", + "inputs": [ + { + "data": [true, false], + "dims": [1, 1, 2, 1], + "type": "bool" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 4], + "type": "float32" + }, + { + "data": [5, 6, 7, 8, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 2, 3, 4, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index ecc7d4b4a09a5..a4adf5c4ce144 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -627,8 +627,8 @@ export async function runModelTestSet( try { const feeds: Record = {}; const outputsMetaInfo: Record = {}; - testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor); - testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor); + testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor); + testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor); const [start, end, outputs] = await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding}); if (context.perfData.count === 0) { From f4306004321efe9a0e65a19a707bf2266ffd7b16 Mon Sep 17 00:00:00 2001 From: cao lei Date: Fri, 23 Feb 2024 06:02:05 -0800 Subject: [PATCH 136/207] Enable streams for DML EP. This change is to revert PR 19481 since the bug 19480 is fixed by PR 19515 (#19609) ### Description Enable streams for DML EP. This change is to revert PR 19481 since the bug 19480 is fixed by PR 19515 ### Motivation and Context Enable streams for DML EP. This change is to revert PR 19481 since the bug 19480 is fixed by PR 19515 --- cmake/adjust_global_compile_flags.cmake | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index a56864ebf4644..8161ea574b8cc 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -92,13 +92,8 @@ if (onnxruntime_MINIMAL_BUILD) endif() endif() -# Enable stream for all the non-minimal build, except for DML. There's currently a bug -# in the allocation planner when reusing buffers and more than one streams are used that -# make it possible (although rarely) to reach a reference count of 0 for a buffer that is -# still being used. Since DML doesn't benefit from multiple streams, disabling it is the -# safest option for now. -# https://github.com/microsoft/onnxruntime/issues/19480 -if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML) +# Enable stream for all the non-minimal build +if (NOT onnxruntime_MINIMAL_BUILD) add_compile_definitions(ORT_ENABLE_STREAM) endif() From efbe2b84556c195e7d7f3353321eb3f410a1e645 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Fri, 23 Feb 2024 17:45:17 +0100 Subject: [PATCH 137/207] Fix cuDNN v9 build by replacing removed cuDNN v6 RNN API usage by cuDNN v8 RNN API and reenable RNN tests for CUDA EP (#19419) Replace deprecated cuDNN RNN based API by cuDNN v8 RNN API and re-enable RNN tests for the CUDA EP. ### Motivation and Context The deprecated cuDNN RNN API might vanish soon and in addition for the current CUDA EP RNN implementation all RNN tests are disabled due to failures. With this change the deprecated API has been removed and the new updated implemented doesn't fail the tests anymore. --- .../core/providers/cuda/cudnn_common.h | 4 +- .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 350 +++++++++--------- .../core/providers/cuda/rnn/cudnn_rnn_base.h | 55 +-- onnxruntime/core/providers/cuda/rnn/rnn.cc | 3 +- onnxruntime/core/providers/cuda/rnn/rnn.h | 1 + .../core/providers/cuda/rnn/rnn_impl.cu | 91 +---- .../core/providers/cuda/rnn/rnn_impl.h | 14 +- .../test/providers/cpu/rnn/rnn_op_test.cc | 24 +- 8 files changed, 240 insertions(+), 302 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cudnn_common.h b/onnxruntime/core/providers/cuda/cudnn_common.h index fdd14dedad47e..2cbeb13696270 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.h +++ b/onnxruntime/core/providers/cuda/cudnn_common.h @@ -24,12 +24,12 @@ class CudnnTensor final { operator cudnnTensorDescriptor_t() const { return tensor_; } + Status CreateTensorIfNeeded(); + template static cudnnDataType_t GetDataType(); private: - Status CreateTensorIfNeeded(); - cudnnTensorDescriptor_t tensor_; }; diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index 99c1f48e21c74..b61b104790fe5 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -9,40 +9,49 @@ namespace onnxruntime { namespace cuda { template -void CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnn_desc, - const int pseudo_layer, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, - const cudnnFilterDescriptor_t filter_desc, - const void* reorganized_w_data, - const int lin_layer_id, - const T* pos, - int& offset, - bool is_matrix, - cudaStream_t cuda_stream) const { +Status CudnnRnnBase::SetWeightBias(const cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnn_desc, + const int pseudo_layer, + size_t reorganized_w_data_size, + const void* reorganized_w_data, + const int lin_layer_id, + const T* pos, + int& offset, + bool is_matrix, + cudaStream_t cuda_stream) const { int numDims; - std::vector matDims(3); + std::array matDims; + std::array strideA; cudnnDataType_t dt; - cudnnTensorFormat_t tf; T* mem_offset; - if (is_matrix) { - cudnnGetRNNLinLayerMatrixParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); - } else { - cudnnGetRNNLinLayerBiasParams(handle, rnn_desc, pseudo_layer, x_desc, w_desc, reorganized_w_data, lin_layer_id, filter_desc, (void**)&mem_offset); - } + CudnnTensor tensor_desc_matrix, tensor_desc_bias; + ORT_RETURN_IF_ERROR(tensor_desc_bias.CreateTensorIfNeeded()); + ORT_RETURN_IF_ERROR(tensor_desc_matrix.CreateTensorIfNeeded()); - cudnnGetFilterNdDescriptor(filter_desc, 3, &dt, &tf, &numDims, matDims.data()); + T *mem_offset_matrix, *mem_offset_bias; + CUDNN_RETURN_IF_ERROR(cudnnGetRNNWeightParams( + handle, rnn_desc, pseudo_layer, reorganized_w_data_size, reorganized_w_data, + lin_layer_id, tensor_desc_matrix, (void**)&mem_offset_matrix, tensor_desc_bias, (void**)&mem_offset_bias)); + CUDNN_RETURN_IF_ERROR(cudnnGetTensorNdDescriptor( + is_matrix ? tensor_desc_matrix : tensor_desc_bias, 3, &dt, &numDims, matDims.data(), strideA.data())); + + mem_offset = is_matrix ? mem_offset_matrix : mem_offset_bias; int count = matDims[0] * matDims[1] * matDims[2]; + + if (strideA[0] != count) { + return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::INVALID_ARGUMENT, "Stride is not packed"); + } CUDA_CALL_THROW(cudaMemcpyAsync(mem_offset, pos + offset, count * sizeof(T), cudaMemcpyDeviceToDevice, cuda_stream)); + offset += count; + + return Status::OK(); } template Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, + size_t reorganized_w_data_size, void* reorganized_w_data, const T* W_data, const T* R_data, @@ -51,18 +60,22 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, int w_offset = 0; int r_offset = 0; int bias_offset = 0; - CudnnFilterDescriptor filter_desc; for (int layer = 0; layer < RNN_NUM_LAYERS * num_directions_; ++layer) { for (size_t idx = 0; idx < W_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias( + cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + W_lin_layer_id_[idx], W_data, w_offset, true, cuda_stream)); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + W_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream)); } } for (size_t idx = 0; idx < R_lin_layer_id_.size(); ++idx) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + R_lin_layer_id_[idx], R_data, r_offset, true, cuda_stream)); if (B_data != nullptr) { - SetWeightBias(cudnn_handle, rnn_desc, layer, x_desc, w_desc, filter_desc, reorganized_w_data, R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream); + ORT_RETURN_IF_ERROR(SetWeightBias(cudnn_handle, rnn_desc, layer, reorganized_w_data_size, reorganized_w_data, + R_lin_layer_id_[idx], B_data, bias_offset, false, cuda_stream)); } } } @@ -72,6 +85,7 @@ Status CudnnRnnBase::SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, template Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, + size_t& reorganized_w_data_size_in_bytes, IAllocatorUniquePtr& reorganized_w_data, CudnnFilterDescriptor& target_w_desc, CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const { @@ -91,19 +105,16 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons TensorShapeVector dims_w({w_size, 1, 1}); ORT_RETURN_IF_ERROR(target_w_desc.Set(dims_w, CudnnTensor::GetDataType())); - TensorShapeVector fake_dims_x({1, input_size, 1}); - CudnnTensor fake_x_desc; - ORT_RETURN_IF_ERROR(fake_x_desc.Set(fake_dims_x, CudnnTensor::GetDataType())); - // Prepare the weight data - reorganized_w_data = GetScratchBuffer(w_size * sizeof(T), ort_stream); + reorganized_w_data_size_in_bytes = w_size * sizeof(T); + reorganized_w_data = GetScratchBuffer(reorganized_w_data_size_in_bytes, ort_stream); // In many cases, this allocation is bigger than needed, leaving part of - // the buffer unintialized. non-zero garbage data leads to wrong result + // the buffer uninitialized. non-zero garbage data leads to wrong result // in call to cudnnRNNForwardInference() // TODO! refine allocation size for each case. cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; - cudaMemsetAsync(reorganized_w_data.get(), 0, w_size * sizeof(T), cuda_stream); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(reorganized_w_data.get(), 0, reorganized_w_data_size_in_bytes, cuda_stream)); const T* W_data = W->Data(); const T* R_data = R->Data(); @@ -111,8 +122,9 @@ Status CudnnRnnBase::ReorganizeWeights(const Tensor* W, const Tensor* R, cons auto* ort_cuda_stream = dynamic_cast(ort_stream); cudnnHandle_t cudnn_handle = ort_cuda_stream ? ort_cuda_stream->cudnn_handle_ : DefaultCudnnHandle(); - ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, fake_x_desc, target_w_desc, - reorganized_w_data.get(), W_data, R_data, B_data, cuda_stream)); + ORT_RETURN_IF_ERROR(SetCudnnRnnWeightBias(cudnn_handle, rnn_desc, + reorganized_w_data_size_in_bytes, reorganized_w_data.get(), + W_data, R_data, B_data, cuda_stream)); return Status::OK(); } @@ -128,22 +140,31 @@ Status CudnnRnnBase::CacheCudnnRnnWeights(const OpKernelInfo& info) { bool get_R = info.TryGetConstantInput(RNN_Input_Index::R, &R); bool get_B = info.TryGetConstantInput(RNN_Input_Index::B, &B); + bool has_bias = B != nullptr; + if (get_W && get_R) { CudnnRNN tmp_rnn_desc; - ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(DefaultCudnnHandle(), + auto proj_size = hidden_size_; + ORT_RETURN_IF_ERROR(tmp_rnn_desc.Set(W->Shape()[2], // input_size hidden_size_, + proj_size, RNN_NUM_LAYERS, cudnn_dropout_desc_, cudnn_direction_mode_, rnn_mode_, - CudnnTensor::GetDataType(), - GetDeviceProp())); + has_bias, + CudnnTensor::GetDataType())); if (get_B) { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, B, + w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, + tmp_rnn_desc, nullptr)); } else { - ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, w_data_cache_, w_desc_cache_, tmp_rnn_desc, nullptr)); + ORT_RETURN_IF_ERROR(ReorganizeWeights(W, R, nullptr, + w_data_cache_size_in_bytes_, w_data_cache_, w_desc_cache_, + tmp_rnn_desc, nullptr)); } cudaStreamSynchronize(nullptr); + weight_cached_ = true; } @@ -158,17 +179,72 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { ORT_ENFORCE(nullptr != X); // optional inputs - const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); // [batch_size] - const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); // initial hidden. [num_directions_, batch_size, hidden_size_] + // [batch_size] + const Tensor* sequence_lens = ctx->Input(RNN_Input_Index::sequence_lens); + // initial hidden. [num_directions_, batch_size, hidden_size_] + const Tensor* initial_h = ctx->Input(RNN_Input_Index::initial_h); const Tensor* initial_c(nullptr); if (rnn_mode_ == CUDNN_LSTM) { - initial_c = ctx->Input(RNN_Input_Index::initial_c); // initial cell. [num_directions_, batch_size, hidden_size_] + // initial cell. [num_directions_, batch_size, hidden_size_] + initial_c = ctx->Input(RNN_Input_Index::initial_c); } + size_t proj_size = hidden_size_; int64_t seq_length = X->Shape()[0]; int64_t batch_size = X->Shape()[1]; int64_t input_size = X->Shape()[2]; + // we thread a single input as sequence_lens of length 1, require to expand to [batch_size]? + std::vector sequence_lengths_temp; + if (!sequence_lens) { + sequence_lengths_temp.resize(batch_size, gsl::narrow_cast(seq_length)); + } + + const int32_t* sequence_lens_data = (sequence_lens == nullptr) + ? sequence_lengths_temp.data() + : sequence_lens->Data(); + + // cuDNN doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 + // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence + int64_t zero_seq_count = 0; + std::vector zero_seq_index_cache(batch_size, 0); + + CudaAsyncBuffer sequence_lens_buffer(this, batch_size); + int32_t* seq_len_array = sequence_lens_buffer.CpuPtr(); + + // 0-len sequences are not supported by cuDNN. + // Replace them by sequences of len 1 and mask them out with SetZeroSequences + for (int i = 0; i < batch_size; ++i) { + if (0 == sequence_lens_data[i]) { + seq_len_array[i] = 1; + zero_seq_index_cache[zero_seq_count] = i; + ++zero_seq_count; + } else { + seq_len_array[i] = sequence_lens_data[i]; + } + } + + // Calculate the zero position cache for reverse direction if it's bidirectional + // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since + // we hacked the 0 sequence to 1 + if (zero_seq_count && num_directions_ > 1) { + zero_seq_index_cache.resize(zero_seq_count * num_directions_); + for (int64_t i = 0; i < zero_seq_count; ++i) { + zero_seq_index_cache[static_cast(zero_seq_count) + i] = + static_cast(batch_size + zero_seq_index_cache[i]); + } + zero_seq_count *= num_directions_; + } + + // Prior to cuDNN 8.9.1 the sequence lens buffer must be passed to cudnnRNNForward and thus is must + // be copied to the GPU always. + ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + // Starting with cuDNN 8.9.1 the sequence lens buffer is ignored by cudnnRNNForward and thus it must + // be copied to the GPU only for the ReverseBySequence kernels. + // if (reverse_) { + // ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); + // } + // optional outputs TensorShapeVector dims_Y({seq_length, num_directions_, batch_size, hidden_size_}); TensorShapeVector dims_hxy({RNN_NUM_LAYERS * num_directions_, batch_size, hidden_size_}); @@ -177,25 +253,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { Tensor* Y_h = ctx->Output(Output_Index::Y_h, dims_hxy); Tensor* Y_c = ctx->Output(Output_Index::Y_c, dims_yc); - std::vector dims_x({batch_size, input_size, 1}); - std::vector dims_y({batch_size, hidden_size_ * num_directions_, 1}); - - CudnnTensor x_desc_temp; - ORT_RETURN_IF_ERROR(x_desc_temp.Set(dims_x, CudnnTensor::GetDataType())); - CudnnTensor y_desc_temp; - ORT_RETURN_IF_ERROR(y_desc_temp.Set(dims_y, CudnnTensor::GetDataType())); - std::vector x_desc(seq_length, x_desc_temp); - std::vector y_desc(seq_length, y_desc_temp); - - CudnnTensor hx_desc; - CudnnTensor cx_desc; - CudnnTensor y_h_desc; - CudnnTensor y_c_desc; - ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_h_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(y_c_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - IAllocatorUniquePtr x_reversed_data; const T* x_data = X->Data(); if (reverse_) { @@ -203,6 +260,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { x_reversed_data = GetScratchBuffer(seq_length * batch_size * input_size, ctx->GetComputeStream()); ReverseBySequence(Stream(ctx), gsl::narrow_cast(seq_length), + sequence_lens_buffer.GpuPtr(), gsl::narrow_cast(batch_size), gsl::narrow_cast(input_size), reinterpret_cast(x_data), @@ -226,115 +284,82 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { y_data = y_alloc_data.get(); } - const int32_t* sequence_lens_data = (sequence_lens == nullptr) ? nullptr : sequence_lens->Data(); + const Tensor* B = ctx->Input(RNN_Input_Index::B); + bool has_bias = B != nullptr; CudnnRNN rnn_desc; - ORT_RETURN_IF_ERROR(rnn_desc.Set(GetCudnnHandle(ctx), + ORT_RETURN_IF_ERROR(rnn_desc.Set(input_size, hidden_size_, + proj_size, RNN_NUM_LAYERS, cudnn_dropout_desc_, cudnn_direction_mode_, rnn_mode_, - CudnnTensor::GetDataType(), - GetDeviceProp())); + has_bias, + CudnnTensor::GetDataType())); // Prepare the weight data + size_t w_data_size_in_bytes = 0; IAllocatorUniquePtr w_data; CudnnFilterDescriptor w_desc; if (!weight_cached_) { const Tensor& W = *ctx->Input(RNN_Input_Index::W); const Tensor& R = *ctx->Input(RNN_Input_Index::R); const Tensor* B = ctx->Input(RNN_Input_Index::B); - ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data, w_desc, rnn_desc, ctx->GetComputeStream())); + ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc, + rnn_desc, ctx->GetComputeStream())); } - // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences - CUDNN_RETURN_IF_ERROR(cudnnSetRNNPaddingMode(rnn_desc, CUDNN_RNN_PADDED_IO_ENABLED)); + CudnnDataTensor x_desc1; + ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, + input_size, seq_len_array)); + CudnnDataTensor y_desc1; + ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, + ((rnn_mode_ == CUDNN_LSTM) ? proj_size : hidden_size_) * num_directions_, + seq_len_array)); - size_t workspace_bytes; - CUDNN_RETURN_IF_ERROR(cudnnGetRNNWorkspaceSize(GetCudnnHandle(ctx), rnn_desc, gsl::narrow_cast(seq_length), x_desc.data(), &workspace_bytes)); - auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); - int64_t zero_seq_count = 0; - std::vector zero_seq_index_cache(batch_size, 0); - int64_t zero_seq_index_cache_size = 0; - - if (CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_ || nullptr == sequence_lens_data) { - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInference(GetCudnnHandle(ctx), - rnn_desc, - gsl::narrow_cast(seq_length), - x_desc.data(), - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc.data(), - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - workspace_cuda.get(), - workspace_bytes)); - } else { - // cudnn doesn't support 0 sequence inside the batch, find the 0 sequence and set it to 1 - // there's a ZeroMask kernel to reset the result to 0 for the 0 sequence - std::vector seq_len_array(sequence_lens_data, sequence_lens_data + batch_size); - for (int i = 0; i < batch_size; ++i) { - if (0 == seq_len_array[i]) { - seq_len_array[i] = 1; - zero_seq_index_cache[zero_seq_count] = i; - ++zero_seq_count; - } - } + CudnnTensor cx_desc; + ORT_RETURN_IF_ERROR(cx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); - // Calculate the zero position cache for reverse direction if it's bidirectional - // The cache is for Y_h or Y_c, and the 1st sequence for Y, no need to do it for other sequence in Y since - // we hacked the 0 sequence to 1 - if (zero_seq_count && num_directions_ > 1) { - zero_seq_index_cache_size = zero_seq_count * num_directions_; - zero_seq_index_cache.resize(zero_seq_index_cache_size); - for (int64_t i = 0; i < zero_seq_count; ++i) { - zero_seq_index_cache[static_cast(zero_seq_count) + i] = static_cast(batch_size + zero_seq_index_cache[i]); - } - } + CudnnTensor hx_desc; + ORT_RETURN_IF_ERROR(hx_desc.Set(dims_hxy, CudnnTensor::GetDataType())); + + // reserveSpaceSize is not required cudnnRNNForward, but returned by cudnnGetRNNTempSpaceSizes + size_t workspace_bytes, reservespace_bytes; - CudnnDataTensor x_desc1; - ORT_RETURN_IF_ERROR(x_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, input_size, seq_len_array.data())); - CudnnDataTensor y_desc1; - ORT_RETURN_IF_ERROR(y_desc1.Set(CudnnTensor::GetDataType(), seq_length, batch_size, hidden_size_ * num_directions_, seq_len_array.data())); - - CUDNN_RETURN_IF_ERROR(cudnnRNNForwardInferenceEx(GetCudnnHandle(ctx), - rnn_desc, - x_desc1, - x_data_input, - hx_desc, - hx_data, - cx_desc, - cx_data, - weight_cached_ ? w_desc_cache_ : w_desc, - weight_cached_ ? w_data_cache_.get() : w_data.get(), - y_desc1, - y_data, - y_h_desc, - y_h_data, - y_c_desc, - y_c_data, - nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, - workspace_cuda.get(), - workspace_bytes)); - - // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, no need the following code to retrive Y_h from Y data. - if (nullptr == Y) { + CUDNN_RETURN_IF_ERROR(cudnnGetRNNTempSpaceSizes(GetCudnnHandle(ctx), rnn_desc, CUDNN_FWD_MODE_INFERENCE, + x_desc1, &workspace_bytes, &reservespace_bytes)); + auto workspace_cuda = GetScratchBuffer(workspace_bytes, ctx->GetComputeStream()); + auto reservespace_cuda = GetScratchBuffer(reservespace_bytes, ctx->GetComputeStream()); + + CUDNN_RETURN_IF_ERROR(cudnnRNNForward(GetCudnnHandle(ctx), + rnn_desc, + CUDNN_FWD_MODE_INFERENCE, + sequence_lens_buffer.GpuPtr(), // should be zero starting with cudnn 8.9.1 + x_desc1, + x_data_input, + y_desc1, + y_data, // output + hx_desc, + hx_data, // input + y_h_data, // output + cx_desc, cx_data, y_c_data, + weight_cached_ ? w_data_cache_size_in_bytes_ : w_data_size_in_bytes, + weight_cached_ ? w_data_cache_.get() : w_data.get(), + workspace_bytes, + workspace_cuda.get(), + reservespace_bytes, + reservespace_cuda.get())); + + // Early terminate for this case since Y data is not required, and Y_h is obtained correctly, + // no need the following code to retrieve Y_h from Y data. + if (nullptr == Y) { + // Mask on output for 0 sequence batches + if (zero_seq_count > 0) { // Mask on output for 0 sequence batches - if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); - } - return Status::OK(); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); } + return Status::OK(); } IAllocatorUniquePtr y_reorganized_data; @@ -345,6 +370,7 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // reverse output data ReverseBySequence(Stream(ctx), gsl::narrow_cast(seq_length), + sequence_lens_buffer.GpuPtr(), gsl::narrow_cast(batch_size), gsl::narrow_cast(hidden_size_), reinterpret_cast(y_data), @@ -361,8 +387,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { } if (Y != nullptr) { - // User specified this optional output, so need to copy the reversed data to orignial place - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T), cudaMemcpyDeviceToDevice, Stream(ctx))); + // User specified this optional output, so need to copy the reversed data to original place + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(y_data, y_reorganized_data.get(), output_size * sizeof(T), + cudaMemcpyDeviceToDevice, Stream(ctx))); } else { y_data = y_reorganized_data.get(); } @@ -370,23 +397,9 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { // Mask on output for 0 sequence batches if (zero_seq_count > 0) { - SetZeroSequences(zero_seq_index_cache_size, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); + SetZeroSequences(zero_seq_count, zero_seq_index_cache, y_data, y_h_data, y_c_data, ctx->GetComputeStream()); } - if ((CUDNN_RNN_RELU == rnn_mode_ || CUDNN_RNN_TANH == rnn_mode_) && sequence_lens_data != nullptr && y_h_data != nullptr && y_data != nullptr) { - CudaAsyncBuffer sequence_lens_buffer(this, batch_size); - memcpy(sequence_lens_buffer.CpuPtr(), sequence_lens_data, batch_size * sizeof(int32_t)); - ORT_RETURN_IF_ERROR(sequence_lens_buffer.CopyToGpu(ctx->GetComputeStream())); - RnnMaskImpl(Stream(ctx), - gsl::narrow_cast(num_directions_), - gsl::narrow_cast(seq_length), - gsl::narrow_cast(batch_size), - gsl::narrow_cast(hidden_size_), - sequence_lens_buffer.GpuPtr(), - reinterpret_cast(y_data), - reinterpret_cast(y_h_data), - output_size); - } return Status::OK(); } @@ -399,7 +412,8 @@ void CudnnRnnBase::SetZeroSequences(const int64_t zero_seq_index_cache_size, onnxruntime::Stream* ort_stream) const { typedef typename ToCudaType::MappedType CudaT; CudaAsyncBuffer zero_seq_index_cache_async_buffer(this, zero_seq_index_cache_size); - memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), zero_seq_index_cache_size * sizeof(int32_t)); + memcpy(zero_seq_index_cache_async_buffer.CpuPtr(), zero_seq_index_cache.data(), + zero_seq_index_cache_size * sizeof(int32_t)); ORT_THROW_IF_ERROR(zero_seq_index_cache_async_buffer.CopyToGpu(ort_stream)); cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; MaskZeroSequences(cuda_stream, diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h index 1c9483b2afd38..0fa01d3486e99 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.h @@ -38,26 +38,28 @@ class CudnnRNN { } } - Status Set(const cudnnHandle_t& cudnnHandle, int64_t hidden_size, int num_layers, + Status Set(int64_t input_size, int64_t hidden_size, int64_t proj_size, int num_layers, cudnnDropoutDescriptor_t cudnn_dropout_desc, cudnnDirectionMode_t cudnn_direction_model, - cudnnRNNMode_t rnn_mode, cudnnDataType_t dataType, const cudaDeviceProp& prop) { + cudnnRNNMode_t rnn_mode, bool has_bias, cudnnDataType_t dataType) { if (!cudnn_rnn_desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateRNNDescriptor(&cudnn_rnn_desc_)); - CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v6(cudnnHandle, - cudnn_rnn_desc_, + CUDNN_RETURN_IF_ERROR(cudnnSetRNNDescriptor_v8(cudnn_rnn_desc_, + CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC + rnn_mode, + has_bias ? CUDNN_RNN_DOUBLE_BIAS : CUDNN_RNN_NO_BIAS, + cudnn_direction_model, + CUDNN_LINEAR_INPUT, + dataType, + dataType, + dataType == CUDNN_DATA_HALF ? CUDNN_TENSOR_OP_MATH : CUDNN_DEFAULT_MATH, + gsl::narrow_cast(input_size), gsl::narrow_cast(hidden_size), + gsl::narrow_cast(proj_size), // projected size num_layers, cudnn_dropout_desc, - CUDNN_LINEAR_INPUT, // We can also skip the input matrix transformation - cudnn_direction_model, - rnn_mode, - CUDNN_RNN_ALGO_STANDARD, // CUDNN_RNN_ALGO_PERSIST_STATIC, CUDNN_RNN_ALGO_PERSIST_DYNAMIC - dataType)); - - if (prop.major >= 7 && dataType == CUDNN_DATA_HALF) { - cudnnSetRNNMatrixMathType(cudnn_rnn_desc_, CUDNN_TENSOR_OP_MATH); - } + // CUDNN_RNN_DATA_LAYOUT_SEQ_MAJOR_UNPACKED works with CUDNN_RNN_PADDED_IO_ENABLED, so that it will auto fill 0 for the shorter sequences + CUDNN_RNN_PADDED_IO_ENABLED)); return Status::OK(); } @@ -119,8 +121,7 @@ class CudnnRnnBase : public CudaKernel { private: Status SetCudnnRnnWeightBias(const cudnnHandle_t cudnn_handle, const cudnnRNNDescriptor_t rnn_desc, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, + size_t w_data_size, void* w_data, const T* W_data, const T* R_data, @@ -128,23 +129,22 @@ class CudnnRnnBase : public CudaKernel { cudaStream_t cuda_stream) const; Status ReorganizeWeights(const Tensor* W, const Tensor* R, const Tensor* B, + size_t& target_w_data_size_in_bytes, IAllocatorUniquePtr& target_w_data, CudnnFilterDescriptor& target_w_desc, CudnnRNN& rnn_desc, onnxruntime::Stream* ort_stream) const; - void SetWeightBias(const cudnnHandle_t handle, - const cudnnRNNDescriptor_t rnn_desc, - const int pseudo_layer, - const cudnnTensorDescriptor_t x_desc, - const cudnnFilterDescriptor_t w_desc, - const cudnnFilterDescriptor_t filter_desc, - const void* w_data, - const int lin_layer_id, - const T* pos, - int& offset, - bool is_matrix, - cudaStream_t cuda_stream) const; + Status SetWeightBias(const cudnnHandle_t handle, + const cudnnRNNDescriptor_t rnn_desc, + const int pseudo_layer, + size_t w_data_size, + const void* w_data, + const int lin_layer_id, + const T* pos, + int& offset, + bool is_matrix, + cudaStream_t cuda_stream) const; void SetZeroSequences(const int64_t zero_seq_index_cache_size, const std::vector zero_seq_index_cache, @@ -167,6 +167,7 @@ class CudnnRnnBase : public CudaKernel { cudnnRNNMode_t rnn_mode_; // w_desc_cache_ & w_data_cache_ are changed in Constructor if we can get the weights as constant input CudnnFilterDescriptor w_desc_cache_; + size_t w_data_cache_size_in_bytes_; IAllocatorUniquePtr w_data_cache_; bool weight_cached_; int64_t layout_; diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.cc b/onnxruntime/core/providers/cuda/rnn/rnn.cc index 4bd22340ef2bb..ed8be63679707 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.cc +++ b/onnxruntime/core/providers/cuda/rnn/rnn.cc @@ -1,8 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared_library/provider_api.h" #include "rnn.h" + +#include "core/providers/shared_library/provider_api.h" #include "rnn_impl.h" #include "core/providers/cuda/cudnn_common.h" diff --git a/onnxruntime/core/providers/cuda/rnn/rnn.h b/onnxruntime/core/providers/cuda/rnn/rnn.h index e4e50046b3725..6221afb003b22 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn.h @@ -4,6 +4,7 @@ #pragma once #include "cudnn_rnn_base.h" + #include "core/providers/cuda/cuda_common.h" #include diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu index d485855ddb417..94c8036be6cdf 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.cu @@ -8,22 +8,32 @@ namespace onnxruntime { namespace cuda { template -__global__ void _ReverseBySequenceKernel(const int32_t seq_length, +__global__ void _ReverseBySequenceKernel(const int32_t max_seq_length, + const int32_t* seq_lengths, const int32_t block_size, const fast_divmod div_batch_block, + const fast_divmod div_input_or_hidden_size, const T* data, T* reversed_data, const CUDA_LONG N) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); int seq_id, offset; div_batch_block.divmod(id, seq_id, offset); - int org_id = (seq_length - seq_id - 1) * block_size + offset; - reversed_data[id] = data[org_id]; + int batch, batch_offset; + div_input_or_hidden_size.divmod(offset, batch, batch_offset); + int seq_id_org = seq_lengths[batch] - seq_id - 1; + if (seq_id_org >= 0) { + int org_id = seq_id_org * block_size + offset; + reversed_data[id] = data[org_id]; + } else { + reversed_data[id] = T{}; + } } template void ReverseBySequence(cudaStream_t stream, - const int32_t seq_length, + const int32_t max_seq_length, + const int32_t *seq_lengths, const int32_t batch_size, const int32_t input_or_hidden_size, const T* data, @@ -32,9 +42,10 @@ void ReverseBySequence(cudaStream_t stream, // kerneral int32_t block_size = batch_size * input_or_hidden_size; fast_divmod div_batch_block(block_size); + fast_divmod div_input_or_hidden_size(input_or_hidden_size); int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); _ReverseBySequenceKernel<<>>( - seq_length, block_size, div_batch_block, data, reversed_data, (CUDA_LONG)N); + max_seq_length, seq_lengths, block_size, div_batch_block, div_input_or_hidden_size, data, reversed_data, (CUDA_LONG)N); } template @@ -82,60 +93,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream, data, reordered_data, (CUDA_LONG)N); } -template -__global__ void _RnnMaskKernel(const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - const fast_divmod div_seq_block, - const fast_divmod div_dir_block, - const fast_divmod div_batch_block, - T* y_output_data, - T* y_h_output_data, - const CUDA_LONG N) { - CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); - - int seq_id, direction_id, batch_id, offset; - div_seq_block.divmod(id, seq_id, offset); - div_dir_block.divmod(offset, direction_id, offset); - div_batch_block.divmod(offset, batch_id, offset); - int32_t batch_seq_length = sequence_lens[batch_id]; - - if (batch_id >= batch_size || batch_seq_length == seq_length) { - return; - } - - if (seq_id >= batch_seq_length) { - y_output_data[id] = 0; - return; - } - - if ((y_h_output_data != nullptr) && - ((direction_id == 0 && (seq_id + 1) == batch_seq_length) || (direction_id == 1 && seq_id == 0))) { - int hy_idx = direction_id * batch_size * hidden_size + batch_id * hidden_size + offset; - y_h_output_data[hy_idx] = y_output_data[id]; - } -} - -template -void RnnMaskImpl(cudaStream_t stream, - const int32_t num_directions, - const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - T* y_output_data, - T* y_h_output_data, - const size_t N) { - fast_divmod div_seq_block(batch_size * hidden_size * num_directions); - fast_divmod div_dir_block(batch_size * hidden_size); - fast_divmod div_batch_block(hidden_size); - int blocksPerGrid = (int)(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); - _RnnMaskKernel<<>>( - seq_length, batch_size, hidden_size, sequence_lens, div_seq_block, - div_dir_block, div_batch_block, y_output_data, y_h_output_data, (CUDA_LONG)N); -} - template __global__ void _MaskZeroSequences(const int32_t hidden_size, T* y_output_data, @@ -180,17 +137,9 @@ void MaskZeroSequences(cudaStream_t stream, } #define SPECIALIZED_RNN_IMPL(T) \ - template void RnnMaskImpl(cudaStream_t stream, \ - const int32_t num_directions, \ - const int32_t seq_length, \ - const int32_t batch_size, \ - const int32_t hidden_size, \ - const int32_t* sequence_lens, \ - T* y_output_data, \ - T* y_h_output_data, \ - const size_t N); \ - template void ReverseBySequence(cudaStream_t stream, \ - const int32_t seq_length, \ + template void ReverseBySequence(cudaStream_t stream, \ + const int32_t max_seq_length, \ + const int32_t* seq_lengths, \ const int32_t batch_size, \ const int32_t hidden_size, \ const T* data, \ @@ -203,7 +152,7 @@ void MaskZeroSequences(cudaStream_t stream, const T* data, \ T* reordered_data, \ const size_t N); \ -template void MaskZeroSequences(cudaStream_t stream, \ +template void MaskZeroSequences(cudaStream_t stream, \ const int32_t hidden_size, \ T* y_output_data, \ T* y_h_output_data, \ diff --git a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h index 9844e04ff6ec5..ba876011f6b67 100644 --- a/onnxruntime/core/providers/cuda/rnn/rnn_impl.h +++ b/onnxruntime/core/providers/cuda/rnn/rnn_impl.h @@ -10,7 +10,8 @@ namespace cuda { template void ReverseBySequence(cudaStream_t stream, - const int32_t seq_length, + const int32_t max_seq_length, + const int32_t* seq_lengths, const int32_t batch_size, const int32_t input_or_hidden_size, const T* data, @@ -26,17 +27,6 @@ void ReorderBidirectionalDataInSequence(cudaStream_t stream, T* reordered_data, const size_t N); -template -void RnnMaskImpl(cudaStream_t stream, - const int32_t num_directions, - const int32_t seq_length, - const int32_t batch_size, - const int32_t hidden_size, - const int32_t* sequence_lens, - T* y_output_data, - T* y_h_output_data, - const size_t N); - template void MaskZeroSequences(cudaStream_t stream, const int32_t hidden_size, diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index b9875b9553a55..1a31743e2f7e7 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -120,15 +120,11 @@ TEST(RNNTest, RNN_bidirectional_bias_initial_zigged_batch) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); // TensorRT failed on RNN tests - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_bidirectional_zigged_batch) { -#else TEST(RNNTest, RNN_bidirectional_zigged_batch) { -#endif OpTester test("RNN"); int64_t num_directions = 2, input_size = 2, hidden_size = 3, seq_length = 5; @@ -275,15 +271,11 @@ TEST(RNNTest, RNN_reverse_direction_zigged_batch) { std::vector Y_h_data({0.87014002F, 0.09402763F, -0.54269236F, 0.64809889F, -0.19472955F, -0.24271242F}); test.AddOutput("Y_h", Y_h_dims, Y_h_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } // Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_forward_direction_zigged_batch) { -#else TEST(RNNTest, RNN_forward_direction_zigged_batch) { -#endif OpTester test("RNN"); int64_t num_directions = 1, input_size = 2, hidden_size = 3, seq_length = 5; @@ -357,12 +349,7 @@ TEST(RNNTest, RNN_forward_direction_zigged_batch) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_bidirectional_0) { -#else TEST(RNNTest, RNN_bidirectional_0) { -#endif OpTester test("RNN"); int64_t num_directions = 2, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5; @@ -424,12 +411,7 @@ TEST(RNNTest, RNN_bidirectional_0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } -// Doesn't work with CUDA 11.4 on Windows. Need investigation. -#if defined(USE_CUDA) && defined(_WIN32) -TEST(RNNTest, DISABLED_RNN_bidirectional_1) { -#else TEST(RNNTest, RNN_bidirectional_1) { -#endif OpTester test("RNN"); int64_t num_directions = 2, input_size = 2, hidden_size = 2, batch_size = 1, seq_length = 1; @@ -597,7 +579,7 @@ TEST(RNNTest, DISABLED_RNN_default_attributes_and_forward_direction) { } } -TEST(RNNTest, DISABLED_RNN_reverse_direction) { +TEST(RNNTest, RNN_reverse_direction) { int64_t num_directions = 1, input_size = 2, hidden_size = 3, batch_size = 1, seq_length = 5; // In case of useDefault, attributes, inputs or outputs are not set. From aec2389ad0463d218b8cf3b1e245d4c34e98364a Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 23 Feb 2024 12:52:47 -0800 Subject: [PATCH 138/207] [js/webgpu] allows a ProgramInfo's RunData to use zero sized output (#19614) ### Description This PR allows zero-sized output. To make the implementation simple, it does not support partial zero-sized tensor. Which means, either all outputs are zero-sized, or an error will be reported. added 2 tests: - op test of `Add` with input T[2,0] T[2,1], and - test_split_zero_size_splits --- js/web/lib/wasm/jsep/backend-webgpu.ts | 32 ++++++++++++++++++++++---- js/web/lib/wasm/jsep/init.ts | 3 ++- js/web/lib/wasm/jsep/util.ts | 11 ++++++++- js/web/test/data/ops/add.jsonc | 22 ++++++++++++++++++ js/web/test/suite-test-list.jsonc | 2 +- js/web/test/test-runner.ts | 10 ++++++-- 6 files changed, 71 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 98990a6fe477b..3e3a191ec3ead 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -385,11 +385,16 @@ export class WebGpuBackend { // create info for inputs const inputDatas: GpuData[] = []; for (let i = 0; i < inputTensorViews.length; ++i) { - const gpuData = this.gpuDataManager.get(inputTensorViews[i].data); + const data = inputTensorViews[i].data; + // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. + if (data === 0) { + continue; + } + const gpuData = this.gpuDataManager.get(data); if (!gpuData) { - throw new Error(`no GPU data for input: ${inputTensorViews[i].data}`); + throw new Error(`no GPU data for input: ${data}`); } - inputDatas[i] = gpuData; + inputDatas.push(gpuData); } const {outputs, dispatchGroup, programUniforms} = program.getRunData(inputTensorViews); @@ -419,6 +424,11 @@ export class WebGpuBackend { const tensorView = (isTemporary || isPersistent) ? createIntermediateOutput(outputs[i].dataType, outputs[i].dims) : createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); + outputTensorViews.push(tensorView); + // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. + if (tensorView.data === 0) { + continue; + } const gpuData = this.gpuDataManager.get(tensorView.data); if (!gpuData) { throw new Error(`no GPU data for output: ${tensorView.data}`); @@ -434,10 +444,24 @@ export class WebGpuBackend { } persistentData.push(gpuData); } - outputTensorViews.push(tensorView); outputDatas.push(gpuData); } + // when there are any zero-sized tensor in the inputs or outputs, we should report error unless all outputs are + // zero-sized tensors. + if (inputDatas.length !== inputTensorViews.length || outputDatas.length !== outputTensorViews.length) { + // if all outputs are zero-sized tensors, there is no need to run the program. + if (outputDatas.length === 0) { + TRACE_FUNC_END(program.name); + return outputTensorViews; + } + // if some outputs are zero-sized tensors, report an error. + // + // TODO: so far we don't see any use case that outputs include both zero-sized tensors and non-zero-sized tensors. + // If we see such use case, we need to make a change here to support it. + throw new Error( + `Program ${program.name} has zero-sized tensor(s) in inputs or outputs. This is not supported now.`); + } // load uniforms // TODO: add cache for uniform (is it necessary?) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 786ae41646554..b64abf9cc5424 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -104,7 +104,8 @@ class ComputeContextImpl implements ComputeContext { throw new Error(`Unsupported data type: ${dataType}`); } const bufferSize = elementSize * ShapeUtil.size(dims); - return new TensorViewImpl(this.module, dataType, this.backend.gpuDataManager.create(bufferSize).id, dims); + const gpuDataId = bufferSize > 0 ? this.backend.gpuDataManager.create(bufferSize).id : 0; + return new TensorViewImpl(this.module, dataType, gpuDataId, dims); }; return this.backend.run(program, mappedInputs, outputIndices, createKernelOutput, createTemporaryOutput); } diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index c0517ce363644..9a1d5463f7843 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -56,7 +56,16 @@ export class BroadcastUtil { if (aLen !== bLen && aLen > 1 && bLen > 1) { return undefined; } - cdims[crank - i] = Math.max(aLen, bLen); + const max = Math.max(aLen, bLen); + if (aLen && bLen) { + cdims[crank - i] = Math.max(aLen, bLen); + } else { + // when either aLen or bLen is 0, the other should be either 0 or 1, otherwise it is not broadcastable. + if (max > 1) { + return undefined; + } + cdims[crank - i] = 0; + } } return cdims; diff --git a/js/web/test/data/ops/add.jsonc b/js/web/test/data/ops/add.jsonc index e5b4ff2b53148..dd15134861ef0 100644 --- a/js/web/test/data/ops/add.jsonc +++ b/js/web/test/data/ops/add.jsonc @@ -157,6 +157,28 @@ "type": "float32" } ] + }, + { + "name": "T[2,0] T[2,1]", + "inputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [2, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + } + ] } ] } diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index b43b1ac37e37d..88555a27be82e 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1231,7 +1231,7 @@ "test_split_variable_parts_1d", "test_split_variable_parts_2d", "test_split_variable_parts_default_axis", - // // "test_split_zero_size_splits", + "test_split_zero_size_splits", "test_sqrt_example", "test_sqrt", "test_squeeze_negative_axes", diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index a4adf5c4ce144..7c03e5b915fd7 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -573,7 +573,9 @@ export async function sessionRun(options: { // replace the CPU tensors in feeds into GPU tensors for (const name in feeds) { if (Object.hasOwnProperty.call(feeds, name)) { - feeds[name] = createGpuTensorForInput(feeds[name]); + if (feeds[name].size > 0) { + feeds[name] = createGpuTensorForInput(feeds[name]); + } } } } @@ -582,7 +584,11 @@ export async function sessionRun(options: { for (const name in options.outputsMetaInfo) { if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) { const {type, dims} = options.outputsMetaInfo[name]; - fetches[name] = createGpuTensorForOutput(type, dims); + if (dims.some(d => d === 0)) { + fetches[name] = new ort.Tensor(type, [], dims); + } else { + fetches[name] = createGpuTensorForOutput(type, dims); + } } } } From bb43a0f1338b05e93fcbbe5c5cb53ebf017625ba Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Fri, 23 Feb 2024 15:45:30 -0800 Subject: [PATCH 139/207] [js/webgpu] minor fixes to make tinyllama work (#19564) --- js/web/lib/wasm/jsep/webgpu/ops/concat.ts | 4 +++- js/web/lib/wasm/jsep/webgpu/ops/gather.ts | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index b06c9fb496d15..b142a82e551a7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -154,7 +154,9 @@ const createConcatProgramInfo = (inputs: readonly TensorView[], axis: number): P export const concat = (context: ComputeContext, attributes: ConcatAttributes): void => { validateInputs(context.inputs); - context.compute(createConcatProgramInfo(context.inputs, attributes.axis)); + // 0 length tensors are valid for concat, remove them + const nonEmptyInputs = context.inputs.filter(input => ShapeUtil.size(input.dims) > 0); + context.compute(createConcatProgramInfo(nonEmptyInputs, attributes.axis), {inputs: nonEmptyInputs}); }; export const parseConcatAttributes = (attributes: Record): ConcatAttributes => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 5c31e6dd86c00..d48bb909f7f8f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -55,7 +55,7 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath if (idx${x} < 0) { idx${x} = idx${x} + uniforms.axisDimLimit; } - var dataIndices${x} = ${data.type.indices}(0); + var dataIndices${x} : ${data.type.indices}; `; for (let i = 0, j = 0; i < inputRank; i++) { if (i === axis) { From 46c4d7fe4ad457d517fe92db7681c38849c51beb Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Fri, 23 Feb 2024 18:20:22 -0800 Subject: [PATCH 140/207] Disable gemm activation for non-float data types (#19612) ### Description Disable gemm activation for non-float data types ### Motivation and Context When a float16 model contains a Gemm+Relu subgraph, the gemm_activation_fusion will kick in and cause the two nodes to be eliminated and replaced with a FusedGemm. This however is only registered for the float data type. This causes model load failures. Disable the fusion for non-float data types. --------- Co-authored-by: Sheil Kumar --- onnxruntime/core/optimizer/gemm_activation_fusion.cc | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxruntime/core/optimizer/gemm_activation_fusion.cc b/onnxruntime/core/optimizer/gemm_activation_fusion.cc index c62887da09fdc..50be2cbd48f7b 100644 --- a/onnxruntime/core/optimizer/gemm_activation_fusion.cc +++ b/onnxruntime/core/optimizer/gemm_activation_fusion.cc @@ -56,6 +56,13 @@ Status GemmActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } + NodeArg* node_output = node.MutableOutputDefs()[0]; + auto data_type = node_output->TypeAsProto()->tensor_type().elem_type(); + if (data_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + // FusedGemm is only registered for float data type in fused_gemm.cc! + continue; + } + const Node& next_node = *(node.OutputNodesBegin()); if (!IsFusableActivation(next_node) || next_node.GetExecutionProviderType() != node.GetExecutionProviderType()) { continue; From c12a20bef95df5437189687b94e7ba2f1bad1505 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 24 Feb 2024 14:06:30 +1000 Subject: [PATCH 141/207] Add helper to run CIs for a branch using `az pipelines`. (#16843) ### Description Add helper to run CIs for a branch using `az pipelines`. This can be used to easily kick off multiple CIs for a branch prior to creating a PR. Update run_CIs_for_external_pr.py so the CI list can be shared. Request json output from `gh pr view` so the current state is more easily parsed. ### Motivation and Context --- tools/python/run_CIs_for_branch.py | 116 +++++++++++++++++++++++ tools/python/run_CIs_for_external_pr.py | 120 +++++++++++++----------- 2 files changed, 181 insertions(+), 55 deletions(-) create mode 100644 tools/python/run_CIs_for_branch.py diff --git a/tools/python/run_CIs_for_branch.py b/tools/python/run_CIs_for_branch.py new file mode 100644 index 0000000000000..c507cae0d9f43 --- /dev/null +++ b/tools/python/run_CIs_for_branch.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import json +import os +import subprocess +import sys +import typing + +from run_CIs_for_external_pr import get_pipeline_names +from util.platform_helpers import is_windows + + +def _parse_args(): + parser = argparse.ArgumentParser( + os.path.basename(__file__), + formatter_class=argparse.RawDescriptionHelpFormatter, + description="""Run the CIs used to validate PRs for the specified branch. + + If specified, the `--include` filter is applied first, followed by any `--exclude` filter. + + Requires the Azure CLI with DevOps extension to be installed. + Azure CLI: https://learn.microsoft.com/en-us/cli/azure/install-azure-cli + DevOps extension: https://github.com/Azure/azure-devops-cli-extension + + Configuration: + Login:`az login` + Configure ORT repo as default: + `az devops configure --defaults organization=https://dev.azure.com/onnxruntime project=onnxruntime` + + Example usage: + List all CIs + `python run_CIs_for_branch.py --dry-run my/BranchName` + Run all CIs + `python run_CIs_for_branch.py my/BranchName` + Run only Linux CIs + `python run_CIs_for_branch.py --include linux my/BranchName` + Exclude training CIs + `python run_CIs_for_branch.py --exclude training my/BranchName` + Run non-training Linux CIs + `python run_CIs_for_branch.py --include linux --exclude training my/BranchName` + """, + ) + + parser.add_argument("-i", "--include", type=str, help="Include CIs that match this string. Case insensitive.") + parser.add_argument("-e", "--exclude", type=str, help="Exclude CIs that match this string. Case insensitive.") + parser.add_argument("--dry-run", action="store_true", help="Print selected CIs but do not run them.") + parser.add_argument("branch", type=str, help="Specify the branch to run.") + + args = parser.parse_args() + return args + + +def _run_az_pipelines_command(command: typing.List[str]): + try: + az = "az.cmd" if is_windows() else "az" + az_output = subprocess.run([az, "pipelines", *command], capture_output=True, text=True, check=True) + except subprocess.CalledProcessError as cpe: + print(cpe) + print(cpe.stderr) + sys.exit(-1) + + return az_output + + +def main(): + args = _parse_args() + branch = args.branch + + # To debug available pipelines: + # az_out = az_pipelines = _run_az_pipelines_command(["list"]) + # pipeline_info = json.loads(az_out.stdout) + # print(pipeline_info) + + pipelines = get_pipeline_names() + pipelines_to_run = [] + if args.include: + value = args.include.lower().strip() + for p in pipelines: + if value in p.lower(): + print(f"Including {p}") + pipelines_to_run.append(p) + else: + pipelines_to_run = pipelines + + if args.exclude: + value = args.exclude.lower().strip() + cur_pipelines = pipelines_to_run + pipelines_to_run = [] + for p in cur_pipelines: + if value in p.lower(): + print(f"Excluding {p}") + else: + pipelines_to_run.append(p) + + print("Pipelines to run:") + for p in pipelines_to_run: + print(f"\t{p}") + + if args.dry_run: + sys.exit(0) + + for pipeline in pipelines_to_run: + az_out = _run_az_pipelines_command(["run", "--branch", branch, "--name", pipeline]) + run_output = json.loads(az_out.stdout) + if "id" in run_output: + build_url = f"https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId={run_output['id']}" + print(f"{pipeline} build results: {build_url}&view=results") + else: + raise ValueError("Build id was not found in az output:\n" + run_output) + + +if __name__ == "__main__": + main() diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index df4e70b1e51fe..dcafe898b3bdf 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -3,13 +3,54 @@ # Licensed under the MIT License. import argparse +import json import os import subprocess import sys import typing -def parse_args(): +def get_pipeline_names(): + # Current pipelines. These change semi-frequently and may need updating. + # There is no easy way to get the list of "required" pipelines using `azp` before they are run, + # so we need to maintain this list manually. + # NOTE: This list is also used by run_CIs_for_branch.py + pipelines = [ + # windows + "Windows ARM64 QNN CI Pipeline", + "Windows x64 QNN CI Pipeline", + "Windows CPU CI Pipeline", + "Windows GPU CI Pipeline", + "Windows GPU TensorRT CI Pipeline", + "ONNX Runtime Web CI Pipeline", + # linux + "Linux CPU CI Pipeline", + "Linux CPU Minimal Build E2E CI Pipeline", + "Linux GPU CI Pipeline", + "Linux GPU TensorRT CI Pipeline", + "Linux OpenVINO CI Pipeline", + "Linux QNN CI Pipeline", + # mac + "MacOS CI Pipeline", + # training + "orttraining-amd-gpu-ci-pipeline", + "orttraining-linux-ci-pipeline", + "orttraining-linux-gpu-ci-pipeline", + "orttraining-ortmodule-distributed", + # checks + "onnxruntime-binary-size-checks-ci-pipeline", + # big models + "Big Models", + # not currently required, but running ensures we're hitting all mobile platforms + "Android CI Pipeline", + "iOS CI Pipeline", + "ONNX Runtime React Native CI Pipeline", + ] + + return pipelines + + +def _parse_args(): parser = argparse.ArgumentParser( os.path.basename(__file__), formatter_class=argparse.RawDescriptionHelpFormatter, @@ -25,7 +66,7 @@ def parse_args(): return args -def run_gh_pr_command(command: typing.List[str], check=True): +def run_gh_pr_command(command: typing.List[str], check: bool = True): try: return subprocess.run(["gh", "pr", *command], capture_output=True, text=True, check=check) except subprocess.CalledProcessError as cpe: @@ -35,23 +76,25 @@ def run_gh_pr_command(command: typing.List[str], check=True): def main(): - args = parse_args() + args = _parse_args() pr_id = args.pr # validate PR - gh_out = run_gh_pr_command(["view", pr_id]) - info = gh_out.stdout.split("\n") - for line in info: - pieces = line.split("\t") - if len(pieces) != 2: - continue - - if pieces[0] == "state:": - if pieces[1] != "OPEN": - print(f"PR {pr_id} is not OPEN. Currently in state {pieces[1]}.") - sys.exit(-1) - - print("Check passed pipelines") + print("Checking PR is open") + gh_out = run_gh_pr_command(["view", "--json", "state", pr_id]) + info = json.loads(gh_out.stdout) + if "state" not in info: + print(f"Could not get current state from `gh pr view` response of\n{gh_out.stdout}") + sys.exit(-1) + + if info["state"] != "OPEN": + print(f"PR {pr_id} is not OPEN. Currently in state {info['state']}.") + sys.exit(0) + + # This will return CIs that have run previously but not passed. We filter the CIs to run based on this, so it's + # fine for the initial response to have no info in it. + # `gh pr checks` exits with non-zero exit code when failures in pipeline exist, so we set `check` to False. + print("Checking for pipelines that have passed.") gh_out = run_gh_pr_command(["checks", pr_id, "--required"], check=False) # output format is a tab separated list of columns: # (pipeline name) "\t" (status) "\t" (ran time) "\t" (url) @@ -61,54 +104,21 @@ def main(): if len(columns) == 4 and columns[1] == "pass" ] - print("Adding azp run commands") - - # Current pipelines. These change semi-frequently and may need updating. - # - # Note: there is no easy way to get the list for azp "required" pipelines before they starts. - # we need to maintain this list manually. - # - pipelines = [ - # windows - "Windows ARM64 QNN CI Pipeline", - "Windows x64 QNN CI Pipeline", - "Windows CPU CI Pipeline", - "Windows GPU CI Pipeline", - "Windows GPU TensorRT CI Pipeline", - "ONNX Runtime Web CI Pipeline", - # linux - "Linux CPU CI Pipeline", - "Linux CPU Minimal Build E2E CI Pipeline", - "Linux GPU CI Pipeline", - "Linux GPU TensorRT CI Pipeline", - "Linux OpenVINO CI Pipeline", - "Linux QNN CI Pipeline", - # mac - "MacOS CI Pipeline", - # training - "orttraining-amd-gpu-ci-pipeline", - "orttraining-linux-ci-pipeline", - "orttraining-linux-gpu-ci-pipeline", - "orttraining-ortmodule-distributed", - # checks - "onnxruntime-python-checks-ci-pipeline", - "onnxruntime-binary-size-checks-ci-pipeline", - # big models - "Big Models", - # not currently required, but running ensures we're hitting all mobile platforms - "Android CI Pipeline", - "iOS CI Pipeline", - "ONNX Runtime React Native CI Pipeline", - ] + pipelines = get_pipeline_names() # remove pipelines that have already run successfully pipelines = [p for p in pipelines if p not in checked_pipelines] + print("Pipelines to run:") + for p in pipelines: + print("\t" + p) + # azp run is limited to 10 pipelines at a time max_pipelines_per_comment = 10 start = 0 num_pipelines = len(pipelines) + print("Adding azp run commands") while start < num_pipelines: end = start + max_pipelines_per_comment if end > num_pipelines: From 9ccdc4961ad76355289ed3a36ccb8307e8dc7789 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 23 Feb 2024 22:31:57 -0800 Subject: [PATCH 142/207] Stop using apiset in OneCore build: use onecoreuap.lib instead of onecoreuap_apiset.lib (#19632) ### Description Stop using apiset in OneCore build: use onecoreuap.lib instead of onecoreuap_apiset.lib in onecore build. ### Motivation and Context 1. Now all Windows Editions come with Reverse Forwarders. We should just use the normal onecore libs. 2. Many new Windows APIs are only available in [windows umbrella libraries](https://learn.microsoft.com/en-us/windows/win32/apiindex/windows-umbrella-libraries). So these libraries are not specific for Windows CoreOS or Onecore. 3. Going forward we should use "IsApiSetImplemented" to guard our API usages: https://learn.microsoft.com/en-us/windows/win32/apiindex/detect-api-set-availability . After this change, our built binaries can pass apivalidator's check. ``` C:\local\apivalidator>apivalidator.exe -BinaryPath:C:\src\onnxruntime\b\Debug\Debug\onnxruntime.dll -SupportedApiXmlFiles:onecoreuap_DDIs.xml ApiValidation: Summary: "C:\src\onnxruntime\b\Debug\Debug\onnxruntime.dll" is Universal ApiValidation: All binaries are Universal ``` So it will give an easy way to test ONNX Runtime's compatibility to Windows versions. --- cmake/CMakeLists.txt | 6 ++---- cmake/wcos_rules_override.cmake | 4 ++-- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index c9be4aa65d0cc..ed9043f2adc4a 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1729,14 +1729,12 @@ if(onnxruntime_BUILD_KERNEL_EXPLORER) endif() # When GDK_PLATFORM is set then WINAPI_FAMILY is defined in gdk_toolchain.cmake (along with other relevant flags/definitions). -if (WIN32 AND NOT GDK_PLATFORM) +if (WIN32 AND NOT GDK_PLATFORM AND NOT CMAKE_CROSSCOMPILING) if (NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) # On onecore, link to the onecore build of the MSVC runtime get_filename_component(msvc_path "${CMAKE_C_COMPILER}/../../../.." ABSOLUTE) link_directories(BEFORE "${msvc_path}/lib/onecore/${onnxruntime_target_platform}") - # The .lib files in the MSVC runtime have a DEFAULITLIB entry for onecore.lib, which in turn links to reverse forwarders. - # We ignore that entry and use onecore_apiset.lib instead, since system components must not rely on reverse forwarders. - add_link_options("/NODEFAULTLIB:onecore.lib") + # The .lib files in the MSVC runtime have a DEFAULITLIB entry for onecore.lib, but it shold not cause any conflict with onecoreuap.lib endif() endif() diff --git a/cmake/wcos_rules_override.cmake b/cmake/wcos_rules_override.cmake index f3d8093629a42..ec2303b073d5e 100644 --- a/cmake/wcos_rules_override.cmake +++ b/cmake/wcos_rules_override.cmake @@ -1,2 +1,2 @@ -set(CMAKE_C_STANDARD_LIBRARIES_INIT onecoreuap_apiset.lib) -set(CMAKE_CXX_STANDARD_LIBRARIES_INIT onecoreuap_apiset.lib) +set(CMAKE_C_STANDARD_LIBRARIES_INIT onecoreuap.lib) +set(CMAKE_CXX_STANDARD_LIBRARIES_INIT onecoreuap.lib) From 0edb03580823c9d9e97ba1a6ea941fcd70a2500b Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Sat, 24 Feb 2024 10:09:07 -0800 Subject: [PATCH 143/207] [js/web] fix suite test list for zero sized tensor (#19638) ### Description Fixes build break brought by #19614 Currently WebGL backend does not support zero sized tensor. This change split test data into 2 parts, and only enable zero sized tensor tests for WebGPU. --- js/web/test/data/ops/add.jsonc | 22 - js/web/test/data/ops/add_zero-sized.jsonc | 31 + js/web/test/data/ops/concat_zero-sized.jsonc | 561 +++++++++++++++++++ js/web/test/suite-test-list.jsonc | 2 + 4 files changed, 594 insertions(+), 22 deletions(-) create mode 100644 js/web/test/data/ops/add_zero-sized.jsonc create mode 100644 js/web/test/data/ops/concat_zero-sized.jsonc diff --git a/js/web/test/data/ops/add.jsonc b/js/web/test/data/ops/add.jsonc index dd15134861ef0..e5b4ff2b53148 100644 --- a/js/web/test/data/ops/add.jsonc +++ b/js/web/test/data/ops/add.jsonc @@ -157,28 +157,6 @@ "type": "float32" } ] - }, - { - "name": "T[2,0] T[2,1]", - "inputs": [ - { - "data": [], - "dims": [2, 0], - "type": "float32" - }, - { - "data": [1, 2], - "dims": [2, 1], - "type": "float32" - } - ], - "outputs": [ - { - "data": [], - "dims": [2, 0], - "type": "float32" - } - ] } ] } diff --git a/js/web/test/data/ops/add_zero-sized.jsonc b/js/web/test/data/ops/add_zero-sized.jsonc new file mode 100644 index 0000000000000..37e08cd7f20ac --- /dev/null +++ b/js/web/test/data/ops/add_zero-sized.jsonc @@ -0,0 +1,31 @@ +[ + { + "name": "Add with no attributes", + "operator": "Add", + "attributes": [], + "cases": [ + { + "name": "T[2,0] T[2,1]", + "inputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [2, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [], + "dims": [2, 0], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/concat_zero-sized.jsonc b/js/web/test/data/ops/concat_zero-sized.jsonc new file mode 100644 index 0000000000000..7be8e8c1cc602 --- /dev/null +++ b/js/web/test/data/ops/concat_zero-sized.jsonc @@ -0,0 +1,561 @@ +[ + { + "name": "Concat 2D axis=0", + "operator": "Concat", + "attributes": [{ "name": "axis", "data": -2, "type": "int" }], + "cases": [ + { + "name": "X", + "inputs": [ + { + "data": [], + "dims": [1, 4, 0, 64], + "type": "float32" + }, + { + "data": [ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], + "dims": [1, 4, 36, 64], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 + ], + "dims": [1, 4, 36, 64], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 88555a27be82e..e96a0aa045bc8 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1334,6 +1334,7 @@ "acos.jsonc", "add.jsonc", "add_int32.jsonc", + "add_zero-sized.jsonc", //"and.jsonc", "asin.jsonc", "attention.jsonc", @@ -1343,6 +1344,7 @@ "ceil.jsonc", "concat.jsonc", "concat_int32.jsonc", + "concat_zero-sized.jsonc", "cast.jsonc", "conv.jsonc", "cos.jsonc", From c980149c857facc2463668a11944af3c6c12365b Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Sun, 25 Feb 2024 05:00:53 +0800 Subject: [PATCH 144/207] Add log for random exception in Linux GPU Test Stage. (#19569) ### Description 1. check GPU status in docker 2. use stages to make test stage can leverage existing building artifacts ### Motivation and Context To investigate the root cause of the random exception `CUDA failure 100: no CUDA-capable device is detected` --- .../azure-pipelines/linux-gpu-ci-pipeline.yml | 351 ++++++++++-------- 1 file changed, 198 insertions(+), 153 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 24319184dd0b8..822bc559d992d 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -34,6 +34,17 @@ parameters: values: - 11.8 - 12.2 + + - name: SpecificArtifact + displayName: Use Specific Artifact + type: boolean + default: false + + - name: BuildId + displayName: Specific Artifact's BuildId + type: string + default: '0' + resources: repositories: - repository: manylinux @@ -61,163 +72,197 @@ variables: ${{ if eq(parameters.CudaVersion, '12.2') }}: value: 'onnxruntimecuda12build' -jobs: -- job: Linux_Build - timeoutInMinutes: 120 - variables: - skipComponentGovernanceDetection: true - CCACHE_DIR: $(Pipeline.Workspace)/ccache - workspace: - clean: all - pool: onnxruntime-Ubuntu2204-AMD-CPU - - steps: - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() - - - checkout: self - clean: true - submodules: none - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: " - --network=host - --build-arg BASEIMAGE=$(docker_base_image) - --build-arg TRT_VERSION=$(linux_trt_version) - --build-arg BUILD_UID=$( id -u ) - " - Repository: $(Repository) - - - task: Cache@2 - inputs: - key: '"ccache" | "${{parameters.CudaVersion}}" |"$(Build.SourceBranch)" | "$(Build.SourceVersion)"' - path: $(CCACHE_DIR) - restoreKeys: | - "ccache" | "${{parameters.CudaVersion}}" | "$(Build.SourceBranch)" - "ccache" - cacheHitVar: CACHE_RESTORED - displayName: Cach Task - - - script: | - sudo mkdir -p $(Pipeline.Workspace)/ccache - condition: ne(variables.CACHE_RESTORED, 'true') - displayName: Create Cache Dir - - - script: | - set -e -x - mkdir -p $HOME/.onnx - docker run -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" --rm \ - --volume /data/onnx:/data/onnx:ro \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume /data/models:/build/models:ro \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ - --volume $(Pipeline.Workspace)/ccache:/cache \ - -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ - -e NIGHTLY_BUILD \ - -e BUILD_BUILDNUMBER \ - -e CCACHE_DIR=/cache \ - $(Repository) \ - /bin/bash -c " - set -ex; \ - env; \ - ccache -s; \ - /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build --cmake_generator Ninja \ - --config Release --update --build \ - --skip_submodule_sync \ - --build_shared_lib \ - --parallel --use_binskim_compliant_compile_flags \ - --build_wheel \ - --enable_onnx_tests --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda-${{parameters.CudaVersion}} --cudnn_home=/usr/local/cuda-${{parameters.CudaVersion}} \ - --enable_cuda_profiling --enable_cuda_nhwc_ops \ - --enable_pybind --build_java \ - --use_cache \ - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86; \ - ccache -sv; \ - ccache -z" - workingDirectory: $(Build.SourcesDirectory) - displayName: Build Onnxruntime - - - task: CmdLine@2 - inputs: - script: | - rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 - rm -f $(Build.BinariesDirectory)/Release/models - find $(Build.BinariesDirectory)/Release/_deps -mindepth 1 ! -regex '^$(Build.BinariesDirectory)/Release/_deps/onnx-src\(/.*\)?' -delete - cd $(Build.BinariesDirectory)/Release - find -executable -type f > $(Build.BinariesDirectory)/Release/perms.txt - - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact' - inputs: - artifactName: 'drop-linux' - targetPath: '$(Build.BinariesDirectory)/Release' - - - template: templates/explicitly-defined-final-tasks.yml - -- job: Linux_Test - timeoutInMinutes: 180 - variables: - skipComponentGovernanceDetection: true - workspace: - clean: all - pool: onnxruntime-Linux-GPU-A10 - dependsOn: - - Linux_Build - steps: - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - buildType: 'current' - artifactName: 'drop-linux' - targetPath: '$(Build.BinariesDirectory)/Release' - - - checkout: self - clean: true - submodules: none - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - Context: tools/ci_build/github/linux/docker - DockerBuildArgs: " - --network=host - --build-arg BASEIMAGE=$(docker_base_image) - --build-arg TRT_VERSION=$(linux_trt_version) - --build-arg BUILD_UID=$( id -u ) - " - Repository: $(Repository) - - - task: CmdLine@2 - inputs: - script: | +stages: +- stage: Linux_Build + jobs: + - job: Linux_Build + timeoutInMinutes: 120 + variables: + skipComponentGovernanceDetection: true + CCACHE_DIR: $(Pipeline.Workspace)/ccache + workspace: + clean: all + pool: onnxruntime-Ubuntu2204-AMD-CPU + + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: none + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=$(docker_base_image) + --build-arg TRT_VERSION=$(linux_trt_version) + --build-arg BUILD_UID=$( id -u ) + " + Repository: $(Repository) + + - task: Cache@2 + inputs: + key: '"ccache" | "${{parameters.CudaVersion}}" |"$(Build.SourceBranch)" | "$(Build.SourceVersion)"' + path: $(CCACHE_DIR) + restoreKeys: | + "ccache" | "${{parameters.CudaVersion}}" | "$(Build.SourceBranch)" + "ccache" + cacheHitVar: CACHE_RESTORED + displayName: Cach Task + + - script: | + sudo mkdir -p $(Pipeline.Workspace)/ccache + condition: ne(variables.CACHE_RESTORED, 'true') + displayName: Create Cache Dir + + - script: | set -e -x mkdir -p $HOME/.onnx - docker run --gpus all --rm \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory)/Release:/build/Release \ + docker run -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ --volume /data/models:/build/models:ro \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ - --volume /data/onnx:/data/onnx \ - -e NVIDIA_TF32_OVERRIDE=0 \ + --volume $(Pipeline.Workspace)/ccache:/cache \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + -e CCACHE_DIR=/cache \ $(Repository) \ /bin/bash -c " set -ex; \ - cp /onnxruntime_src/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt /tmp/requirements.txt; \ - ln -s /opt/python/cp38-cp38/bin/python3 /tmp/python3; \ - /tmp/python3 -m pip install -r /tmp/requirements.txt; \ - /tmp/python3 -m pip install /build/Release/dist/*.whl; \ - cd /build/Release && xargs -a /build/Release/perms.txt chmod a+x; \ - cd /onnxruntime_src/java && /onnxruntime_src/java/gradlew cmakeCheck -DcmakeBuildDir=/build/Release -DUSE_CUDA=1; \ - cd /tmp; \ - /tmp/python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build --config Release --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests \ - --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda --cudnn_home=/usr/local/cuda \ - --enable_pybind --build_java --ctest_path '' " - - - template: templates/clean-agent-build-directory-step.yml + env; \ + ccache -s; \ + /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ + --build_dir /build --cmake_generator Ninja \ + --config Release --update --build \ + --skip_submodule_sync \ + --build_shared_lib \ + --parallel --use_binskim_compliant_compile_flags \ + --build_wheel \ + --enable_onnx_tests --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda-${{parameters.CudaVersion}} --cudnn_home=/usr/local/cuda-${{parameters.CudaVersion}} \ + --enable_cuda_profiling --enable_cuda_nhwc_ops \ + --enable_pybind --build_java \ + --use_cache \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86; \ + ccache -sv; \ + ccache -z" + workingDirectory: $(Build.SourcesDirectory) + displayName: Build Onnxruntime + + - task: CmdLine@2 + inputs: + script: | + rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 + rm -f $(Build.BinariesDirectory)/Release/models + find $(Build.BinariesDirectory)/Release/_deps -mindepth 1 ! -regex '^$(Build.BinariesDirectory)/Release/_deps/onnx-src\(/.*\)?' -delete + cd $(Build.BinariesDirectory)/Release + find -executable -type f > $(Build.BinariesDirectory)/Release/perms.txt + + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline Artifact' + inputs: + artifactName: 'drop-linux' + targetPath: '$(Build.BinariesDirectory)/Release' + + - template: templates/explicitly-defined-final-tasks.yml + +- stage: Linux_Test + dependsOn: + - Linux_Build + jobs: + - job: Linux_Test + timeoutInMinutes: 180 + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: onnxruntime-Linux-GPU-A10 + steps: + - checkout: self + clean: true + submodules: none + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + ArtifactName: 'drop-linux' + StepName: 'Download Pipeline Artifact - Linux Build' + TargetPath: '$(Build.BinariesDirectory)/Release' + SpecificArtifact: ${{ parameters.SpecificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=$(docker_base_image) + --build-arg TRT_VERSION=$(linux_trt_version) + --build-arg BUILD_UID=$( id -u ) + " + Repository: $(Repository) + + - task: CmdLine@2 + inputs: + script: | + set -e -x + mkdir -p $HOME/.onnx + docker run --gpus all --rm \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory)/Release:/build/Release \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + --volume /data/onnx:/data/onnx \ + -e NVIDIA_TF32_OVERRIDE=0 \ + $(Repository) \ + /bin/bash -c ' + nvidia-smi; \ + /sbin/ldconfig -N -v $(sed "s/:/ /" <<< $LD_LIBRARY_PATH) 2>/dev/null | grep -E "libcudart.so|libcudnn.so|libnvinfer.so"; \ + cat /usr/local/cuda/include/cuda.h | grep -m1 CUDA_VERSION; \ + cat /usr/include/cudnn_version.h | grep CUDNN_MAJOR -m1 -A 2; \ + ln -s /opt/python/cp38-cp38/bin/python3 /tmp/python3; \ + /tmp/python3 -m pip install /build/Release/dist/*.whl; \ + /tmp/python3 -u -c "from onnxruntime.capi._pybind_state import (OrtDevice as C_OrtDevice) ; \ + ort_device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0); \ + print(ort_device); print(ort_device.device_type(), C_OrtDevice.cuda()); \ + assert(ort_device.device_type()==1); assert(C_OrtDevice.cuda()==1);" \ + ' + displayName: 'Check GPU' + + - task: CmdLine@2 + inputs: + script: | + set -e -x + mkdir -p $HOME/.onnx + docker run --gpus all --rm \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory)/Release:/build/Release \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + --volume /data/onnx:/data/onnx \ + -e NVIDIA_TF32_OVERRIDE=0 \ + $(Repository) \ + /bin/bash -c ' + set -ex; \ + cp /onnxruntime_src/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt /tmp/requirements.txt; \ + ln -s /opt/python/cp38-cp38/bin/python3 /tmp/python3; \ + /tmp/python3 -m pip install -r /tmp/requirements.txt; \ + /tmp/python3 -m pip install /build/Release/dist/*.whl; \ + cd /build/Release && xargs -a /build/Release/perms.txt chmod a+x; \ + cd /onnxruntime_src/java && /onnxruntime_src/java/gradlew cmakeCheck -DcmakeBuildDir=/build/Release -DUSE_CUDA=1; \ + cd /tmp; \ + /tmp/python3 /onnxruntime_src/tools/ci_build/build.py \ + --build_dir /build --config Release --test --skip_submodule_sync --build_shared_lib --parallel --use_binskim_compliant_compile_flags --build_wheel --enable_onnx_tests \ + --use_cuda --cuda_version=${{parameters.CudaVersion}} --cuda_home=/usr/local/cuda --cudnn_home=/usr/local/cuda \ + --enable_pybind --build_java --ctest_path "" ; \ + ' + displayName: 'Run Tests' + + - template: templates/clean-agent-build-directory-step.yml From 0fcc6fb7601893bd1e2b53baea4436a7a51b7f8d Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Sun, 25 Feb 2024 14:04:22 +0800 Subject: [PATCH 145/207] Add Whisper model in CI (#19604) ### Description Add Whisper Conversion and E2E into Big Models pipeline ### Motivation and Context --------- Co-authored-by: Your Name Co-authored-by: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> --- .../tools/transformers/benchmark_helper.py | 4 +- .../transformers/models/whisper/benchmark.py | 3 +- .../models/whisper/requirements.txt | 5 +- .../models/whisper/test/1272-141231-0002.mp3 | Bin 0 -> 92124 bytes .../whisper/test/whisper_ort_output.txt | 1 + .../azure-pipelines/bigmodels-ci-pipeline.yml | 101 +++++++++++++++++- .../docker/Dockerfile.package_ubuntu_2004_gpu | 9 +- 7 files changed, 115 insertions(+), 8 deletions(-) create mode 100644 onnxruntime/python/tools/transformers/models/whisper/test/1272-141231-0002.mp3 create mode 100644 onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index c7d93470a729e..c9c815f01e053 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -589,7 +589,7 @@ def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None): if max_usage is None: return None - print(f"GPU memory usage: before={memory_before_test} peak={max_usage}") + logger.info(f"GPU memory usage: before={memory_before_test} peak={max_usage}") if len(memory_before_test) >= 1 and len(max_usage) >= 1 and len(memory_before_test) == len(max_usage): # When there are multiple GPUs, we will check the one with maximum usage. max_used = 0 @@ -620,7 +620,7 @@ def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None): monitor.keep_measuring = False max_usage = mem_thread.result() - print(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB") + logger.info(f"CPU memory usage: before={memory_before_test:.1f} MB, peak={max_usage:.1f} MB") return max_usage - memory_before_test diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index e57385aa6db8f..11e596cadc2cb 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -410,7 +410,8 @@ def handle_output(output): actual_output = handle_output(ort_outputs[0][0]) logger.info(f"Generated token length: {len(actual_output)} tokens") transcription = args.processor.batch_decode(ort_outputs[0], skip_special_tokens=True)[0] - logger.info(f"Transcription: {transcription}") + # print to stdout as the output for comparison + print(f"{transcription}") measure_fn(args, generate_fn, ort_inputs) diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index c307a3665f8a0..956922dc83d51 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -8,4 +8,7 @@ librosa optimum onnxruntime-extensions>=0.9.0 protobuf==3.20.2 -numpy==1.23.3 \ No newline at end of file +numpy==1.23.3 +onnx>=1.15.0 +psutil +py3nvml diff --git a/onnxruntime/python/tools/transformers/models/whisper/test/1272-141231-0002.mp3 b/onnxruntime/python/tools/transformers/models/whisper/test/1272-141231-0002.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..6d220f5ede6a7c54893b1dda32b7876c31059fcf GIT binary patch literal 92124 zcmce-^;?wh^FF-0G%QH7^wJGWr%E?ScP!lqh@{}sEse{9G)Q*{NQbm^igdTUKndZI z_xJex3(x+tdF(ycTr=mK*UWi8UPku={_kPr#qnOa)e*gHD8x_fxN@%0Z34hf5hj*b72oR)#g z&MPebTvl1r(A3)A(cL%jZFpjGdTw#~$J*xh{?Y0Af7ds+_YeQBr=>2Vp&-c14}+m= z{hujh*vx223;^KzlO2C2YX<-Ci~paWAD{d~AdTs4@$bUGGEhrU%fcZx#YZn(Qjbp# z+bpz%_%8`9uqmFMXfc%TZQ{JN&JHT@?Rj>YE%-S9$np5dYx(#WhuLU$!KHdZL^`in zyZ7C+3^Y_PWm1iq^;?G6W-ByH+mMe;DMxXFqe@^?hqK5s!Nv4mr2g7ZxES5E0 z?NEjUr%XYGKhD))YNq4gr?urAXd82^MyNG4|}JMwWF8?Ce8$IF;>*A9Wv zBO*fem^5K%ByHKu9aDCr$v!t;j=LgvtUE{iet)?e_E=bS2%`Oe0prw0UGy!!{WV;B zyj1RArUxzCTd2f-)hDcZ8DI>^qE~qN`n8ljx37i7Y4b_mEw@6<+fN)MIlGA#cW%h< zZJI;zj_I8E<@(=FPnRDf{{AfYcHfykYFaS)YAdf)toqyCTm`*Wncy$PW&Klu_Ju6S z9HQ+wp*+YhCcy;_9bD3%R)mr<>fgPJNi88u0cSx};5So^i^B{gPxfrE9NC*#_$OT% zXE+f>NnwZW4#l?K)rTf7cYX5(W&ZbEqo@P2=X1CW6}4GTZQ9*`A22K5f59jEMlV#G zcvg2Vr5!R-GT-F(_WUUZ**_LR*A+WH$fP27Ic1t<2!`*Zz|myZ2o9R0)=7QI^U(!T_}CV5q4btAjC zuP@0_U1rYNF5i09lC3{NF6c9AdXX=vK3~y-TAs?+XdwI2IeFZDaFMExP7}Mo;f^1( zzs$o+>x^xj*%E3|4osEbvdpiKNk}s~bax8QkKS8MCi@b|sZ`m>=H?wl<4}FU>$W3S z&@h46;EjJKCrI_m9@D`K97Ts?S=l`PdrjF^U+6jJZfqtCx+7DgnhM%Mv0>tiDx#i( zqiPwW+qv`RtyhwA11;unkN&!n+x^ax2$fL#6j~f6m!RZen4Lu`f3; zZl{&#eIwQM`>ws_efC?PK=&H|Y9szYXqP~>mGdgml`zk1_qfbAJa>^^GZ7AEcTGZi z7FpU_E4Rnf?Sq*71lR8?ncYEO-cX%Q&IIe}wST%nIX2t3U_P`XVzG~^nqRg29QyLX zPq0(FE!oE?rSz@Ql6rl3kTky@)-5}?L{T(h5v?HWsQAZi_beq5e$kHwI^I1^wp@;S zMbwNTar)7$8d6U-AG2?6T3c9u+xZ#Cp(4xuIgTOAqg97w@|6;$M4ECQL*j4&>3><# zK8JN}J~6KQmu1~5@jSC5tSnd(SFR#DzEMzxKS+yOwp^l{l5oTNW-X zS+!uOnZFe`QV|%=*owMdTs+N#*<$>dP*{K=6>uLJ8daXWW_0Q*{_lRUJtu?M=4Sp~)O-w=`*#AW=f^I>F;DS)#-UqoZ4scgs;#wUTf$?K znk*T=Oxip+2OS`d89L)w&g=U1U=ubtuUex4>Z9vXvXNy%X) zd-ytmqbr3!NN%iMglZ8bV%6tS zGbW}|y`amq<3F%Y67g=D*8VW~b@%U8_sdG&U8XVW352GP#mrz>6-=5O(_Ro?(!#_r`V}aogb{R zs^+i0jJD3Yo*drjRl0E(k5m2UzZpp;k3mAnAxhjuvAo()9F@a}!ziuFPN>P&)WTKX zhV}vVjn;O=(-W*R+*o@U*_g8&RZ|xYZj0jgy|)Cv?fYZL^z7~2XwKZ2O6@7#+%wX~ zJ-k_d)Wxg9Twz$pveM^Hzu<+wTpCN!PIq0@aPfA+M{;@9I4?EkWua*L4bs8E@Y;$c zoUu6WXS%ZiCwEMEPEWY8z;icm|KR`8>17nIP+<87?~1D{UF!fv+qcNnv%uejX4tW^94k-Il{ zX!tF5UTw~{+7euj?UR=A{2%UpRg3AX?{Rb9Rbdbz{WuPG{rV--x%OsjWzg&;eR{UM z_XstPjo~HvONp+X~O=C&qi+;yhsO#>KYtR4|wEx>UA3@1;%qwl1+)ox`x8)&RJPs ztu6V{AO#AgiS&JgnD5P78KEqiRQ-o18OQ37-7&U3U(uYH#1^@Nr^++}~t zs%IC8wYD^sGs>Co=T@&r?=t9z!}IQJOPMJ3%)%Y(1a>tv#TZz6eQ291pKRXdKs{Jk z*bx5RzA|jPFKD+qcdJM#sxZ;+4_XLH%qWkHZ`7@k4Q=07=oi-fQ72WM0{d@QpVJNW zyZm@YTWUaeMcp92&U>yOpI<5Q6J}q$@|}xM@T&de)E~%;vWfzW6Uu@eg_#_E-b^~H z=r{{WV=a-m3>EW)6t-?@5`AYH9S!+>xN8Wg&iJq^qPMKWU5bT9rxiuHIrfSgFT?CA zJX+mW>$h-cf1Uit58HSD&*nF$*1-W2)xwf+75m=~_DFNem`T9?kJmkp99 zUH)v`MRqV`zg_nvGcgXjA4&lE2go9auP~j^!&j&|M?p zIklXgM^Biib(MG%F%TVr8{!!PM`IeNcG1BMrjod|9|S?^;H&^=(spB;R(G^qtN7A&-}({ikok$#pDI z=Kv@Wfenxk(U3F-L<5%F4bTCkmEE)G{*<{HLFoUHPr7@_@XbtrOeQxf@%Vg*jWDnc z=KisB#cU5;XKbIG?Iy$s=?1{TAOHeD0E|$@2c)UOZ~^&nYyzMMA(_td#1Dhyzk^k; z3KKTCW)8@*q87vJ+zZ-8h=CqZ5XCx_5RBISjs|@L@Bml-_g63w771Ds1q>Tt1CZvV z!$<|ylEHkjXYjkl_*DrTFre!Eg1BgajcxvRtVX z*3a6}3=C^P9!$rH_~~D@}_TH+%CwK0UiJGsjkom?N!>Q=&}0 zz)o|bdk1PY%Y;bU@pE{5d~~LY6eA81(-GHB9$EYEeg3D=jOj=ie_WZ`+)gPfLNH^8 z*qPZOmf%F55M?qq;$9{N&@^Zo$*y_B5kfBXl&7t1Ma1>Nz4AXgeYyDAIhE{xC?2O& zX~XLY^}1t8Fk^}2dZUnz?_fsh=0*B3GOr!TQf;SPLI;$hS%UON$B?K-@E=- zmcm>O&M+y`fVp+YXk025T}U>tuPu^T#^*&aXN-inh{kw7NT>IFF_dN;Qy0B*9QpvWax<7Ac@uJIWF-6Kcd8 z683ocrT-*oj690|(>b=X1yVDdy_60@4-5&b1w%MP8NvWSBt~d!Xdd08%U&h7*L(aR zQy&|j%b>MJCA()#(QrMCT|l=U0CNU_Acs8z^3 ztpFzL-XH2OR<{k0A2&m&1ET)xu4Eh2K!<#<{SZ%q^;7K?-qNxUFL0ukNW0lI{0km>!j41;PO z2LD039H$dIx7ohWC9J&d=tR1~Yt(07=W_jqy?NWsz6T)p4i~Hj1IvW`?}SxKc*C}a zj}{y*KMS51uE%^py`HHasn0Kby=z)h*}R}Rlhno2VD>nX^Y_+kMeon+@Xxfb_H{o| zpD*=a^!o(2v5||@o~@-k*sW!(b+qh$T>Tph zvX7i1hI$w?2PNi@30M?NV1ieiD7T2g121}%n(FgspT(^DryN}8^Cq5n9D@UD&jo9n zG&rz$u05gBp-^NtI+#=WfVGv{mt{c}d!D`Qo&0*YmfrHgf1;-W;w9RJDM@{}boMeE z$@%m4hWS-*Sl-(T#2!89K#Z5mw{_~})fs6W7Kd&F?cnF~5<2aaId=aE?4~CT3hvm= zWg3lU?xUQm=XIWOMkzrC1)A`!%j;A*E6eyp062Xg{|1NdG}iS*C9)4S5QUKQ8r4~> zB+AzEv)%T&w~{wHJp`e++T-x?6A~E*iL;ibwIvl;n;iNQ3Hyr@As?P{@Ptze)N9eJ zdNp0(R1pPd;Y<*x3LBGB5!O&}_=d&o4@Q$={ z?=8>tj@p~-lV9~NG@8jWf9awEP3HStZoR+r2l))}=JYny*lM_(?65V9)PLfsHjwh0 zpU*0`)OdKzkSN~QsAm~(WUnX11OWkZAy6YMBmj;B!3&Xupd&~iI5I$_3`>7DM(rgb zA^@(8@PXrmxWoTvjbHMLLHm1Xp3L$ zilo`o6ZT{He`LamT(^o1xwYiBN&iia6ubozpjZ`hP#PzM30ZnpO^^Z{^Q&a~76`m&(c#gyVO23w58P zLm@C_BJZ&vylxpIGf@~RZ`W70&3Vw~XQsG9{C}{3(82s=;=?MuGe?f=A zxXFmZOYne^Km;DPH2|BGP@RAa!4feMOF51v0j|SS=heo~RgY&Tjjt9S+lbKcB?LQ@fLU$g zGeGelDaVhb$GO}G5n;A^sX=`ybt&?SocOW*p|Bxa11j);7)T9Z8#6>U#6Jwlin2zF zK}b`eqvJ>uA!&dNk$-YC9nfz=k^n@2VT3d>5-Wrw&O;k=7NPA(J=7Bc_YYmyDkXJ0 z7IPB#&n}M`g4Q%JMXrLGbX;x2HAI33XaI%ivdW5hEbvDQC(2V87c9X%2w{tg0yKQE z$}av{M%UsBVqj0VJU`rs2ms*43&hL3aRX=7g1@4QI^Dm-xS~h`*{4hVg=> zJ$Jl0lGREsE!Gc&STcyOhZa$MQvVu+f?9fyO$#}%DO#~XS~Gkhl2D* zhVshrh%G#YgL`#6T8xddd^dmH`Uk+dm2~J6?mKq8-=w_K5p*goJYxv4OC>!Bu$+H! z58G1-6J$|_0s6}H$`A+?N!cA&J;a@QgKR+~<* zm)&{AL7^k@FImCj*fRN^p&R@O1s9=t{$Y9f5d|9|vK*lRnXsPBuz`&@#*MgP$^<9R zsPWcL{F`({W>G}G@dBEbHLGz~+qF4R_#9avCc_Mha~`%c98@&au5eqGG2n_Vw?|^a zDhwd#*!ag3?3o~Dyu2MPh7bpw&AKl~V_DO zRfTO5UdkNak3+y0%IVsfl4VPn#>FgE7}Qh^W$4DHjFzw*2L&f;2yaRLixg~uRiOhx z)?L=m$uDi)Cierf_-xX9l5yhB~6Bg$#n*0xEQ zjOA@naY0dvfU+}Ofx!2DWtnwM31+RE+c2;Ppl*7oye{r~47ewMsuxFst^{RWOVL0b zuVZz+FCJObyIK$lQ+F$`0gs*}IdOMG_QEoZ3PJwanLnTkT$85Wcw#WN<8QI)hQ8%3 z`OtCN7>xdc-u`IlkJ55Vkt)YwQn}i*&l`7E45N~%{z^h_M<22TXZ)TK3YGK3)yD=qMER1LdJwA>Z|aOfe=$pHo^V2g2rqF9p?0=%>gA z*aqILZx8jzJ$b`7+fp{WUvEB%b7eoW>~^`N z!6yzfK1{I|DD&TNX0@(${o~TqfhTUorBusBB;#^+uSebqau^wAYwFK*91@NI9uT8c zBnY^N|G2dDMj*G6Ajh-5f1J&+f0bWwKIUNtUxTmc1zehvke2|_!%U8gdAD_L+w zfs|%_FO{(^s#Eb$#361V_k$@>$YOQhoybI|JQ4#?KgYU6)sTo)`+iA^71&>^Hf8+VKSqgwm|GN@|hp zHKkghr`SuJp1cZWffRHbJXPBx&y(^$$~ND9;|XBAzmkqO=-xW`jl8dzXC%LxZer}u z*RGd-xwNwqxR>?*_Hm(x@|c6B-i~t3?q5&XuYc=>Hz@k&){Osumn*{MUg(S=5G?ru zTQ@d&FvrVT!VhlqF1ptQdS`XGg3)cStolGD!CY+l;p{?m+r-N1{!FBAO>xm2?9a6Z zmrvb)z4Y}M|G1|&sn=GwUT&$w1ztw98knjI86=-gjaI2%JmlG1ZY@8;tSygT`Xk1# ziglNFC+|Ogq%J;-o|`Z8Tlr~QlK(AvTlT>WM=s|eIkR?t6!FiV*1jmdb(IdkPxX>W?JyHu;XgiuD|>y)*GdBtwCxJ4^cG^AD7kmW9E z+3Kkf$|$f&s>B7gV7@jiWy4v1KKqtUm5VT|q&8Vo^--tA8U=Cj60MV7xF3q{D%NHDx0*+U2_}7FRsG?ZHJ^eRZG_P zduel}NG2;iKAo4>lx)eHe)f9&;NlG%e$EjC8@3f`hdQJqzok$Q4G~33rZ`L3yOg#^C%P(`co0n9@Kq7Z=P&-L ze0MVqrplwr%`i2c_upbGsr$MqpZLYYpbZ^E7=bUViSj{-jjMkCG`|3 z3kSbOQ(YUrS7v_>Hp|ZXN*)+ev{vUU+j7_=?j+RT)|f(CR;bf>ZOq!%sPg=6siD+L zd627t`tQa|ei0jjOS4v;`TAY9I$9->pN=8euq+B)3qiaT2z64du|Esx<{j&mJi7OyhnObrt z`l=GpVlw8N2zZHpPgvv*F;%g7=|Vse`Z;?124RM`Z7~vL#I7bx9!ET>g{!YbtU8oj zb(W?KX4SMT6GLxkwhttpzZ8_()O~!co8ea|d||;^?(qAnI9bu6%h`*4hUGVVy*STi z%j4ri#bfTH{^PHe!>$L1uQ$0rgB;ETR)inkr~ma&)2vn;wARflEt{NXH{ij?!LY7_ zD?^C^)(FTxm>utxOd$>m>O>IP3k4S?88}*RnOnN&~#BK+N87%@-|Sc`@Y>;oY1833rWyI*&GwSuz~Kg!f{ zF?EL2f!$Jy=!i*9m@dX8I*f-5EfwH_aZ15E>qyXWr@rz1%Sn&7N~S!c7~CDN1b{%= z;NoD4$zxl&g8CVr_|K|;{J&mY_Bpm0H#;j3ZUSfK;U&e#P}ugawK*L)=)C&Dw!CHL zXZNh{eMRENhv-oxG9)Jck2z8@EIFKz1387AAL@Y<28O2x*A;hcCW=Gw;uDR;J5UqS{bOlyxdR3EkwEi3%2h|6S@P zSMKvHC)_u4h+I;c@dJxI{*WPM62Igl@5nqWQL!k)$cub!H#%xgh6b)^)axYit)^ib zLa=sCjAWKZl}dg4H&o?E%I>B|dhb^~rkwhfnW&HH;(8e>4fOy(m{p#MXV~dS&AkI8w^9hNie~r?428!{ZWgmQL|uXKV|m7i7mb+mfk-iKK5mF z4nT>X?j`1#+H#!kZ?u?-pUTm<70ipq-fRykN`TwkE97fODOC+yjNiFAut~NvXB!>? zq_q1X_>2h?V(R0o=Ci*tgMQeZ9GU3@0NYM{8(EmxP@7Ol{;S@gjInPKROk(TQACuw z!n#HVWBu9GdX077`|9$S3L_AW6 zZyDwmFU`lmzXS_L({q^wrN`y6YZiC%b1SzYzs@51H2eOLGxS7Q>-t({ixXod>_sw> zyU8VU|bi7~84PDo7-& z(^`L1E{#y1JcR9=J45oVZf|SG#s$4N)g+=AOQx?&?k(!P>RgvobsAL~-j7UIjdQ1( zH2v3~^^kgFNY}XhP55OWU-^8~?8~*5fj&H1jdSkZ8{HlS{f^crtgGbI!<{f)lm;?= zxbazDV~x^>8)^+Rak}D@4ZL3Yk0p{CEkw{R`G6msgr+0aXCa0spNjcR^*8Wq7+6&p z7rW-;$NC+wL6&WY59SR{RX6^<0Hnu;w-=GD-ByZA0cw70V@D{tSu23I=Z~0v#}6L) z1GAtX^9#aKzrw;#1>DTjb#`^;l256e2$by#hBBp(t*gUaZQ5tKg7o_iR$KNr4&$Gt z>G>V1#-daX>^@|Dl%aTtkd&sh1Lh0s6%&Is#Q69z^JI%v_}2rWb98x zGTO?e(a3hOLY>OFvO8Ai7^td>H5|&RhI%PS(81W!O5b4!j#QpHg>m$pv(xq+ch82nchjF9lP;xAoiNOT(Pbp9 zEI+LX>E)_F0B|Ms9V-dISyz6o`&AwKRk!8JB%=7&2o>|^UqY#Zwfrl@t=$>I48{~Q zu}>Jj$XSW`<)K&Wf?7%{9H|^!IFF$hDwAxK7m3Tcaye_aXR!N=Doj^ZwQ5==GI_g@ zd+nQc9#ahHSdbFRg-1A<85u3Jxt--;%1kE6oceCIDsElNgp$s#)UK499iBhiGuKM= z_pxj5SvU>L0sW)UeKc))ldg$kdAam(Uw=tOl7iU_XFd@^dD}zQ!@GcyAIn&?Txixa zQ-8K&G4?)+%oQeY^~0yh%Bp}udr8aQ|EoJ(S+}h@0ZO#ZO^hoCGc~#9myOI#nSLX` zTa8`N(CE#)>l<%g3p8l{T+1G`H0sz9%s8;ewLEEP%DUnbsiZeuE0Nz?R8Q7F!tJ}i z6yVLRXK-TW*`!1NgR{8{=BM_#7AAWZXT#RXVKf)+Mch|qksHK2vT6Kxo~^D&u;TV+ z<)SjHv2y*wdZY&4`+M$DP*A8?!0S`h{n<=)-FKs1ZdN+QWUH_j9gN|dCY54O{BTTo zjCjPtW@%eY%&WiY0)9@)!ZsHh>k#St-q;<%jUHvi-mLOmoL_?F#Z>xyHnQVDO~5Nq z=|f`G8&Vea><;dnsms}^-oqvSLhb1agJMDr^~CDorAE~3|Y?l1YB>&U!>v+My;CW|c z{UM3Vw*@x|>lc*?eDs>vBm$W}2?k55rCyq+XZEW>Zf(8krb^?lg${~MuDlC94F+{` zv6)5WRi=A)$d#4?j}}{O|FfqozcD5AvOVH6YJ{*x8u$eU*{In+<<6Ff6nPhoHGTZn z68j6Iqz(Uffw-G)cSzbA|Ke|P*_@nn^FixhU}All55LJ-?EHrAEIMi5CBioqn7Qk^ z-p9#(7<_nPH!dp7g#?R6L}_G>rQLm;?^-zRlKlKgJ;4as+fX3bRdG4dq;dS>&trOE{Z==kI0gL=RiQq!y;5-9gK%ab$_a-TgJBJAL&uO0JBhI zJ8pNyKQdZw*9nxb^+Y;K-knzrc_pWg=P;$ioe;39fPo!|KNYxVzrRm`C^ zQOsfbI9tp&Th$OCkzUN9COV>%RxG%~p{n~5Iml}R8RS=`B}gmgP`^zp<{B}0+!f7_ zw^9H1`Qsx}ijE)ygkz%+{r>UJ3A?w+c1lZjHlhFP{L37U6BRQ|*QEkG3TNg`t@F}^ z`adEp(Mbno8up&xe+KM-GJ*w(*$c-Aa)ymmhyD+rncv*j(2ubL>X~^4>`{v1Z{LR4 zL`IUde@S<4CJZgy6Ih8*cCoitJc|uFx()C!&yjL%p{*$I(M(!kpA+j( zu|5s!c&EiwglZ~fgO^5h4~Ns<9gv{ia! zO5LEVmtBwx5G&HNgN{uq5T|k=8F}jjp#ps&R47J*+odBe^}N{Von>rh{ykGc3o(T! z&C&_`C6xH~)6zW{Dt z9<<;R8NR?Co*^#P8S*x)Cy>il0iP2!!V7ha;UQv99C-ae5fP%|Q`b;c`z9y@{EyCi(g;j@ieW~n@ zHGxP3p7M`q^_c%xONEK2&ie0YvAW|{xR76*~#IYu*!px!JSCdd}RRH z9%lQRR4Kc2CE$TTf2K?WuZvV4K4CQ*S_rdNbbYQ?7D0|aPvGudpPI^G> z-tq8Y76YPnpXP1bGecf@G{H3;)3MJCNLM|>qwrD8PkvtDDEMXZPph`%obmp3vj<}H zpQh)fX|UqsbX-+XCQF*r$!fw7rRLwg`VUg`Bj4{w@7*gVLf|&1r!59F*JrikM|XaE z7454&lDR>Hn7&BU6~4-!VhuGWqvB{4f&%s~>aVGFRn^wIG8Xy#KU-W1wfv--UlOkT zrL-lz?E{q#iMy_NTb}GsikO&Q;W6ss zwOZ_B$xx!6{wS@#wGzs=O&VMT=9xtHOA9 zQM&BLwO02_!k`5-2U%GqlbJh^(C>8rI@VT(_)0-BWYi|B92W2_R*uZh=QIr;H^=$= zHqOdrbLZbU3;cju~u|+nb|$pwnF$Lp#^pLZda3a;d|M z%r6P1I;J$It_kXtJ~4~a@D&)m#2*jqx-M^a4kRJvThyz#W+9E}gT*YpGxWsXV;N^f zWwED;otb5fx@v3AwlfAXenl{GEjHqyBZ>bAU-R}d;6s}>2Y9oj2#D+{A>ar^xH+$` ze%qjVWaJzqt)o?8RUyh=J-MwoMO?6)6P_e=?u%LNU@J`sq(F2Dh0w**!1hCc&lU7> zNd)AuAu!oV3^V~5G?ev^uJAA%Ds_Oq4+Ng?9}C3f#xA{qabPD^<$vx|ck;|whiie1 zdA(Cv(WT)@Hg&42=32GF=-`MF{<=vnc3qv5;g2StLJ*tM@GlIj<^@pp^LsIZ*lLMP z-r~h!`T?QJv%+nDkcd4Ix36V@QfN`gOe%+4Y~re59eX5%`K6np^*P$Z@k_N&?yXGS zMw{^hL%4v@l{**gjXeAq1mdYzF#3*}+}%HO9e{uX*s)15@IX-XY*!Vb&JaK$8iouM z0M(lOC%pNB9)Sx;9mL{Okp@zYN;5-Xi>?@$0(n|<>zg6n(FD1;*+@|N5TsKM@OoXE zc0HpN*Jxr47Y_>)48w#?Mza$N^0W8bPx2oFh++F4-O_}F3@>9qjA!hOreDW8NFMf7VhBPzg+2F z(9*sTF%1PT6<(F+hIm-yAOMP{TEfT#01ZNUsfOO!XeXIjYl*{Zv$aUW`yl8#dD8N* zQuzU!2KvbE0cmLXw|q1_tT7>60cWXk#;syr0zq@Z+xvdd7IS)N5fd)?n2;6!B=>~u z3^)NSL81}s=vZE!m3zeIg2#i+C7&eki8ef+VrK0+`Wo8#uACxxW^9~mwN|`AG5P-A z2u0iv#)H}vqL_3XE7qH@(O-!8YQY!g^4i$$p*gfT@Uf|@>MgQ0?Cbr4QIPXv!h+6= zg-46dMQg7xvR3mKmhRv4iqz9Y>@GY7lj?s}f2ax!C{R{Q(4ZC&Q%qKHT1*`;*O!ga zSo>-c_f>s**?Xwf`^3&dMV&g_nIJIJQ;#)ZUu zTaH?0i@@Sx7qnko?inN68WJ_cU)vqRy)NYW=he!PntnEp~O}j`S-8oX*%Rt(% zf{jo_zhc$u*uWJ~=E!E3fv39@>stapCRIoLd{iVP&xqFg8;Qdg>Y&#_iissnlK8A3 z^(V?XS#q`m!HeD}AL;unqfYo!BS>(%dvZ)zr?R@x+Np(m*|KbYlKyuG73FPR8%HN2 z4jDyndX@ZW?X;b-T=Icg>(X$mILYs=poE#zuYctio^bVt!7O#jPu8TG$Rv zA9AWUJGH(l?8x9BRyzO>&|)2m0|yue_>Q|?`mW!RAGtpi6PJ*#x8aT)Rg5gW9<^uY zE}#j@QQ~QnzoClmA!mNCWMSWF5Ipy9go!dHEo1M69ls>#Gs|v~fyV1t`K0vohj;u= zonG_&F~n+dQQd4&!{U1w6!L4#lyEk?%vMR*iT&47&te|!D! z`l$3GAs@b*O`Vk6@)!(Uu|(7hr>LZkZ+QVnGb7&C^+4fYy}Cc+wnoRDulvjoKDDPD zXnUo4lEy8(-71*UXLyTU??zTABqwe!=5_VL{`%ww<-X!ed!KV_67c1!y7@8RuR~~resU|*Q~az5GZR<;U4jF=2D5>$B-KfwR9O7$ zDMR>psBoD`WO^7ByMF**96B^L9Lq8s3omd|dFg4*vMm~_m5o*8&!Cm8Ux+lYc+L}T zZHD#jLabwSZiD&2a|YXrp!c)DNl>Gj*v75(H0(v~zJF2vXr(|3W21hG3q-$4AdYUc zKu!A)wtJAl_aM>grBrG0=G}au$PVLe=X0+%F6T|FFBzFJ(zgpW|2pV=8&$3Pznr-% zjD|<7ZCs4g$W@V|cOXEt>2XFothfV2X%r7rmN*845j?GHiudlY{9%Rgq-3Fb*D9XM z1EbpBg`3yS%G`pvU?q@mlEbY1q8+cG2rcc zk-GXvlvkrx2>Ir82GBcKZvRO>^Al=wkS-_UyGv@Re@9dcd4$%a4bDH#|v2 zS>3?dK|SY@+(FFTUsV(9_Mc}C3zd56S1dd}ltq3ZcMY4;AWzdbIrU9^j16Y3*4b)r z{al$w5N?FZAq&oFjf|sEGnQ4QH%yFB`iXUi`+N}4>LqO&V;?AQr2OIs>bX=Kg?*G{qj_dv}uZ7zjH4n z>;H3t4r}7j6e0U*)6#&#RyptDS$`RVky*f${Wg8zr2OFCt);&`+_BR2QhGwL^2I;} z)ns-Pra)^9>QZq~xa;8!>hbHY5H{)euTr|AxAV^hyMW)RI$B~aZ-1h$Ixen0ovvK} ziNMd8v>3;|iLA1|3H$A6P~F7#ctm}>T9v@UWP# zC;$L$tb0*Km{C)dI6{s#@p3EM`YFZ^g3RiliOx~^7ahM4)RoFvJdL(w#(KU67!rE*7n)V8sa_;_=riPM8d zzB@yAfTgb86KkVVj9}n1utEP`m4+HG%({-SR*w<<1_%Jj`e~u{t^=&EOb#VYxPQ>FjMgh20GW0h*Pf5c=IG6ZoD3NeNl zM)ZWRGeJav#$nP5pW9te6?3DWIPG*WH+g|&&N<;Gzdb=0KzH6=$0#gqA)Y!!zS|Lh z3_GNXE1nqO*sX^dPa$p8t<0m2-@{hhiuOp|D{wMM4{-~cE$A}osJ0=FXsI3!V>d@a zLgD8c5PYDYn6x?w@iiJB*ljIbUK(Lz^!$@+{a;#w=T`8sj{UN_l7ED&XRB%UNB!+)m4G?2SbzTy`GH8W z430!&vdmwZ`PGyl$~C&OaR2d4K*B|`lTD^QeN5bB2jrU3xGJO*DT+o^`sYU`IE1Ji za-8Xj(5_PbOjaa|CeAuVA~`=i{KqMuivPr$#+#6o0m{jETHhE2MKrzOoWKwO>CBeG z8>f}=E$t9g*jryz`YT=KyFExVnIemYDIw1UWjA{Etu2{#lxGj#{kRDU>D3S-EN>wG zrNa2Z)}kst!!`)*fV=yv+gnPy$GCM^UX|@TP`~S;gjoF9@LAl$OzR2bv?Y;T4fDN=z>goGxHgkLZ{3d)b7N>zsKIZ z9H$LGwK!s-elF*ERV-ht5vGd(0|!egsHv&xQj||c+AVo;_I3T2z!4f|f@iv7~5L?>X`J#G8M*nSZtQWvZ^Jo5uUc z$8V44$v>_CG7VExlcWj@3%f-jE0Z|^nAd^tn_Tuf*PN@|(2D5A;cteF1tv*e1Ld}` zb%7WV5@n93b$#b7oJf0VYy|KaNBhS}m+`)-#G567nXLgK2tdGa4l&uN7a9(%fZPmS zu6k+#l_jU`tX;6lk`%$R8B~5=ig7Jo5vAup9yK=$7$ED*sI;%e_)SiBlVx^%qsPco z*!iQ6B(d@8XWRT&FF6Yqwsh@&_R6mK9j??1F#x5o~y69CHy^5wsmy&?MBp3s{?Qu3B1o z{5NS*4pa%BQw>m?J9s*SoV^kvHnGea0_=Vcbox`=Cf9TC63_5=4msCe2IaJa!^3gs+|l);58sN@rzi#ie_Ndt zTQM4PeA&=MDOYA-Bq$GHDLWnoU}zzBrCdx|#0E@vdjs))Ib)AvYC5#_CN)BViU{cu zS5kG$h-6NBv#h;&IxOach(F#)4`9V_^**2={@sv|$YtOn4sQj07|2gu{3MV#lHtkv9}ZFd#xY1>vpcmA3t@zXnWdpfz=m z7~>z@?`b3lEB5Zh)M6#{B&bxhzV652aJg`B>?LK0N0({rSqJ}5$M`;t{Da{?9P0sO z!`X3J(ek9=q}X?#^FJ7ST9!;naIm%lIx~JNiLhg`qUXVJ&;)@nTw;s{z&axla0vl@ zjuidZ>2Sgy{?a$`pCT42!$IgodDpsT-1NWppoUDCAdW5&o()sP(#1$(gT29E0%*Cx z^nmN@C9vmc9x!qwI>-qTO{z@)n-@C2G}?dNRx8N8Wf=(_*5DH?6V~9bEFa0oENP!rJiW0 zJbOwRs(ME~%}|%?T>59&-z6 za7KEZczi)~1AnU8iZ&?rs5D>HA9&p`Zr-(L+2nH?wBNfjkahjiJPz{cM(G!(MiZ_-o(#_m*+I?m>4cQX(Jn|#=$E?*q?!zcFv+oDW5*s8{3N4h-iRsO#M2O0%1+R&ZFq+d0w$J zR2Zc1%paUJGjhLsJVFt2wOp0`M6gO?M22&TSx1FYg+00uYa+>M%vjaMLN61)ir4&- z)3>$6JGj)k(?vq`_LLglcmFex@VM#BewMd~X83!!r%RxrkI~H4`I*%8cKIh-jp`@<;p4hz=$3Cs)H?bb2(fwN8;7oHK5St>fgNZ zjiY70Jcw$h;#*(m-;9Ovn~zCMV@U?SeKbG0QEgecnf2PwOD7%ya9*8h*8#WoV+l-;an(?CLRFL0Vu#zZ0d%K z+%uYWoZ$A@A7A+}W!9BU-&%&F05+A7K@zz@G%gvUhPEKPOL%i!5FN({SI26^N`47% z#P?c%2B!xs5^%Og+6PrctZ>CB@X7SpVUJORT6?2u;zh&cc;o=ja-?q7s$RUx?eaEq zBrNq96F~X)u`!R)qJH-TVvSLEcR%Fknb1#EN?mzB!F_@B@rqbvvkj>t&@g8gb%z3# zGi>XWg|lkS?r3sN{BBrJG9(_>V;86tPuHW^Ll1-ypme$ou&dVE^bJDg-sIO`>h)WrG){*|3o;!FnSSf;pNmYE&wn}f&hcSvqpDI!`|cI(v`KN zpK9Ru37GX3H~L&ujtHR_OO^xviPL4mpMcN7lj3kCun&G6H{2H&$;AkJj4c$|O$MvO zO+nQeU>qXUA_7$~rV-!6F*q2@@RPgisBfc>?xYjDFz5aUFII<@4gyOh{l_iG4^mSf zJI)3V3Lp%e)6&o{anO!vZH_r^ZTtTWOF6BVy%WUl?JJSt@I2?3}R z1%k%zq5g1}!9k&|zI*H)0EYn(!UfQE)>QMXk2gT1qY#RsSYYLMT-KD*GyM8AsYAly zSWIwI0u)_&$|f8PCkL=h$vQki3aH`@LV!j{0G`H8HyA*+VIh;i&R$+1ivxsrW5{O~ z3vk1`VdbR5a{9NdeTx!6>xp&>a z%L!fiKcUmd!qkJGW{URf7Cy293!dPj{&;HQA>4Xk`1eHGz#dhGjbS?z0h+UemCN?x z6zOP9AVB)26Q(7hgEa;N`1V4WAw+#leWLgf7%HSYnJgl#S%F1x7i!kS5nA*jR+ryM z;bo@b&Wq9F8vmmyloQ~IE2dG0usc!#AcYhFuP2jXgZ7iDhFFLX{at;1v@&wT+bwe3 zj=s;MUp$?hRVF?>WGd-+HPtU6JBE+aW=k9dpg52+Kzo1#(LEt{9SaTvh9MweEONjk z6arAZD(v8SWvYm#geVXZo#lEi*zzNF7uW-Y>AglbD%URVi*0prpku(H<_H7~=ML z;`w{nstG6+06?n7?`0@i10mtnCI?BNG_{O{?#_8yDLnt%n^P9hYY;7{1x z@f7enC7?|4wD3)AJ^Iz*5!)bHMsa~R@~ULyk;M7ph$BOD;z6y>;+xVxvx?V`RztO1 zb=7>e2XfTxhDTDr`nVqd`t#?XM{N7;6=c!>sK|JB_2dbO$+v;i3qcI#Pn6NpY`_KC z{j>b=%TC()Aub-&wy_qt!S^lbUKb*;zlxupIKyhm7{J(N*YTs~vY-1NLB|-97Sb;l z%XevvbSP_Zt7FDyMw}wB1QLm*pA)Luh1gV=e?|7O^|KXlcz3#Q?^3{1Dt}|cLAuI| z&O0wuy;WM%yef-OY!C|E!5Gk(c|4|*^sO^*^ke7GZlAti-hP~3JQR2SKxrHd40HdA zT4m3hR1xDG2oR##LEFCPjkh|`orv;DL(jE{q9g0);@_K2nYz;R>JCuO(bRKiWG10q z`{7uOtKzB-ufIA5xV{@~C~lYhXk7?U#SKR!r`@)e_r*~uBFeHSd`3H+)T-g`izy5JcJ6;`x zBi7M!ez<6YWAd6ETBh*WH^lpgqwe6fb3kU6;!^+2t0AsUSoA$tF|dC^s+g)~ZG|!9 zQ;tSrOs{fMZQ1)no?14+N3~36=(AAxrguznppK?HKJAogjUO`y$1}H&Y?rWz zVHgve0V2n4;5x_O@M#^|(p(_G=l9!MJdQMhtqY?onG1tX|2=Flle(3nL-OXke67j-OIsQas1afsWAmC4sBdS7@TM z55~vmvvy?T9pz~<&)0Gp6(J=}G9shrC=*U(EI$r%)p~RkqI2lL{+W{c`J>qvUBCJq zXbV|}FN8ZnZU)_B6>7jz+DO3UX^J~BWoS|YX4t7Eju7U#K+>_f5kt`0r*?`nmu?Z z1uU*>`)BvWL+53`bz|$1unLc-dV7EG^QJ!SBYH2z*$O?RpG{GlhA@3}m3+bX{@A%V zdY!uRV5RcDS$2}bqUFt{H1FyI-5KDdH}t)_g^TVMh&If%EzxK{c(dt)yLpfLCyVCA z`sG;{4Ux*q)Q_5eA=U0T>NPJ#LO#z{u9Yu9rk?u7j}6$~xKw<6y?1x)gyBDV{|kU` zFk^ADVBFtIsf=o7OXFB$4jx~BrZQ&=8G3UU2b;0gE^U5&cAUL(upjf}lcSc7a#7g_Er5U{n`91Dn= z|M{iJ{B-T#Riq-z?~2y@9bUZU-x~P^F1aEAGlB*S2DABD;I zOvVo4GE5V#T7aT)fN){Mha$l+k5E zp1KXs80$tHTZbX^INAP$cmIh5f{n@H2rM|1#J6h1UGI6vlK@)n^Gt`iC0kX;%E6J9 zhS?WM90Bxrdgj%GfUvO^Iz-GX1fucD!eRHT?r9F|X06|G11{bD2b??d&2zrzo%()hDew4<64QN z`$FsRwlm$(sKhyFzU!gx$}-hmRl0yrKK@UOmNAj-tx=3{0Xt(BDE-@6Mb8{<0h3_* z^O(PxuO4Ovjhwz@8KH{^Kw0H;O!`PMspOv{_K9AE!uYHIA~g5{BRMEYu~6=ysQ(~B z#3VTE$vxN#3?M}kpl&4($A`PMTeJ-Q*B%#VjtCv6z~AmAJX9rmvBj7SU4GoTW=*%+KZmz4y-$pZv*g8Wg>KroRHX%B7*&2` z4^T|^uIqe_{Se`x@K~BK0ky=X$WJHnJuM>#KUgsg5s#0Wc?bE_RH)O}oiq6uMuBhs zLHP0wErMaF+2$cMYL&Lf={Ze@%BG8N$F)uHqvHM+y%8v2&{6a+7|9s{w{fI19h z7-B<0eH*85HYtnmDpmXn&2&8@JcXd&mCTwob#eV?*e||oZ6@ml>3#KFl^r*Xv40d? z?#CU!;Vf(K(X^Gg6p7GGiMshCduf}@*FalX;IU}8Ba)e*3Z>9#+9Q=iv$&7u4Wi^Z zR%s^r78iWj-t1P{Mqd04jIC&od(C$~`u)cw9m)CO+E?+x4DINDXPUf1Q@}MRbT8`WOTDwx6+{-S%psaKiSK8d0MRibN zA5n%E>$iP<*_zT<%GH)Kpxe6*mXEc=e)1Mu0W3!@EfEZV9GS=|>WpWMe#+CzjIJ9^ z{`G`XK1#87-2ZW5n;{zpRUT}+w7J5f#rAy*R2MV>^Q?-VV^S8Gd^xIbAa zBnVS*Sz8}+LMyi-l||PC0^o*j_HvgzJ=-1r_oCuuAKqK$MCC|WLAHuP%9Ux6UBUog zZ^}!vD>PhAjfsj+sicvFkt6Nwmr^#_xtGLL1hQ$nK#r5Az}n&Kvf??p)&~VYP0IR+;%`x&3Aw#a-{4q7fvLkse=G(HBdq?I#tAF1Kmcwx3z9?O*l{DIXPV zV~TR`Jdr3f@a1jE6&}m2ngmI)Xf#7zZ5KQ)+MR-NCxE0(r=_uuh_D9Mg|_Lbp--f6 zzt1#mEFc`{ucF0+8_%z=mCKlsf8@y<9nAd72&(7yXtV#T_j#~+pPLQ$K}WH9EO#mV zNYa>ElB%SLz9X|%kxq4RY`jp;?~TN9RPiRCRxtQUiP0Z-=CTg>IA zUud2j&lp6ra6bo&X9tuVRLn>jI6QU2R7kO<6bz~8oQ3yV+vkAsCbJC&vWoKkSGAug|hlF9uIG=N3S^qEnkb)K85 znvE%R&(Xi>!R7;Atjl}-OuS|v&vDSt?5B@ysMrr?jHKzPb|?=Kt3Sl;hfAehi5hON z<9-PLx8$?P#H#sE!?yA7pPl81TlSHl@&k(|s2sSAES}&Zz!w+B?%B)bB`7ewwEp)G zIX7vhtM}QWiSMrldIkB&g+BN9|%Y%CywaQL`j5F3wxm+Rg zy6Ok3@5?0SMz1!dtelLA))QhU#oi#lzJ)N%rHHJKHQS{zn-4Kj(ZampF$4Z#5{Ec7 z6sxg5dQa}C$rG(?j8>5Sd$E|uLBo#DFuw9{n1;RVsZ3z!eZM9ub7FBbq<;7E?q6tr z_J-{Ba>aq&FH(Gxeqg!0Miz61sJVU$w-R10dAtwo6)StB(Lr`fIYQzw<504s;edX2 zblJl_T@&V`vaz%!rB?)F%(bH=iNu}*v%4i*>M!i0up7J#jm6u}0t~(}m@O8);r#*$ ztMqgL0V~BOsaIa_c`TQd#N`NR#T!vBOEFSQ$?KpT7~c)MbD2lj`mMN>1Zgj)`s*-c? zrwC~KaOmCb(SqvQye-HFZ%81+U-FAsOkFQYp1~CiBC3vlce9e-Dt|G!z3|M2d*;n_ zwMU6x8Ewh82fx;Z6bbHiIG>ICuBNKXC6Y(Ek)=YBPdOO-`&nnQSX1Bha_Wh?4Pt-6 zQdQ)eK2_i0oR%D)(bam%+mf8pj~#*f$;7M6@vAUnud3redsV5e=O;eHUuV&}QY?K! zfYLs~l)@MOIr*Lbub< ztAd1RF7oQF2XZGjX1#{^V(nbseOs*ZlE5l4wG<6>p=8};=*m z|3u?SdmoTD={@|<(>yuEZyBspob^gi-HnA)RWFriDzWRONsZ5a|j2eAkrg|cnW9p~e=3`7n z^CHay`kLm5L{?8pS26HqeE$5RBtHf*GniR`2~g zH@=))`(7-_w1j3peR%v`7Un~K@XE)|@vAGA3!zVUYfRnA(odrHFB!BQLuEoUwcux0 zV;0@l`Z{Su*-hF}1EX>i9O|LAyhTDSIq@mzWrksVavM`e(v~=v(EFz0mX{q|mNr1O zPrIF&;GpnHirJA^rWYFH@&*&2JX3+dO@ZC~2TWVth9fn32YV0eUW1_;abPOb(-oC? z>d&mZqnt#4HC}#D5GXgk_&eE7!SihNfKWi<&Z@;*Dl;ebi|_9Tj?W?IC*pS$HaK;I z5jwxUY1NW#m0c4Vc8W>8z|UEM^Op0S_Fov(N`(zg4aYbsF=)Vfta?6W*!zYaP`*JCK@}K+$&9wQO+X;w%8QHY z|DA1Y*vTkp>Kk|0$uTa1zYIrSy?A94c0KkG8x!2V#MulsIB7NJ^PRmJi;Gh#RC|{S z?hRzqb*6fZo5pTe18zA_*8fBV!E^)m+a(e;F&5J zr5GhQwP>6klUos!>2#E8fZLNITmT`Z&_(jGfZ*yeu#6b_o`z~gEs8j(9>h=e!RATs z9I|`{>$tQSzyFq;RmZGM`0T$QeaRt4(WF<>Hu=>gNEB%{RFD}Ie&?3&T{Wp^{6>~` zW4SNW>T^74K0QoVdh>YxiV}-ey5fWLDU1*CQU$uW#t{nh2n-2u)wNFF(oXK!n&@#Qq>Uo2{uh(NdpQM=w>4yyl)C_n?~*@h^$NT4f5ykT#=3Uzc+=qA4Tm z%St2rv>J}&U{=|9*7q&zIr6Ww@vd9*<#pSdqx=_Nr6xid9d*k0=5aAD^y*ov23cIY z5sd>tSe};M;fEht%Ip-?VTuK| zPTtqBLL(|#6E0LbwyOPox!T#XSe_;wU4U!}(IkH}nX~=7dWRx^vAS;Tqu$G7IiEL} zJlc)QCrAHq1!+A0#(T_N&-)Ru5k7`L*u2lfyJl%=9knkJqI$m9m6$Jj&LB#l`N6TP z{rBmQ&}g@R%0t;#(K*`8NsLwB)~h`|lla6XkH7UL#kxL~zDjaT5q!*o2vSKy?Y;c# zdUto9qBZ(Dh16|+E!>%m(y@-b+A`_<*glcq@T-LT*~`EBYyXlNo&{d7UAZ4F1g-g4 z95mN9OfY^MC84#;3P@pb2FxYPiWA47ix7)WWMFLnkZ+aY88qK?Lq>=1Jg@5Qb# z8?M8hW*~5>g3YDB>4$TrCvVr74grjq#+7AOoO zjR-;|)5FMt8m<71Ehy`34;FyMp%66$he5HN!;ml(00Y7IU0`xD&%jC;m8#|RE@c6n z1CW#yIAY64S?@%Z|Q?2En&tPJ1Etmc^sZWWuIuc;r zR@DarP{KfyOnCR=jr(=MiGUQ;IwK4e=^UP{kr);fj=;kZ0foXBxiEA%Qz!(^*i`r3 z3l*{ey5O0H%vUO}?jmL$eGA%HG%q4dX**2ytIlv=2(hkInP*GHDZrp1NEKL*Dvzdv z5`1!qb}VTc+3SoWpTl#Ak3vekpE^KddXaR)P!(HSO|y zk8`F;R0Goe%KVM}L)j8V$yb@a!di`WM&px6VMB9c@tOdby}I(JG#S%F7((tPWte@P zT2^Yb-&J_?*mc}>S#{gKSNi_rUC=%5 z=k7eANP|&1%FldVf_->`#?Fehkdx3Pj z++Hkgrc54W<1a=Igt44`w@LvlE7~VQszb$r`r!`l@U5J+Gp@ja?6ZWjZ^)jMS;uHB znV!)%bg$GLu!(=XGg}qQ-VGp<&-K*8m-KV?P?(&f+xzpvtJ<{Ns28=p_Fh>y9Jw#S z1wgXUK(|#$l*)C&!zMR=7$1z2X=GG-B|!&2{b8Q|IN4}G!J0MFNB$x7Usrq0Dvx+* z6+izah04y__YIzfuD`eE_3=}oM5nX14q}xr(N)D^7u}-KF~2gom?QdwUuSU*ly3PI zKHR}YA$qEo*Q1kPmua=g23vpCu^P|vq*K`G_Ls=erV_}LC2~o@d=l9V{_uO3b9;4z;R^_I>V-tjcQRIuPXKH}pS(J)YdPf}|^ z``dMw_AoulSZd||nV$2jFv`qQOKVT>{o#>Qq!z8Ae!PmBZrBa;1TubW!oz-|=1C~gS=IA$mKko#Tr)=S;eQV_yGf{+8ZAXt;t*z|Yv$2D;mAp!G5 z=WtI>f=Ev`a!r~(WKSd=#3AYh(DlPF@2ot&9=fb6#7hI`IA~Ag^D(bf$Yj$5L@04_ z)I1te9$igesXeo5vyJ_4lpF3)G5`I;?})2=PH}7h=Bpd3;LyCEkXUK7E0&ZsNOVv! zjP@Rm31uT8uJb&HapU0M4=m*LNk5-mO6-2ysS=9g-}*uAdaj>o{Kq;PY66P_evgXY z!mtt2Bx4}J=%c=B_IQ@C;_i5MfLyQ4qf!8q+4$Yzo$Tv2vlmHwd9Oa!`k2QnBK`fX z>9I<&)&s36fTj3h2yE;=VjTe76-xxvO@xEp6V@ttUzMoE?Mq&pKO7$ zPq1u2Oc(CFZKo65YVb?3w#%5>8yuT6yV+nz$ag7wh`Dlg*sgu`4Tsx^CNeD#{l?zO zesE|2d87qx)wEXSL87*ix(D8}Qf_L=HposXN*W$Inxg{-cS^cND%y-b%mF)1GGFGw z+wk-{-C#`u*b4v#!H5H31%&CeV#A2vlg7z9Pp7}p+tvAMGVDH8QCOWTY8+{@9uU`U zb!pH?SIOJw{&z|D{~Q=F`c#J!>WMi|!W%C}*yc_Y<0nQ~+zU%#_lXa)q{!US|9hnz$v7e*?l)Vu`%=$$4W34KYqB@62rZ$#v)fMW^~>|Az# zN^)GLV0xwQG>uFiU-QG)PX-Y*u^gJPd%WHD*rI}RJVIvY5hb|gesh;fucQ`5Tmb3E zBZaE%kUZ;Co$PIph`}SxEaYHcN`ts!VW4QmsYDi5E@Kw4WMfiOCb9zHv>gD31oebZ z{+`5bFwA#% z?c^1k1WobdJ*q~nW)B=@nDjBcZI6Dgmfdc-f3Gyj#X^C^WSQ`t60!YyG9iPYH*N#J zo+8M-3;l6*QLS`-HMT@rJi4|tj#czdEIH`ia|6?=gaqQ?s^3}TI=H~Bl^_qX|gnyoKQIX|m-48333{f;V&)fbNL8lH}aYJG!zT&~+=fY$FL z>&v=ch2HhIXlb2(9!#&XaH|Ry{`T+d1TSnfZz#GG`o18WSVaBN8xjjRn* z6`38ZZF8H8O8m>+mFYhVA&VdSrem(StWo-2T{u-Dam1c+)@7@^jgFMG=o~sl*OZ6x z?jP;*Dw7Xffi>J5chZ_4h*3YMn#iL8{-Xq4xW6DwXM2r3uRS5>^v^FBVE@HHDLl6?jqbTdFgRmFxWu8TVYIV9EsJHD-#Gf1+ahFN$GIfN)G2 zqRBuWPf1|o2|7$yW<6TUl{=apmJ96r+5K5=A88>>|JK926?otpomc#v0j21&v+IzN1^cp~lxXBS`PhoNgT^Y(dOVSzHoa zB)HThcULpUp_g;e4u%WHao{AJYIWbuP%EKy*Vfj3{#7A<>Wg<-fy%BVQ8(1Qh#Z<3 z3gLm_L`C4pM-brHgd_Uv*?Qmzh&pUvkcI*dlOo5s;bOd*Ffc+`AgkfrU*}49eFfda zv$9x{yf;Gmp~a$YE*6qL1o)2ZY4xqI@}53^_x_d!{;3^@s1y%|8^njkFrXnYvejqh8bJY6I{-Sk@y!c!DNNYizbU{dQCb|nD5-uF2rE%w@B8KZXSzDc(Sxa;Z#sTg*dt2x zIInnTAI^O*w3w=$G3O_A0fG2S4G056K?t4ha;z~>H&_)MkJk<5!-BG2_A>^bUJSLu zW++-qb4LtG>>*Re)yLAaOx@2lQ0wY&B3uRFPZH&G033n~&ny)JnRtP_Cyf_$(6EI5FTut^yV0YDNJ z2*m(UCMYNi42q%&gA>D`fcmg|;c^*lD1r@<*`8*&bH&q5GGIo%#P;`vL@dkE$6vGa zUHtQHp|Z9~tj*=S_v%W?DMIh@6+n|>LPkW5SaMK!6RtjQPqxtcqb&kpkccui7D-C0 zP$2DF$j?jRr+Ny2ZJ^LRB!ZFrmuNTFxj@PSFHgD%N-m@H+utRZ;toNGDAOG1MPw_+ zP<^l_(e>`JX|s6>6cQeGGg!iAOSk?t_|5*~?JD@RoN;~PCT?%~e|mZJ#BHMc$X%xY zHeKE4HfDc$dQN)WOAOwBZO-#JKJBimLphdJ^^@_G=P%lX$9msgT+#WiH1{K3wRq3Q zSB~O0eqt9P<6a9&t1AYIt`o}vcZ_$u&q?SO@XBtqAX4c?Cm%TY=E4qeT3zyw?c(l6q4t)4y&ExJx8Bx^^JVfJ-&j2IX_=iKa?}iW4%L5F%Lu?nz8H`o zzOto|ogTN6%v%kX=^+csGMBxJySLAEDGCYVNS{FLxjpXwNCvdTe|9-;)Z^=o5jnD1 z^0ocPpr7Ao#kbH=?44`K@uHAb4CSwTl~-9Kc^y6nyMFl;aH3uzY8T&fSzwKo-6pX14Y90FCh09? zrtHo8UU|j69&8ryI`4Yd>Z`>Gm;$`~=b!n_saL@M%Puv1aCEw}f@X9Lr!TW$koV+< z|35GY4;yD-h%u(`dIUs9B0 z;eY_GUmL~Gl`Adpw+*+P@n~|z?~hG}NmJXO&Q){uG0PS_D(*JXT2T!U@gn)p+AiVT_TN;2nf zJ1!-aWa^^y>D5D6tfz!O2kM+QmmdG}vbPcJ<$_QJ#t8h8m~2?p5zG7vw@X_PaTfE7 zHcAZAHvxVy>wEow0Pp?Plfln$oOaGa`vx%-S+; zaC32$0Jc~hP)51`+#3Z?q%#|f%QTHP)s(_PBh`K*0=ph}|CZf5qRJ?fA3ym1MKpxC z;q@rFP}cI7A;goezu%w67>%?U+D>)6HFS4ba{PY&??m{j{;S(5ZZ1w+vN<}BU-XBr zRjG0WuqE6n zEhL#r+qYesPVzTt-zrtUC5IEX#?LO^DAC|3X52)b!bvTbx^+v2@cmb?$fLHb@4^fP z@JZ#M8}4G$W6c_aqx}3A_`~nASi9||#y;--5tRDbe;S)aZ@kF(8FiZa{B+&yN!@;2 zD~WjaP>$ypCV9`d8$Dqu65G|Ty3IVy6QuRc`CdNmV*OQ@pA8;-Ka_T(YU}J}grR!G zSS*TMO7AS}2!;m5HHaub3s-%;RsVAFgqW8+xNwE1ZGt3kHGTHDrYWqtZ7%QmbhDz- zvn~zQW@;Vlu`l00(nQdPnm+eB#`$%o5=AaR*-ul$e;Pc{dq5_pM0sAb+1g1j>o^<0 zBVl?PU0pbw<6R~8F(yyAdB3jQ00kq$F2*(9jHCrZb(Ts>pHkZ{=ll@)dLkT0bc37c zo=_;8Ld|HD|Gm6=G+9+r5T&SM`#V`Obt!U#DZqBqRDY{4A};BB{K!`F&9<9(vsFe# zkkj1=!NKPU0|T-(M!XchtDdsq*?A0b4k{(JqJX1g4fcQ_GMPP$A5$bTdlk0<;^A`$Y*MG48XoA#@JXr27WzckLFjD&Cgv? zN)io6Ru(0z4if78)Xzd*zHJ~?Y7SlUj#=gLueb9#wl2r>K6{#T6(gEs6PJ>Tql+Jw zLqj=BSDxbR{`t1b4XUG~AfvuIC`%;RA*7sAHLmMly>eQydH!E^EoQOQ=L8~Dfoqmd zO~YsYx^Aa2>Su?Fz^A6*dIJ*1N9=5gB?g=#PmdRVn2-_WR1?YPl=~aW$phafvk(i= zkf?xXw-utpgO|M~S*(K)O-;YoccdPCUzSj$YKPrxets769?Far_^+e=zjqeb&hpz* zOYrq^wmz@{8Cb%)%*f6(=K!6jmKOq`MY_??tQ*XNAP^oAQ25%QY4Ca=jja>7a=-6IjHEuqB-+OW4#YL(|GBd&-}H;VN`FMwxH?&w z?4)dHsgT64h&2hK#( z&Ou7}REV6;LqaD?b6OGAgHzfoikbB~H+qZjr{M4cK(IWVCyd4NU)sBiQ7f-1{b%|! zQIAinsq}IXTq(VdhEUC5`pE9Jy*Ra#G2MuWp~9Y_kg>0~yhy48DsK+Q$$Ts_pSHGaQme87_^;=cN~ z1|5$|Rv87op+`Ee6GR=tfqj1Ac)XIrdQc>(`^$u(7?)}hLDDeHq7QC@lbOxKv>j0&sc!m@xr!qk2M{aHS!-ey%?7TK^Jfq#YtlQVnMX+s2Qy?QH($UqPI% zvNl+?@?{Z`5Qy3<6fp504LqE>)kmSgK{GmLT7L0McVC?82F_ZzVi2M5WP`lKy)I2) z|GGp7RUpH%N^)J4xKWq0nL4W^N%_2bW)Iwco9GPx?&I9klR zay5hI(}!}9k=0XmcRC5ra}us6oecO~dHPsfW6LC8up8=FJkd6gsPd@iZUK$=7^MwTk(Fw?^uMFHVRMWtgb5iTUi$ zW?$kTCvkK4Ep(>=QbG1cKduiSJ$dpwhW~?C7C_?d2R(CQRJM8mHs7CzGd&gwAxrx= zN<^_q=OH?%8Lh2rb}&0NZzT}j%16n&eeKE@7FMp<(NGmZ3D-Yq?65x!`d$+v_Zw4& z{-&~3i1c-e(H!}esyOZIIlq_{uGXuo)Jf_?#O6%v#-={5Nb~4=#q8=H3{xuUdu+U? z?{>#k(r)-Hf9&sV)Apyop-K5y`TWaOV~sDKdm1>kl(2PuXkdzqax+=Y{)KLIvmt*~ zRA6okbZ2k5yJlQukR*ln(LzzsRu)?YG<8DRMM;<;CTtjEvGWFy3ume~93V*+`9Ejm zKp?`U^xyECpwZJyiJY3s{E63#19iOV+MI|y$gWFWsw4OVsUzt~xkw_qku%do0lRU8`-wT ztLy~B*B_ar=Zr;YW9vmmbz-dUZNC3judwuHsqD<;Qp+IC=E0j>vVWEHn+N}C;s{w` zG;XnEYis1mS<(bUDm)gLt;v|7AO$UQ(P_0zdxElw(h6Tsy}z1rmGFS65~^G8k^&?a z)a3GaOj$h8}FLI8pB??*8EuB@1CEwb~;KS1sFG zzqK!P@gyEnVA8(V2Lj5NeIa~I$Kq;s2zv9kPP62Qx<9Ngek!a_Z@8=tcWg%>TQG1pp344~m&B{}2R!YPx&akNcEt0i4Jq1lQ-k}gxPva3 z#?~ij=T#rRqD?BAF+rQ=UjHM3Nw<5b7WGdm1r^B8L}!;4XxK%EqsNTvFCFbY3C zn^a&d5-XDXd2IFVTUC?hEEi^d7UmzmaSb_~RQ5tbHqLTIeU8cT5d`$3Yjwb(@Ca=J zaC(PMMw?{1(laK#a9KBXak=pLmTD=Uf-mZ6pNPaVYw?7E4pqzx`kWGjq7YSMYX`Xr z8k|T?8PZv>j6H*JIB6D$=1s;aVfMN{lMX3U!v=t+qi|?|?~&X$NigA)J+wf>d}5cU z`-CvG<5JE|Q6N%uJ!vKO(Z%onuZFiZ<^G>+IK3a4YF!Fai~;G~#&hR97XNnWvke4g zG(~a4%5{p1jBG%N0$gvKSos3c1vaYaVL&&aJ9C&S-9Wz@!o z-_n80dGSA^j?{yte3KR)KLsVFFrA)xGz#dKOI4u9h z*eHfa6W;s*v_lhq;@GW!O;X>mmDNn)dG`JgPgLlCBwYhTC0rA}*=^QV8*bLkZnJIM zwry{AZ8kRBwrjI(ZM*TS_xlGkXU>_KGk6|-fb+iT7yo;GLF7Xi!IYJCc3d%kBJhCA zpnHQ#WO8yra)VMvQ$&CP$CJ>^KH#COg&&*^3)#SuW{5s-M*oAUtTz}``o&8z9sdh? z3B6~HdW4(pS0+vfArb>zCVluI;%cno1E(@!zPLy&zovBGHP7jXL>al zDH0Uv{VnF{e8q&+ab~IxXj zgS#V!l0nc(zyH|QnS@f$WZCw0h*-|4q>GN7CdB><`)yUpfw;2FO*a*8WIluAeP_)b zkEB2jSj}(ag@la%)uS(jf{M!2@+Jk6uJ64FAcAmOtydwK~WNbednH~OvPJdBO8IKGc0Kw?r+(drwNrrMiPus&H{3+7bbWUCEZ)F6Brb5Z$H?PoGvQNyM*4y45>umvj>iw;T}r8?xQT)4l4AO{EL&`gZE8)p?P$?a8K{= z>vN7FqG((^YOBhrVM@|k6n-i}FaG!aEKAMHuA%&2LO|$>T>gGY#*0O`4?O&jI7shw znxQ+7az7K7NXf{vIU8St_ygX)rA*b^GF-0N>X#;+=#40JK(ptpn%9=}VhXWf-Jh=r z)XnPsXl+&|L(sa6nNl`xAti|&9FMW_^!+oz*urc#`{}I#8s)>AGzX8Tk*4_6{3(;s z@x<-QEEX{#us~|VZE};^=VG!jI0q;TNgf4{Iz-u%_cu8l`S2vtOzua+NXmkhnzK;t zStgk$J=8k~c6dsqoF?KhM=Wescs!ole$uJ|0b}EKH-q*=Mza%u0S${sH`JeHO5fyd z#!tPD*_;k@-sWcseDMU3R`@v_;Jq(Yu#k16qFm(#^wrwkveEj=gzR7`gCj5*NUxy& zJgF+9V(nAA@2*ap!+K~))Kgi=sVqIlELko~EBjbCPfNP-C?TZXuVR_RPhPHRoOSS2 zvHtXI`U4Una1rmfUGdQRMvC(*H#d};Jb6SbId`c64?Z~(k#NQOBs@SLS0v;o|I(;^-<|{0FX7l= z+);Lv^4NF0PrX(cGV3Y`$7ez6e9eE$kL?@n~_^wOPQrmztQNapdUNsxmgb)dh5`c54$bW=of>){uZgrc|Z_xHWM z3@P7)GhLE)JM#2XhL1_xvnH$ZP`l&(nyl;X>%5TLuv7wLrfMqoDlifG&2P(eZ!a|c z7RAOlP$KxCa>Feysd#e?(ABVd8c-Q4w2(=xDrIFXHx-eD^*@>ExQ9|c@G2PxZE5Gx zU>W$2Mm3twgGl_=f)O#vFoWzECA|>!h9F8k;PJDHo)i|C*SJO71*4S1=s&EXp zKtxJ)lyawPDM11kcmhn2tyfl3B-@W|n`5pWv#1Sd%`2ZSF4UeCrw*cOOE$}#u+^o_ zlQ2_~NnA-bAsX{~=q|Hzs3NP9E%6B@?E=-KoIVGfb~m-KgGilKAoh`}u-MjBVw!lZJ-54QP9?m?-L z^+>(_Oo&j4c#IGqPCuAK-7Kc!S{s2u9oiz$sA~!W9b^Jv2Ge%}g+V#N=re(mkkAo9 zo1mW%T8zIxWA5sI7$zSxCn5jTXhrsp|Ad}CEW6^htCue!VuMNds|!j4drRqIhG6_t z^zp$11ELKjS)d_$Gbw#62+XV;S7_f0Dez{V~j)O+$?I8P?X$jeOM^>y%e?6OC?VQ-nJD3;?$MgFrv^_k*`0O z6r(o4!D<62=}-bz0_+jf5CcK{-_lS5W`j9+alSy}`up>y5d#DQTd2}t0>J{)7%;$m zuZNM{csH_#c`uh2It_=u8D%eDRLk~Xk9TW4IZ|K&T_XGi#Cur~rNMhxdhy{WU;;3D zUI0+wS$#9#5E=XI2b77TN%?;d@L*5yLr(O$Lq(H|b_^ZUMzaqY_jM#iSBU#pCR-9U zCaJ0r)=dH%!9;@xM(uQoan#HI+(W868Of!OuxbfGX`WOxv?vQ=H*@?4l7ig6jxJizt3yY{^G_s@rmY5ZWkJA};WyrYfNZqCppY;3n|tahrf z!|Z&l`SJc(QcNXn`S&6C=75Ix;X&D_&dox59N>o#L_uY>ds}BbH!gYfjx-;_)T+On~Wn}r?Gm9GfHve z9~d<4v^-!(ukCF?+Y3vDnp4k;&=`|iLEvj(ezU04Tx;7^e2mH&Bv6W;VRfc=wLEJ{ zV0G$k$j5(H3K!hBn`;n*qF96d{=U{WRj-&cL_LSN;oVt|!eTf@QwpdOs82sW%=hOe3QU|?NZ;nAZyk8(A;sS2H z(oCdZo`7LMAe=ukkGrptl=2I6XE#R6LB}MCelz7-LSY^wheGLM#BO!h$N9(lmX}L@ z6_Qc^Q+U|COZ;;aN0Z)SXs~vJ;|8*(Rseswpg{2lwzN1s6vYLA307Z(XRD#<>h@oc zwXQ&ZMbC8cQ>j~A>$SSR5(1v7C($X2(nS$8)|b9h^DtDw5(7Ydu?VezdwyVFF>U&RB2nR3_IV|!vMi$2BSj!GGhESg}uhZ<6 zS9kR(2>T{eJ2%DS0wV`A%>KQF61S0es#dbI2`lnp+hh$WFgDmze9ek-rraO!l#g9t zYJ=2Qvl+y^y&r5c_dfXG_b6X<4@5e7;v~xSsqk5%j&Q``8OX16IRC?VKsULPtvt@GJnT%MB~2F zb-!_JeYWZP5E-JlX@ECo=!tyx=x6)>^FxEn;qP`&_Nmb;G{%~$;lK80p}g=!@Rl1n zRAgZ7BWipUo47#$cvF+zKMv1^fO6RIUlF9-3llz6tIm?hO-JEg zgAU8!4c?x*q0yDvcirdJjcO05GM~0Mr0svJ<< zEjsidbq^PUt^&J9>$%iZ_NXcG!TXC-o8L=+f3#6p_w+G4&?srH%6r&efwa`59u{cc%7rqYxr%aM+Aq9}7&({$BRFmh1= zHfQ|)pn?Et+)3BBSI@!Wp6AmTEj(ff{gFt;^1+uHy}_xOdK-Bs3>E_t1EVUOU5oD@ z8hD0jmdcM-sn?T#ue5@3m5G!{YYcds>)|%1C2xSK%_7Oj*DY&Ai zsdEUT^xg>TbGY~99|}B}gW(>MHe~b?$W!e^PgPRL%59Wr#T-<3g?S^N7T6UIdt~0uzxXs7}G^{Yd!-7;?aTNov7TZii3XSsHM-AiyFNJaf ztJvFpKFy@enei$`ronNPV>f)Q2Js*F=)P8RFU}kzm>~Q9W=K8#f^eAbLp-w`_eKps zD6R3)MrM(VqSE!U=zP}bJME-0LQA}*KT;4l#GmewBPEQkUXrVNkHc<#mb0_F3f8vqv%f>m(}&}lm;fa8R#mgm& zWFFdG2NFKyC__A%j>p%TAA5fvm|()uSS5s>_Ae`GNkdzYuOl9GGX0LUbelJ~eS`0JL|$u}Lx+bPjH5MiH?CKWS9ArL~UsS&{#XbVkA zbER1YDAbO!TGy9Uj4%4BabuU1gu6Mu6uBB=X1B!Bc_ehw79=Ict~Jvp^nd1B2j2&< znhEu9!cjKV;H1*BG@gC;{t%A>@0q3FM~^PVN-p%;es)%jq!UA==Yq5qk>mMYtuZ2h zOk5Iv(fMJ3P1+q9#%_iNn=-A86M~tuNT{13Jx| zav~1h@Zhy>rDF*%vekKF*o#t*Xs^N{>^*Vya(LEIaQ5_KA9!&eyn&q7ST#8MTh{X| z?L=AgbP9WU^r(8YY@KuMU2EkNQSZS9;*E;%1c_NV=!l#2e0>8HJ8j8_M^MeS;tr_K zGJc{iD9?6jSpmuQAu3_Cv(EOxd0;{d-y?t-!kCH?#az&m^9q&~oP%a5gs}Qcq^rx8 zrz%QZ*fsATyVCtbXf{d1mRZs=DpKyj(=s2_?Tps6JWEvnqV7FHOKQS2Ukf&|EEM1E z7FCO?y)wz#nX|F8p!zh;@gaEPMFbUg%Tp`U-7VDkg5=-qI|0Xy{_Doxq~(G;lGX9c zE4W~W?PEZP76=1}0LmE~o|+2Vj-9}PpdM(6|Ngm~RLaQRc!o$m$|@1a+A3DvR!LYE zZ29!KOAm8PLW%J7>8FJ0DwcRVY5E_mp_(o@j*%N8%epKo8kq-;Yz(0f5yeRrs>#J& zbRL?kZ))#Ltu-n&8CU9!$vPQU>JL9wsz+HWfOxT#$|@M6mW-O|_EDMDS!kh#sK(d- z^mE+PyrCiku_O^3du~_N(drO?e371BqW7?p(9NmfO)*U!+o9-+>c|Sz&^;@^z^wSN zm-@gK1^UxZ97B1JNRyJBntrq^ulBN|7X5t@7L4e+kPLcX7(d5u?AwaIKjM)Ck&$C# zbBks7PMJ465{yLOMzRt;qm%P`Nh!$z0g7>R-sL&ZEK0-{2f@!UrmU&L>{4<(aD zG}r{aGs2201vYC;%NJg#THZ4yW=`#Cp0ph3Zc@f6EY3lLd~%3FWC7C02ZkZ8B16<8 zm23+)tk;U?BX2IYs1t_f+a`)*5?9`D=d~M|NDGsRcrN!-^^Ttu{6G?n#ydf&r6(gvjy}5wBFl}6{2-j4V^~pAwEvrU4o!~s7$pphh zJ+5IBGGM#pnvNW1o`4gTbFG>{h8eQ6(%LgQW*ECU+#@u<(7cxazzI7#`msrtcZ*NLSs7y%a{ z5h^3)62#1($p5^sOr}7oPD5o>?yXtrAjT*gm@t6Ai9}hZGZDJHB-*f6*`OonWQA?1 zvf;i7A-5phC_SySexZbwpkn>H&JkVfSv}7qBE0A}EZQivSx}|06w0DnwLy!YZZ`z% z)t_0tdP_#spnbgRepZi0ib&MQjSLgWg;$SWWl@9(5gEusYaj}D|ArhwBDvWlJO{2U zlNnF)-`vFvg;KFbm#4gs<3#f4C7;cz>O>=}kuWsMZ9%lj!I;nWN?NqijRw4}0Dfq` z!Uo<1CE5xFqVrX+;yt7Md^y^W;z-NJEEYyw?qWN&!l&%F()ej^EtBv^%q?`CW+5wX z23Oi~PD)}j$K08hfuuwR0itGBag-}$gfZwvsK1#_D=y2TzU>hNa2yK*>C`Smk$Kw8 zTm=hJ{VMmO(yQobYbZ&VYtse##;2j#Q4B1f=e$Ctv}vd-33JkcLL%9lv*}fV_-wBj z_ad>t*|o-XP=mR$+rk#L75nz~*n8)eF(IlfrSTx&@3Vn0RFU5Hhi)H7h|>HQ)0B_cn5*PLn`5TVD)(6K2W`a zS_aicNI(RND)^jAIw71;$bP(Wt}}Q%lN7cBm17X%NnDhdNocM_qGPRuz7?6F88o^fchADVF4(xpXVJeK zl^YR6iq*s@GjzL?Y426lcL{JUEQ2vvtv+pWKzqBSt7|lG(Dgb_)VcaT@%OdZ#ExFI zK9yQ;rt~k3-dFg-it4%gKgkmX%y(m%>bs%%p3&y9%553#)fO6(YAC`mD8cwMjcIc>wYpKd&H2BFTE&u4F-6VP1!z5%cO?UA#&y$LpbEGvK3Hi@lh*z@r;6jJ205EJ zDi$*5|GLE#3YWKd9vmwoh*&b%6S|@w6d9IgasGzy6iI9UP;xOPCB+v02WqK${!`}P z&WOGh=22_GU=`!Z;jH!X8=TRZizL_PQRC-Cw4@r zg~iZ8>n@Vn{8GtmUPz;ZBfuB_@0=+IK5%IaVSW{O3ZC6sIb(!1QgCXAHL{|ZALZ%S z^{RlVTv!&X?5nVGugiz}${+-=hi5wm)-vn3FCQ^x9sA=qpOGS8l=rKj2j-Wpl>8PZ z-2a#Y#dMBz_ehz}%*VS_3sm?;*KW_Omj>j%h;!L%BwGqp=}O9RYaA#JMMx~M0T&$< zT*h{#e(AvJ3A^6Il`H+Z1|dzcG)U#gY64AEsnA3aa~T`iN@#4h4nAYb35ArQsjs1Y zCnn!4ZN@~`VCJ1qo}fz%u)Q?r!e$M3=?7CbJ{>H`T;9i_H|drsF~{BXl)c6F)>8PE z&C=D%G^Ev?yTq%Sw*+SzTJqeLS$rKqCntUSEu3L6s~mPcV_BzZOH5r5T)-PeSB21j zCM#4qI2FAj#{p879%Q3!fNyi>77#Krt;1RWd4^LIXQ7(92Jtf0^crxU+E(T5Ru39r z3BN1McI&NpkuVnz>uK5bd^89ZLaWVR$qAMP4U3kXOT(swDr_cyo$1(lk4<)pDy=Uy z7w0p|FZ_Elnz}z-G@2r-tug$$$lQ8MP|B6w`-eC6@mJYlJz1xpGa6%Ys>IM=T~qsi zvd{f}*Aap%mH;vQd7p=-p|namV?p6;Vr@V5h?03sS#?c*Yu_)_l^ebK=S}DrUGg8e zVd?{m)Je!6cXhbdmlP%>3bwp}R(<9{2a%9o+$)Q76wfk2Jr@fP1LXgdt1QAg+!rOh za#{<#Q-a_Qh$e>0Hxa|Qjb%xXMe&x3O_fMMF-SHRH&-qvFS@we3wB>_VX{m)+r8ch z&zzh@XD+yU{m3ONs*22wCLS5xq>X5rtI)uc=gyb6rV(4nL{H<6M|qSv@S(wbXQ`euH1(wU!PvC@<(Xc~TPuiRkcp)*jY1*J zVVni*0`>K4o;}M?s#ED4gaWUKk82OkQBq=PYs1^<9gkh4b_Cb)*3d!0>MB1Ck*p<) z+crXQRU)>}c&s{A;?mwII*`bns0ymvH~3$uM`9jBQj1=&rq-KWJ^4)`{fF!bvIA&^ zY>)|jk2c}7)RXvvdWRj8SVGOs6SSO{>5N9vL4i3*j^B!NA|*Ad=$#ZlFjdk9W>jgh zH2pG10UdCPhP=!{nQridg_8X&id*%r8d72yWfB$a*Gss6E?sIAPZR8ToDLt#fAN)2 z5h9IUA9No?K5RYG4aX{UZyns5pf~RAS6V4LPea)Lc;(M{8(zSEJT_zA`glR8k_*EG^G+2 zF=!+)9?F}(N8Htv6l?+|L7uu=MLQ3oC%Xjz(jw7_DR;yE07ylo>=-lv6P`t)FC@EV z==JRF!J;zT5PX7<<^M$ljZvkfRx?(*m9r>GZByFw*8`;cGC&6Y<55W&RRsP++w7K~Sf+C(%4ivWS#$gp z5Pf+qzHHzmuK+<>ph%lSV;Ow5;_$*Rb4B)o(_bBvxc_{qNKuQC|5hTIXMM$bTzp>C z$o!VguvbRimrFn!ObH-@f}dO|zIp|oS^qGC7!aHRH^d4THhbD~%WlKVD5;>d-pQBV zQ6jX*!WEStbpIZ543d-4yC<&EIt8QPdPV6 zfh`%bxD>AV{QX)EaV!f|4A3a5wJ%YEohIF2d(0qy7>Q2$7icatE0%H>J!0i=h(*gg zM@XJ3#Rb+-O5?uw2gk*GXPbky3KDDF0Y{!%cv$tnx3{zRoYRu{M89vSyjY9&I!HJi zM3&o&;Qm?f-iA-Mdiyj_!&R20Eg8T2yV#ih-tE$Ncx{)z>;OzNjWV6K-v3?j#X52W=i-wEIWfx% zeqU5|S`NnJ0^HEr{3)4|sZC^LzC%`FD|-D`^Q13$E7qPz+_=By`$nTYT@Rj!eypZm zT_vPFwmDP=Q8TNSCr%nkUWhajhNRM#R~H?f$ItIh4aQCj-N+j>j8}SdTE6Ipzu9%?0tMth?`q$0urdU!OcDe<39jtTpP!8=488N6;g%j4n4} zJW5S%i6=cXp*DJCew!f6)*jM6+={(;#_Gs{Fv9ASO+p?U$ZO;m!$FIb*Y>6AorozP z@%;AUM*f4Eg9$6CRXsiY+$KBpvz9W?D*aJM>W7Jo3|eG`4i%03S-E!R&^MvbON?lW zK!i7J5Uk8vDO>9#Ysm10vCP91AX>Rx&mM?TWYv!}nh)!2YzYDAPRU(7&t8; zfCc5pNOi{_Jic^}B1|rP(0h!JgwNb6+4J|oTLKqn$#mXN=${vVh8=!Qhb0Xx0-CVj zTA#_1`FT6e!2XA)cg~D~)lE1j7WatL`8T5`nvP6-0E47RT{_ab)(PaH-uK;7PSii0 zoi~hW!Dz0b9k5p59=>3i4$CIv*SW+T(}i<=ST@_9*FN9-DU~vKRo9&S&5fF)Lwl^shSa5--}^?#_>f@;>Ma%p zoE)3prreVF5K>))f+9d$Uw2?+$p6}T=O_-FQZO=!K2X9Kj*4)yyX*b&yBmi#^ZMfg znHKCES75Xd#&y<}-@p00h}JGks!XVD(ao}OQjx&GaTZS#Opc0e1=cznD)BLDA1O`J zMW`Tk_Apbm(3I4Q#E<2BMid*S2X^HVYGf)BmXE%xKh};m4w@j~Q|2?!0&?~@O>r<& zAdNg4xGrcopcgoTBMMusdWyE*)K!>*tIc(V>W8l56%DByM)jmyiXc`yD<4S&LKCNo zP9JvA@ah@E!MoAJxF17q-hOibHcKT|Ne7kIzB3#MNbSFwT9xGrD_OpUTbY!GE+lsuep* zmd0i8*+(7VP0C<>$!7KYQc2PB-e&4MM)JgvJZv?Ev8^)CG1I8h&KR3OP|5}-w5I>DGNtk5B03I1j|5*LFr7=OKuJ9{@ zm$0@RtCL6=6FzMJmvZEw^Nd)}zwdvRPv3V!iuW*qB`C;ErxIVpWGC28i4D#|533D- z=M_&lfqpqkG*)>pF2CDv)s#Qti(zApvtsi#fnaxwkJTmsxQm|}!HZGD{ zpq(Ox&sgxm3%u0z`hijZ5@67rmaMW<{qMJ|S>c5$VfNE2;@^?vKzDN??1tJa z?0mwo?)Pts$&*Sw_l0zEmOm^8CrVH-d>UQepZe=`+|a9H^&cD|iSa{7Y^>f)yTjMK zg5T^Cb=$qL9Z#%U-6Ap!1}mI!5!pO@V0Rci@Fc8<1cYKi-#CcNQ&Qycbn;OlLVYqT zUNEgP%$j z7DQx&WeW%mcb-6QhAj;$emLSCKxB7lf`wh6r*zA>XbmMi$~B-Y@Ig_8lriPq=?-3c zos_MJw_mMRMr$nKQCx^avl;f__9)yVP2^p{i`Wh(M5}EQQ0q};p_sjddNpri+QeG} zav+T@xqPdC&N?d)I8(-W?rswft?}+<2FFEHyGr_~w0`zwrg*WBqdjAk$)WDLOR?-L7XcGS4g_QPc3Z?De<1fglhxQmcVkOEe0L; zYf8LZs_C;Z83WblxrJ2I=~ESll}DS*D@u0hA;IU&4p+x!t${3;3Y_vnwL^J5N=haa?h0bRS7rx4Nzj5fgK8qd zf=}9o`)NRS?-ADHeaA=TW}-|N}>6X&HIYWJ6NoZ zBglR8P4tT?g$*EvD#aZGLP^982Sxd`ErWN73b4^Z&HnxKXg~F-& zL=TjtQ3NQWBqa7AWc>8wnEck5h#ZIg={3%&#`_=bxW{ zA}*mFSln~!E4;v-TRvfzXJ062qx9XD$R?DO?!IZq7U4zTD%3F#F-UpOk%~evl_M|< z1)KEYx0)&`oFPPyA3eo_ok#I3NEit?zJ|blpE)etzK;+APu483_6u?1{+T}Mx?ZL_ zU}wYu1f>x#sopUv?ij=IX$z zz%j^q5!IPa+o)(S_sjp1V0|^+GOVQ$)JWpWl>nc8Pl7Bt#8^6TKdQeh*s&CsWo5Qin>>@*K^1WY-$xCtxRF&SIQkK%t`>d% zyb|o_`gseUR+}!1tML#nM=F&f1(WvA4&9*5F=ZQ9;H=&*StHPRIWmwVguuG zN)|1qOIHMMQT+~!c!yTNlE_4x*REN}aw}Xb8d{=MRrwKrKXpIn5&r1Md42P)OO-1> zY6>sWY^}LN$d-xP#v)oEa-Q#;=ZV*kuAOZ`^dB*lN>}&-FvwA*7ODe!yZ4Ad;vwsZ zBsIF`TS86L_aIAV1aDK`Qu&DwB+OHdZ(( z1!vA*9qJwjKQ$Z1B2cK)iRYA)jiE3ChSokX65R(*`S!7zamLZjZ%|ptIt8^Z*XN{<=kFQG$_VXtsRk_)>%1FWspWfgh^n%uK~9a zn4pK?HZ{>Mp#bRv0L%&B0cglzkFB&okqQ_XP$WPy%hJ6Rr^P|tP&5CaI-o$+T1a_i zcZ5Xos7D8ZF9Pl-PNonf?ALSH5h#Tw1y=fmg32yyRu+n?&=65Gq|CPRG zxt7%TO2cJ|uV4PK`s>`tZCBuIRG+s%4h~4#I$z9fbXa2=5m@i-!$=QNCW)l|LJdhl zI13J_+(=Dh|IA5t7?|gB_~vA2p*JK@IMB#bb8~-~)uSV7hjUM_ss&hzP;*FSzlS}b z`h1H$6nbwa`b?|F4Ha({Ay}Uzx6B4k9{{s3)Llv8w>`V$x-tU zCR?Z}M+rAKZ1_v*nZnCRK}jm@BXI^InBUVfMWx_&Ac|1{IwiE8er`NUYh^^$-@dYR zQqD9d&KS@KV*N9<9Ncf~&piAJPXG)(HQZXNmGi^yCOrQ%;F<#xjC^uo}tn3=Sr}vAu+Nzu0?%pXl00ZmY%xGR!cLO~C@RI1$77HJLhh*JN`D`3| z1O7b#9maG@f}yFhIJi;V$0m6SP8@mfQJLy?aO77hyb>$=i}zOgh~{)LEh0>)yOgP6 zorJk~KKxF4_)<)N`%TYZ$BcqKX?_gcY2z1`Y7B-#A;{CEcwXz<9Xij%;Wu8U5gLT7 zXrj=&6)i0P%u_0G!yta_8zi@VLz_#$I1l1Z0Be~a2;_u@|A7Wew9bv(haKMi_V(`k z_T2$d+d7H^_Fo?4%JWM22A9 zb7(|Fmw=3WA~V=(GbsAG_d3^TAV!^j-j~F`)0;jtCT`mB9FO`CVXo)YZc=QY~=uURxR9&9&6@ zbDk_AP(n0ge?HVoKAez`h2bfr2_fd}i3XA7EuH2bd(P#IY2nmHVU0g4>w7&mJfYF1 zFuYj&5o)jT6qqo`@w+Aa>87YBk-~!3OdiQ^^f?>H2w@{8e@GRcX^<~$Dw zC!&_ZX#q+!bL{Tyg2zs?#@AMFQ;oaZ{dXtdc1J;nuoX!uA3=uoC^@wv4gO2Oj~C*f zkMAuc^~z}^yE9Dq1z<=WN>^6FRRxJ^+58KBj?Ir?DKJ9+YQr}nYJ0>$gtoYaF-7g%5_NoRd zA4d>J%1-(kMmA3|SC2-H%n5`Fph~< zCJB~_$_)G$raaK221+^4d0A48Dl1Ur6&6;)9}E}`*ApI}zGa)PzpmQ%W6!ssY%m?s zOJGRsLSa0Lj#W(IMa2|>w>1DoykcUtWL!=mE6jnfwl03YcHn7v*DP&agnOaFX%LM6 zuQ-uv4r*>x-o*K$i}?+WVMu`=yQ0d{RJ3($TE^M?ycH&!;p>@&rrxC`TYE+L+fJcs zt|qdnlI?D%U3R&Z)uln$7Dc%5kLoAyeCNE!wgwfVWFxcisI0fw=L(l_FdF|k$Ou+& zB_dR3`_)fd%+Mj!75k2(iX2%Ai`R`f^c!yj4AH*}4RfjFIEn%)o4U>Pzd`*`yQ@aV z#)5*__uU>Ll6MTOuR2$dq5UgC357< zDvuaVR{{7p5Q^&Z-__ZBs`I>f`F0SE)5hX<@S8YDez4h1_|eTx?#*P5@nC#%DT8~w zkj`@FZ%a6TUZ`pugpZx114J`L?H1Ob_VU`YM5s7z%HvBgHTGaA7gG0V;1>pnu6%~F zGJ)e73fy3CCw9p>5660010BR9=)wsehDy#J`mTQE4NkHat1# zeiZ4$`U9Zl2a{o;wqw6Xnuo%n=nQgb#5G^=!-3&axeXa~RyfU# zs`sqEb!gwI;j4?wQp8VWp=dnZIf5LePHZH&Pb-_TcK|}k_#CtrHz2ZJi5&ARXz*AZ za`N#Bi6%WTraS;3PrkQS)jr&~@~AtXNy^4b!1g^?3>*E)6 zfb)jtTtkD!QYOLD6*8E@pIChim*RGseCi0`Ot6atDK-UX_DEn(dt!wCfx0~vWg{k~ z(Y$yC)w}{(wBpt)bFe`_@uh78lflkjRR~hf%?b$cP;|4@fADUtH#oG zIX9Fmb~~1E_2r8rQ<=H*VowZIS%UaoNl@_fjg-y|totQZ?P*N2UsXbx(RGnNr*)1S z+ArA+CfjsFfdC4vQhhAhvHW9pAId^xkzxw*MkX9~x=)|5PV?0@mv0>9Bt;H+xvuyK znOmu^*QFq6FqWV@{%(e?`OmAH;rU{OD4q5R%TS6!Qnh>284U9Yc`R&|$#&62<%+Ah zjOER>%^j(2CVZt1OQmIDhqPtIBc9S~-6KWK^E)S%E$BW^th6F5CBn6kU{S=0v3qWP zeqhkKTd{e*|2C)?W>Fr>=xLciQvEs)-6URjn#EvTSn1>%qXEG1MUkt=i2XhrmA2Z; zeGf79Q|2Y=39Kx>d6bqf>t1>;?rrfIv?cOomi0w(D)8}EccPXVg7LOoCEx?Opa#zD z&CiCUbjF0K5&{7Y|=f8X4`8%$9>E5LlX>9+Xy!{=&W)NITz zvwOdB-rH@h>=~}xgc|BRKw>ZW{0)XoA>ndc>{SC|q; zu)qffM~K*`(VAs%4#2j_3j{EWKX(l!2s$p3HMu*a;Me0`)<(D2Z1VXu-talyXkW1p zvbxP_ijZ>o=yDJCjv|KJi~VqyYME{J+v9$Cd#d4V<{U+~+`gSPr{vLo#{v>T19;;L zV7{WzN}|mxlX1knvX{0i+U!QO;*rkk=bBWA3Xl_i`h?${Uo+AP=jY3QwQti-CDgv# zJ^lHiLWG*M)Zbh=tfjgVKS;T@qYYaLD)2D*5>x8zByQ5C>)y#e9CbK;QhW75!NiQ+ ztm>FZD7`h*An~4}{84Q7m{58ro#Xj+p<_^*7(-o{gv;2=kHi>=9RW3$w;qKSU@}sC z)z|;QlGtnXrS3dLAu=xF4Ya$J8t>*=9gjvN9;T3rv5cW1XMwr6^)jd}bLB$Md}x}v zlw>E!zT6?ExX-PF!1}{=3f4IP0{o0g_OB*($c%6fC8_&5!{(CE!s$skZ!oyCBu+9| z0L0?2Z?))hbU%P6(L?Y*@a=k`CdHROZMw_k^i-@nhhDp>ev)cDLJ9j7P$^cVx3ZTH z2{@Zt570^!C@e4oEvRNx)RWX`?rWH3`<>wqSx&$Vzt(}7e zn8b4S*Yc+D9c{U$s)h!}58HY_azm-{!bn3LKzW2LqoVQte`d%(XtCQTY$`QNL8oDu z>@_w9ow2$5t0`DhpDRo7ku|2yd!2oitlJCE!znjlC%~Aw%m=FmgB*McT&SNWQp0ksY+z64>=US^a z-U0-Ei#|;nqO43MKiHS)%J2qj9KijLqpNUe`un0A9RtP)M|aogmeH+rcb7;vGP=7) zN;gQSUnEC&hajnhgkS;64}YJ(;Jv-O^X`53+;arNI6rsxiqFjk_dlF{zUy^@QKZQd zNxdyN{><^^JNqV0kCfp6`YClA=0#@@IlMKz%YD)C@sIFiALj|)$P2+IEAt9C zLF~yN4ePzBE*@}H(EnqM&)cQ-&7`s~ub#xVUdd;yRO31rN4<_UUFgC$<;~SWg@OsU zKz(B4=PW{G6uxKeaTkYuTjzj<`?{41gWrvKn>$?is;fj;6T-yAzclK|J5P1yf)Xq$ z;+>~Ly2j&c0w=u^rs$&#RD+T~8W_0MahD2mFlHJ$7Zc}@6_Q09Tt^6q^bm`&t{QO2 zxWunlU`{d(a)LD}#BKCQUREXl#P}hMx%J9#g7+Ea*>8buTEt5aW~K&Xfra(uaWv5fH;U%!h-vp8Gh{M*;1m>yy8UsT3QxYcx#7n{wb!!DOK4rPSXbudmi zKAcD|pB`M(3BY`o2~W`>K?jOvJs{eT%h}@=I>37BlLG7zhf>gP#E`37DUfi9l!4~O z{#?75!n9L@9lJBRn$FMcPLzmouDo9BWK3 z^Lj1aWybn9??=SDF*kLFyT7Lr*(|RsEWCm`&C<&c9j9x8Y70a1TBX{A=Q7-HYi(Ah ze7S->+xTKxBg@$9%sjsTWNe;r`N$c%{PHJ>e{qiD2?^W?Sqwz8DoDtShJ8CW9r4^dNV?vbE5Z=GD zn1#KERP+y#q-lB@2^bq*Y`s5TnmFKYTtElc%Z+D013b$BjwipYGF@eEI33z=V0M11 z$iBIRL5fC`iiRPC;mZB0CSB-KXx~aUToke0?QG_#?^o;8jGh9Q*L+0fBH;1v@4MhPv{#Rot?U-} zO)f*Kzia9jiU_Bs_D&TJZpfW!&a*ueVsW|OGAC!ZTN354dL_MZ*zxZS_FCQk>wH*m zE)nSUgku3CtD+&~QM8j3X(Cu-Rz5$*cCjvvuqcw7hA12v4!28jG>GkJ`{l%TaZGvG z5fsC`8ye7v85KG55AU3!HglWuaT2noD*JP@8tHpa2Zj0ER3W{_k!&sr`~=z!#yZD_ zx(lCOg0l3n+ml6l?;+7P4egEkv*NhSUP!ORw1CEu!=yu-NGs+u2KpYDbTOT=g|mC3 zb~_7plDN~8A5p>6$X&*+Q^$`ihfVhw8qTEMOq`Ke5ie~61r}v=B1Kk3dUNc$RcYNs zdg+J#qPV+l=wp+@{gDxCXU;P9jY9&2^@5LP(wR~v=kzcc;J zzYK|@T_gzbf0156mlTI+$7%!(mpJ1p{4HUKXk-ttVu*X5Ml2eF!q+sO;M_2TouxIA zo&8q$t=M*CIc&Kvt42GvftGd7P9F}Oa;l!fK|i;z^=Hdaf9uC5K&w-DoyAto+tAZk z>rTcY`i61uV%73obCa}qe!iY~DtPqC$4{biX86gXzN2?UoO?&VQDkjPgh$?~)WNM% z*YWBFFRhd^&Tf7|f??RXa~&7DEZR``rpfz~IwdfB321y&6=cOx-tG;Sg|&2?kBnIH zPDlL1@Ke66m;l=2v`|_Ya8>D1r}lj@ACPT}&Ei^kyz)Y-RhBPy?>JK0ey1o)GT;B= z=%u(pdyqTTTbH+_{u^dG#!g&V-OS`_qx;%EPx^SJ^v`IMu?CO~vbAY+akh8-3_%I~ zfo#}#c!*GH9zxd|D~Au13d7j=r1g@L-GBXJPNb&;Q4|_2E8q{Voqw8+|1iF-boh{y z!C~eRE9@0e|Ai3Zs0@^+bTcc! zGpfl=g3xe)J<==rA^98On8ltE+%jHZ;+ixK^Ad1=G@v;^4uOi$L6;R)FZ3|Olii3Q zPA<_eO6|E)=WxVg7hG<@tu8Bf)o)~A;HMU)gk*XMa@c*rN2f(XC^`+Wr%4C)hJrVD zEqY_hJB7aI`%cz2toE6PxiuAMHH|vF^k|@?lTS1KmezQZ@3wf)&t~Wtc5^TJyrw^= z;Ml}B_h>XKVsw*^w|TJG_51x;K+V#wq3`K;nnw0)SI1jRNq-CpX(0r;fVlr*JKJT^ z!6Sgszx47WI@z1MTKm>7|Ke~+Z6vyqG9INv@xyA2lAPnQdZEm8K*?=`QUH&`sEJ2 znQ|-qctwK^AFggM^u!OBCO&F5bMI>(3Nv*-(=OQV5?-WR*wB%&rshn6*%o72=@L$Fy@o0qF3sL*4xU*5@P9a;%+kM=^gE!mMPcHl*MkI!?=yKqwLmfF4L zrXgKX77EdK5x1#3P-Ec^dq$F6KSH&zN((K~i*&f@T+2?`n@#(u+F~`>G)A%>0{57?_ zL?K*HteV1ceW#kNo|0aCzpkE985)`JNkwfwhB%>QLEAl}i-Y&`(J^{tDZ*f8Td(W`N=N;hQ3quq$fkn6{fokLVS|{RPmy$RJtFRXRU8pY6B6} zvF?V21SW(}BuFd+@k|ns2$)o?fa9A7%SI~Z5UeKxfaNJn;r7Mmlg+Z)4TMl5P80&= zmc;@0$N^NK^dr0NC^7&B#?h!Oepx>A5a(wzN?<<2tKE)R)_?u(>ZRXAts%nTas+Q}nZXmA%2{zLyVb7PqbH5OF)vtR`W?aqHio^6* z`y2+BRX1F%G`q1}*2u_7^NphiH{2ljce~>TAj?}IbmZseXzY>+VDk#qz zLw`veM^4SjkPr2mB7aBo_na{~IMRXZ=nJj$HK9|-;STIk5X0>y!&{F29wL5@B5JH6vZ zsqW%f>E})ALp5Q?)+&tJymo_~pPT!kS?NRSJSCB2{slGZ7NZ8)#v0#lMrkMA>a;?7 z#Jh&l=;;yAlDrbnSDF3o;nU&q`X+22Lz*gaWX0kbhXiA^<|10DyTt{I!=OkXm4xtQ zeRouI>Z842y#6vSfmN?sGOo%2Ou&amUTge|-Rv?YG$+p<7n^`Y-@i-2xvC&Mg5IIA z0_hn*)5B^F7MtCO^y7MlYs3A2Tz%rMR6So*))awf5a)P9Ps6g8cZ%zmdhAq1o`<^s zI-YT6UHo%IX#{| zzkhrzI*>%QH0ogF&Az^Fbhvhw_El|ZSfbqRX9`n?PmmQ=0sshv7e3&cSrHAP3J2FX z*(8H9lxSnlE8Z*VM`BTW08q<--;Nr_OltPU&U{HB zBS!mtT+x8P6>s1mlare`ADpDBnxa~S+wP2D>qb_Al!PxF-V*4aWAeQiW=lEpI{y5Nu4ELAe!ewn>tydTbp{hi(ok~8N){Ntf5 z=u1IT!ipy7-R4lt=XsA2#`iI&>K{4u)w5D&gj`Z2S;J~8x06l7tOan$Q;J<=&>Z{t z)hX1=gcAIjb9;T#jeUQ)U{oa=XIUl=d^iVuUy~;r2YqrbHLYJ?u75o7j^*u;Qv%$itv8Fyb&7KAhMEveF_Xe^#oC4#$#@qjAK6VwB8dR2XxDej-H(8KTmV zNMCH7s2D&n+=wN{Nv-J*lqYg8;+G1N6JQ)3z?A}^uEn@KJq@Gn_Jh`0Nnl_R0mQ7- z@JHd4Fh<}6x)Gcz;#~%C&SS$g{Q{VHWrP&N2PW{pfX>G5#nvPvl)NxzlDVy+#s7Q^cp?A-T|@^TklSNPdJ0OE@5EsD^x+cu)_4mPCYL z9%N^v`;Bt08P+jU#Uyb2hV6HF-7~p3d3gw)J@P!Uu%sNyCQD)i1VC|FF#sx7P|Rce z@O3K_f*6W$R%>pOW1LbOTwp8U$;pQR9nx$9qe8#X{j4(@Z~mT7{i7=Dp~se4?tq8% zbiHGp{*QyBqO`d~D&(QRsVH=QQmrLE(K5q2))P z2e1U%u?REYuGhPeIGEr|*t==e$QRCA2_@ms(|-^8xYLPokE-D?J(A>eHOOlI!t zu`Xfp2fY7CJdg8&S1mgNe(N+K|NXy)_1(yG!V1SR&|YnLxs1M2*lG&~%{;$*GHM=W z{> zM>)mD+w8uKE(U$XfS{u7ay^&VNuDNinW>-)k32Z$D637&Zn|mY!|9&QVO#1bi&U>D z%YTKP%fml4Hv|8uANcJC1x#Cr9PL4m01*IzPXNH6!cvn2NGeK%Bx~62pf_COc3xoo zeAN*)#L)~IlpSoJK!N0tc1m{M;5Y|NT_jSq42`me;Y=tG0!)q1Ux$N2_G10>iPh#chYQSi+(7=#|mli8bOG+gT&dHshdc~iLII|E?6r`oD!WZW`~-0fxIIzWytk>o|;f7{?Fp)(ykb?+3RY_uhVWmrqz z4M}pBs?CZ1H*17qbeR4buVXU?4Z99^=^k$@BiB|VOT!40NA|!E&h97(ZOG?S^S6j{ zuw1$F$+y+#byi&pA#4vmIAe{`tpCk?$Gb=dc8pj&y2 zuxGC%^(>wA%=+=4OOUGO_O636RlKe?=8*coo8bu75zFI zCERo*&9TNl{Q2-G^GJC1G?i%I;_mx+@zat-TXI#+VpU*owzIi#v!5^j@z<{&SL?CH zS)Y!M)1qy+6Gba6c%vr*j5@wUF6S~+I(Sqr*`cFFO}4i?`Bl|1;gBf zjy~aL(|&gRU2uc9i^Emap1%NwMoOyS?WV>!AU^l+t~}?snJs<0ZUj`^5)t5Ooh&a! zPmv^2#&uyyb!4lNK36jSSwJ_UNhDQtlEC|u0)4FtL$5{WPF^PI$Fv?ZHGkV;#X`}- zJNf66cbqtpUNu^FS{iPNEUq+9DqaTX_L_V&RQ#lz!nB;sY0M%LvFvA{A9NA`CR?@j zqunBUzij^Ifw=j}dWME+Jv;6Qoqiz7TtzERa14t(YvTe}*Ydk$4nd-ENHFi~nnFT` z_~d_UKI1yOh<>w?Dh@^cglH_KVwRVWIJ?h~)f~&*E&4)O6YMe56{TjFs{-5pYthVtLJ8hDDZi1a#Y2g3aQGSr1#67sjm zpR$st@0Wkak4CyR8m9NPLX?S=w5wdi?Afa4E}HLrN-Iuf9yX_B)qiTgg|)K@e2$6| zT9MMuRi#*oDQsgb*U5Ayz+Mm7h_-m*Sv~2uaC7x1SS;ki>N1%GT3eMQ;;GZX?Nhq0 zj9eyoeKvW8PDd%rt9Rj7+^2>K7iwhU%|ccq+VsgzSB*8JP=O>d8NvxN=_uL-W^ciE zV)}v3*{oawNqi}BiK*RIet{O3QYSH|{o3K0<; zj(f9dLHPAv_*cHr z)!-=l212sQu(XsW_)%}i%$(^}Xryx2^13Y;eNFWZT1$w)Hol!(^Lwv9(zopUQQ!Tt zB-M%IxlqU^%Z-1Kb01rof0ZMjQ|$O)cIBbv%=+G!e4Q_R6l%5?rdi&r2k%m}$^9UR|O$S5_r>P{h z|6f=3f~$}`zIY;CY<$RwHHkT#eW2S&gA=Yd;Dj?xhh{bOg)WJFKxIBc!$W1iWvHZz zciDBnx*Nl%4skflcrKvgU8yDwN6+396|!|{5S=LRipSjIA^2K~J+4-6O{VKIH5`i*>-#zHls@hJ8|^QB%>pD+SH^ zMISFEn(OPk$A}GPv7b@skpO(B(u=zXEfIL1CDKdnVcWD>Ufau3)xp=wiQFqk;v^Ly zkl?>x>?Oz@Vv5qnluu^xlgwuE=&4ibd!#JLtx%A7w~Wg#3y6p049Uj13E3fl`%D00 zeoqe^;s`7(Fy_$oy0MsiAe;er8DZ_gEsXVETM-dUo9N*KXs2MBTw_Vp^nKr= zPofdUVXh6P2eZ%jqk$ta05Di5!s-*tuz}6KPVO{4$c{#r5D$Zp2t(99qw0JeOl>VI z3&hv!a}Ro*Ycv{> zWTt+Sv*jbdc&`G-0Y=M7{SvDFp=DC5oBSq^Mpl*xV%oT2Q8dWjBrz-5 zQ)s7(MPVyW2?FG^8-DmD24liD^N793yQ;7|!^j1B^u3CCNz}oj7x7WC?2z{}#c}pX zn-=RkDuj;btiIh_NA0|y`x#yxp#y1$p<~Bab5yt4HW|M|cwok8Xu5DG?PwfJ`UF;D z+<||2%cV62dALl$JPzDh0rOP-YR;epb=+xo?9E!1Aw{}?4@(FPPMWlXR2wMs}}=vt-HRz=8O(8q;EFD#2aBQ{9fFj+Zxx%ih`G+4#$^Yo42P zL3dj{kV@3@y`TN!hX4Ge8%IwLHM7bDB?!kZgFI_4S15e?ocN2wPM|82o7h#UFk$k- zex0qE^9^=*f%2HMLx)43$x#(09!JX=`6sKMCfU9CQic}46z~$P8$3RwO>c}CVk|+& zM4OhP1>-5Av4bEQgtO~dYQnj8V$|mh8IpL&IJ7?j=_(&^g3YK7+!Q?t|qb*e2F%}G!<@aC4v|Aof1SJ zfku!2I`!s!i5AtzYxDA6z1l}zq4DzG8Kc|X295Z9Dlu?6Px#GrCPhB`5Nj;jzq|$I zv`y|xQMf>_Hyy9;DsSjE+&tB~C)1$v{sm;Y*x^#WaGVlNPN0pDnjL2VuZ`maZ?46E zXxpB-y1fZAt-Z}v0=`c6#G;#GYZQSF#vJz3S`nN0RG7?wA`wQpy@8Kt6Ol+{=O;TQ zgDzso(&5 ztMq;N{#ETs4+R#GzF7#js3hqF7QiUU%&13kS|w`E@GHFMW)ChqOqyX`@&y1ng3I@# zakk?|K=AmOvTmd}Y{73xaVVii*2PdUU}(MVHw4ai!4%UTs<##zUl{FLz7)b;RkK?u z)+#!$VmzAi(a*1*r^RM(_)*6*)ruVr+ix4-Of3*QdyTJ9iKcWnb4%LPf4E$BFgWXe zUp~8y8*5LzX8z<;Q!_GZYet?*Rj>f~yobtF+{oK@8MoKw_?3H!=WKB=H@7$L`6HtE z7@CiH*bag!_*xIAAe$Lwlff0tc^RMV8t=N68U#b)o#WE+v(tDD8$6t%h@G9Ri*>3J zw}|?`{FXRv5Qr(`k#{K!AdZIGQBn^l5MEMP#|SallhKO>M)MLX*%RZl69e!g1jZ=w z_LS|mIRR!BtN?&kK=*8#*t$)22mnkyql*KE>f$f+lKF8!j#|-{G7wrqtl$U$zQCs% zSu{CB5LeFv1b z2pq@5^um*HI5CX2)YLoap&=s}R2W%TID4PO=6_O{6)roY^zH-%U5=i|PzvTbI(0%> zIhF_Gu1@ekf4SFxKG|uFgwj&mwf$N;qx3>mWeSeX+FEqSI~lZdsRo^Cm%mhf^Z4J- znJcPejb)c2)Bl2vf?(C@;v&pu<3wfG&6(#@g?>li2vOq2dosFarDNrrbm3h70v;q|^k)-mS zj7+yz5nsg;->A{DS9_+sV|Jb5q1J81q!B4~a$U$Oqnf_O&FNQADiRdykwDwF5e*!Y z3#oxu$0SJeDLe zjG+>RN_kElGqL$f50mi40=)-BEw1Q&!_H3sWYM2%1dZZ&MBJ}&^?$kFH17!c2Wpk2 z@fXmpKu$EQfvGTRohYq0046el>cr@Kced z^R6c2GZdC)YP>4{RM`Wyd&_e5)$4S;yyf)6m3&~vABj-1BC zUt&>G*9&%^0~+m<^0FeG_H-#7s+vo*SFWI*Za}aqr#Qu zEmJvaJ(xK!w3=@r^zn?P>@&L@#i$Fa$xCgsMakwnB!#FyzP+P%-`{T?PleMApAe2I$B@ zFW_&e<}toaSg)eNzkSUM+M-Yv!^FAdEj`` zvHIPwZU@s#hf}khXs;+7o_JYBlbYNU$NyJ;e{`JFb%?|j-K$Nyk6Cm~%T$jkHwus) zB81)gsDy$ezK@d$tGmzVM74A?B$OQUeL8$p?63{G&I`NvS!npLVR}_U3PRXAabDmY@tq$5Y6f1$(giM? z%~AK8OP1;(1Gl$p$Jl**Xg}y|en>g;J-fKXGZ1{xRx~f0|7DAZh^)C4q)@&utbbbe z+sAK6);rn!^mi+d%g3mBa9z7vqfPmxpT zpPW`U*KgRqjuc|rNUY-8&+V^{dlp9f;Txv@?{edB>-={K{ieq!A&Khpp_`4R1qw>?! z(%LiE{eAV!jGnx>l)Xx=Pfe`!8F-B8T8(q|E48YA)_WKz1w2KLO={2q^@&|b7DD?K zK}3tSE_P1WQ~Vf>EOHj?)X`3XzW{5JE)x)Ndwlhn%5%Nt%$fix(x-<23QBCwoXqxe z4-&BuW3I>R&U7JPn&sV@J}3hgoc@ErjlVYyulC9`gJK~KrfjpGo7OS^MT&{S4pne} zpwIILdto!PkIy`MJXXH%{CS+@dZ7E5ZO~z9`SbBHbd0Kx(&$?5g%KsU3V!_X@L)-Y zlvp2CD8ooT_d%TvcxrR^AfJT11Ehd{MiIskaVysI)a!rZ((o*kH(5-9mpl1!h$lW* z+FyKcy1{eJJBO4p5kxDr7N?Wwmx?nv^>CcM&Px(C)l%>!5wz|#Ao;p;z8v39tHQu; zD_plS)0SH&iv9dJRWInv_rQTum@NL61OjzO^E(wy5)HZaTJ4rD0Rz8ZVG{FYHz)(-nI-iU?)s|3=g;Y-S6rDb1LMvl3|W%Jgdra?eD)FU{68_tJPx2 z;!oax*B5`Ksh>0F3J2zgo$@1|#BC52J1Ow&VY9#H&AgfvT^*eN>5_BSxyc6;#5jz9 zGfq*NZZQ1M_31suDE%F-!~t$GHI7cRa&qn*li;7n$NL)*`9SiQuTyz=g3bm5iO%#s z1O`aXlp3JhThwpRkFF31Z@75F{CtYI83@+OEO~5fx>TgIp8BDjmG+m#?mkz-Oy=4R zl-pR-dt?2~YegI6+g%b=$wO`uT4-|l{?Y_FO^aRjF3(ww3BFL)`ID1jqB%5;AN18U z)?fWc%4e&uk9k6~<;Z%~*F1CXbvnkA`SCu)|M<<_+1ityL%_&q32O2-+9wZC#TBLC z#Q0j*eelV8nmC1!88=-dT1x*L17W?Q=;6p^15u~%-_DdrBP}z%a8PE>7^*|OR$P`* zFEDuUBpH}E9zKV;R}W*N4)-#7bbjf2hrRKgasFBv1;zuW4bPbUfkL*V%)u&5$fu!k z3$Fn;1sZq)P^0ttb4p+-Mw9?w2_#`~@ZeNclW%OLvpFEwi^;4C+SG99X_3G=QcB#i%wtjZqIv1?%exb!5$5s97VU~9%f_#n_xJ?x zCTnm&7t%2y><$k;RF;Vm1=3R!j{LQs7)kIip1%r95kAe;=25_ASdE456@J)>p@V?3 z=aH((nFlIy3m^v^bXPZ8mE8uF%#ErKG+Lr>_@I6WS?WGINC*EninB)4*A3kidG?oH zB>s)N#vtj{dh#GF*HoLW7vDuG8@_hAW_P_b=>9R&*`~K+p!?mI|7J30NMgQyKeKIV z&Fh#e!ER-2n*^}`E28AL58If^YGHS%ZI=6TkcB3sI0!rH>U&!j-yc&U^WqC(IvVLa zn-1UKRU>+%YZ#;eq*9>QS*00r!e^%bS%~g`IWDR0tKR?4^gyzZupD}fG(HntojVhr zG0!#lwgQY0Qt9~)%GLgS;bEun)V7I4m5M@dRjfZz6v$%z9wr#%=zSatd; z%%qx5D`+?r1ve4)i=759At?uuLtK1mR$st|nkTG&lLRI%W4ft~F2}8uzPnnR7nQdM zkFEJ{wrd2%ty0xDF8`=+L6rqa8(E2`?x zkSfTRd*=#0iU-PJ?|43R80Lp<@Ju>=W)6j@k*`BE`u-G{Zrom>UJ1JQldez5kpX}S z@sD!qyR1H7h~Js*0-=yqGqdBMril&izcW4HUnWd=*q?r&5LbqQ)j}D#(FnUeI+$b~;pTy9JbmZexk!jep^hrjjJp^!iW5RA%4No)fLpu#`j6;%q5xg7 z8@3fa{_v)M+}pGM?ivNgClV8D6bI>a3@-{dCkgxAviYU=#}f$I3AY#N3iBWf4wGtl z&#&8rev7cj*bg~-E|rHy*9j4&0~%s_Pl1hhGb*o*M_OYL4G7owdM*En?wR(oEW;9O z4A$(2cmx`5W!>b!K->zc%`4(}-UdCqEiO=w*8G3>#W63_PK3#(HYl9&RRpShN|^bu z?C^Mk)RTG`1>0k+QOz@B&-#YxD`N8PPX%roA5s73x7#9u<_mgCj8Ur13_d#Z ze)ew`ZyCOfmSDE&@u#^0wL^;Bn(I7OV!yO(Q5AK`7-bEVhfD?Ks2ydW8Ba;IJ#fD6 zvj6eU|M|rZF&SZZtljTSIisQHoTIG2rv<7{S!#GyN<~>lbIsAOgJg&8@X`8Kd#GLc)a-xv)cCm@N%{*SXVskeD2L_J9GCR5JP8a zLmB(_+2)4~{2Q&&xFP~-_Xr$CjBIT(UcK6j$@DxwMwcg&9=o*x*N>s!B9qj|4G#Hz zbe-@KgWk;6d9$B;T0&wGV$w+tkiGx(>3&T$_s}BoG>PU{{OJ7rnqB`=`_rz~wZ?F{ z_1wi6qXD7y&8=&Fm!UaA-ER~LpNn_=qU8t2Fkk=seGU{GAbH!vum=%3q2{y3UH^|i z{3WS$+xEn@sM4Of&arnPS_RlG(ps8`xY z#zv+jBg;IwfX3 zO4%*0&c7W-RiAm6t;{Br2r1SbRa*aMXT*R{d^yyVW{In+uD6l!5J!h0TQynV&u9IV z@CR4OhWx{^T{t&+RX`9$g_?-jnH}-`c9w}Gccz9y=5pLS-=>_`mDVhf))Q5us4y}t z*JGLjlO>MqEwwu9QP@~Cky)ts@#paW ztv~tDcsG>4?1-FC*M>0Rr8u#3O0JVpZr2~(-Q5&$2ucmV&Z#aeRB2_O3!G@vE_|ow zNpDrAsy>CvPkrUkbc5Wy8v|;UBBBLdjJrgVzD5I$#=gV#k&h2~{WR>a&OHlWPvkcO zUn!KnD<+%AC*L=q;Y5rTGtxa$CvzHW3We|s(5P*UQ|f(s4zLNa<9zdqk^mkr{}0Dz zIDO@j=G<7E^5b-%HRwEWdSdp2=QvW4RRktp^0R3h8DwFXQx>Hz0 z`ouH+yH62x_coZV3Jk9#(fopt!~GQr4g-bT(-Wyc3urZ^$P4w(BwNs^HaQtK1i^wt zFCUxbFfg!WXo@E?HLACNJofF*TsGf-6bph!-SEG)9o{%8O#b2J2mr6wuvNW?Bk;@6 zLotbIjg&+tVxi$j!5p$$s+$gjG}wBLeUXDCFV*|kjJRA0402^U>m)QFybmK*)qm5k z5Aj3);n)TD19|70Aq_L^9Kk+ThqD&G0>?iBE1dPBb|jT1UqZ9=|5kZyNLs`ks^b#% z;So;R&IqQU-FTK(yihT4rIL@-cF;^7cez4$a3vt%QeXh6XNDjGt6#5HFjPN`PBw&6K0P;Z& z4r+?8qp(#olHOnhHVi4kQ7CAV4K{G2sw@V|Zo97#UG*lx8V!IJ`BD}psGWswhQ<#s z=im`sx{M2zG`NT0bbLiJI3QP-ytcSE+fS3~nEe*UH`UkYpWO+`=lm7vbSQQuJ&zfGj>~ zvN130Ms27TMO;e4?IIE@j2GmSvQC^v#?v?%{>W_jGuvW}+)dk-PoJh786702tFIGP z->2fKCy?*>z2Z#DE~{SJRc+;5h|BW91vvnjp{!TnW`_!;EQmMwTQBJ zk?EtUBth8T{==UQu@J&O${`4SgyeO;+ZIUGs@yo_96QUWeVt8Mx5@q^4;zeU)Mkl@8ctVPff$p%U>>Ot zBzb1g^gPeyxwEh9RA=#2oBed-L9r?KpmWr1mPx%<@5%=MVJ#V*kr!TZ)OcM0%k#YN zwH9_y(i%5B&x7o;zVGH3j20Mbq^~fKY^^?qBLMrTiJ>nwzqM37{P z;-xr^96$NTwLt^(q((IpN2{%Fir;J2hP&c@SqY|7=qTW6;3iM83SDKXaD@i?H`;H0J6+3l#hX^!)Mu@&DSVNFS!fR;;)RrQro|{H+M1t9 zy|?pewD~ewYjz*%MQO+&9eK3pSrf|{iD?y25LzUf|JO4=fh8Whn`>ABG12+!Ztn8-sk0S%oMQm=)K<&pN|pnM^NSJy3weZ4 zt42ZMr^$*20s$rcRHe5HeJ`WATvStRx6xTyOX95CIhFW_-+!H~S70;Jxy<+3i^Lm< zgzci=kd+w1Hn5qfO3HfNIUfv~=W>VPZy0`743FT0WFtKWj9?mU2(2MAaSg^n`~Oe} zKV7VM6JE>>?fhJ1v{eCOVKA!LB1BQNgCkZCbjPrN_%jnWP44hmd$XyvSVjK8(wnouc(OC~!>7dKg>)GL8Sb>jVFpxax z%Kf*!ke4N2uEQTCI%O}Q&+grQUWcrM%Q>&P#ch9jHdGcXFAmy+{ET1-9}t2@U^^$m zpCL8E2jd3_kO)len3_7IC>BLjwx~Q94>!3~wo4kEZsE5g#O# zpu(W`3{P*!BkLueDodkvgL$_0o)&l$Ul$BXK^p~BO>S3Y@mR<@v1=(S-y)~r@=3#{ z2Xaw9!YMm38Q`Dlr@$<_`74naL?n|9Q1a{|B49AkPvG zIy0gs|5Pa_lvu!3mPed6MOyy(v} zvpFM9vFW}x3^+L$hOIeu1|R>ndmlPGr}12$cbo5T{o{|vd*v!YtXHO&Lcc@F_8X5X zA|~F4YK+mQ(JJ@*#(B0hJpB3m`1fy{$w-IGNy4s{|GV$+zV|qPXc+AZmlUzuXu596 z-w6vn(=kwg@(>9aX%;BB`{ykPd{<}FJIeEZ%Lo;=9qf>clKI_uUHI9vE}Kdi0l@p9 zj{0o|UiWe(AR{A;%v3UF7_Sq=#~!2J zLbuk{Ob$NG&(I_zgSx~rwl692xNC}zKCHbN6>Sp?o~-<~`*=P*z1uD?TvS`@Gv;|V zS@*6yjJb17O8IIn*xQiv{kxCbb&vP1j}Mk#-KIzJ^5ah4aC9g9@pQsL7!#n%O41Zm zgh5gC9+(aMueNn1M=<#1LB|(}9H}-9Q50h23bs~Y3IhQJHj%(6>qxLjrL`k1VW-Z& zd~BITu(j3lB^>w{Is4?tsUnFKN89f+aVs+AlvLaPf|Lk1fJVFK#lheD!%O|B<6 z&LpQ;_U$!;&N~}z(E(W47^>m_=bQw!lgB*aAt{U)7lS{tSw05yI%1pZ~g8tAGPpR-9N$nDpaTC_u@RyXZ4CEwJDJslW8- z35WnvtPhV;?6qtM^*Y5Cj}o_ZkN^F+rSi>IX-(Djgv9_bqHr_-2s1Ps2W)_YRUs}G zNkkAPP9}M?9_%T_E-{0d07_)OL`6n7{x(MQ_sRDZoweZjpZ5b?wJ(Ahj{}O0U9~R zz^_J4IULf71xt-11jceM83;u(e}MU4!lGu3K}cwkWt(c=^7gg@xWtXyNnn$*xv(aa zybvWY7+I#`Ir`y6_O>lfJeQM%tajho#-8%65-pybOcO}W$e@F}&cbt^M}#-(r}TWZ z{LD9&el4jbY5Q@dgL=+7S{{mDc<4+}n?b2{yTK%tH zKGwV?H?Yhb0KD1!Q?tCxm>}afTV}fb{8#JGkG~TSu(wzE*lHW5su+lQPVsW#(y7L| zvOV6tjX^t&j#;g`OtHt8M=$w`?)8f%O#(s7IuXVAF_nvx&%Vh9#^QMR5-61oS|u@a zlf{YC<8py0NnkZ@zsA8f8_Z-9kZBb{VF~DDEKVO7Mh_k=)itODc{$joxJ8j-0zJ^p zVz8{3pFCn6um{SHCd-NPmnZ>_T0Q^}E*vz7gE>tCe?L%xRdsZ^ZRwhwEC2fY)c`~- z0zrodoIs@Ex)_hOQss!hZ173LQUi!{P6H8ogHe)<*TI)P*OwFY zfDMoWp9|o!oFF}W-ur0oA!-bvZjss%0TM_PwD1TI?1V`EaIks`U?bd!0LlZ@Mj(ix zY~eY7IsWW)UQT^&>!P)|5LEuSk-J=dZb1=Vc3CGbmdYOyM*t;7Fd)9r;-ewxX;l$` z04xY^Q!0R+cm^=ozFTsbOM?vS?GBu-l!9;<-2?wSj4d2O(WKLKe zY?lRFVm;4e%3TS&NVZYkz0<`+%ZXq)@72Y2QfyMQ$VH@3KJfz zNp3Mh;~;Giattwrv5}-d`Y`#56x2RVy1emRZ6MwETUO9IO?t)wym?(0xTEH#&g}F_ zP)TUE9*vTLbrD#)OkqM@-Em&*O2D7ptHt$)ffyrvKb47CagG$S)2f;zblTGb-lC1J zS%B@NU0(q+!b#s3z3bh0REb_mYO<1x-N~8V06NSFVem#6n#?XwWH&f7`bn zt@_$j{|nxw@xfT}hsKU9-Zvs|Lvh`RG^Q_Z7wt2%)B1h=sY)Z(r}W8al&;Tu+P8Gi zJ}OoaeTlsIWTz+`_Mi|v8mCCzqHWI%FMY=Z4o_(kX8QHi<-6Odd#zCPSxP%qOq{5t zS%2tnwrQgqm%ctV-h42+oNT)X7cn8F$A-$l{h92RY#kg2%?tL-1NhRn1mLAMDHfil zmr@8PhN=8{Zv3QS`_ymZWqH~FyIQSUw>jtzgqTRNi{j~M8WGy}nnsKWLE@q7KDr3)wKhz(T zB#q=DmvoPb!oiITqffeY02$R;T7yEiQcJd^0THDagk zc>KTmGq*u^KGOujN6j^#u8mM~xk4|VfgKXs5-opl)*DV%wb6r4?Q%hggDbt2rkR8k zi>f4iCP5Vwo#M+|>CwWZEEDFIW7|P{*xHnykJhwIg7Wd$I|R(mX-`L}AEX7&b+ScJacAt<#+ znc`1Lh|@R<+kO=|S3M38th7-g%Q`s5ps&@3YEQLXN7`(bv@AX@e4%urqbHuQVV^X3 z`doB+sqN~napza{M^=rNg3Pl2(z5Q&k?Fb3=bfrVS5`9YHEpF)`^3x=%>uAa23eE{ z&i9^L48zi=SkDK}qxtuO87;~dhY8B^SXP|@1}~K-E-CJ1G;k|KeRfTAocI4>-sbb8 zihR^p|M5q7+}zCY^+oniaWxzM+olO^Srs_-@NN`cr&2aX8QQNPYzdJsvbiL{<)@=# z-iMFahwYcDdVM8-_x@Vv*Y6sYy|ew!9%e;1EsQZbTx8{bEL8{-59F9#yBTW=+rFH(J>{gO5&#|~5zb}i?|RR%E%-c@LLOJf0Vxsxg(Q6atlhoN z_VYtTdsx_Wh$>OfJ8;O}_cJufQiTzU#sVX>s$Nr6fUbF0daKAwz8GwN$Z-G~I|WFYI^$HB4R+O(7>nZ1d85LMt=-g zB~y+2$b`K0^H{t~s0K>w=gmDOnL1iKI2nY7tEVr4yJI)xhwRELl!>IC)Nq)h!kx)N zT$EQ47d5Uz3*j$>}enq!2*h+yCRD&xUDIX#$qo?*Mwm z>FKXxXlc*okU>p{EWkL7WXgCz9C0=dDrYu(TCNW>5;IbHN0$O4BViQuRZfCoiXg63 zQwB1QkpgjZa>B#11q=%Ts7N?);E-m(Y#a@ZS4|v>=0{{oCf$$b?8LE2yKF<<0bnGw zii`joh(S%$o5YQl=uIe+a2rw5eA}S-726AkM7t>}OBMl0Lubfm)6-@E^J!>zAohR+ z6I~M`fUN_o2CK$Cmx#TBL%5ku?BH}snKh)eDmjH1y)|1vNhgGi!p%2XO-iPlgY+R^ z2+`~QxWrG0V`^ej>1C1tV}={WwcRn$_I2{V4^TI8cPO8W3DI%A3KuaG{k?H zjeTc*zm6z$T8xdgJLG7NRid2VO0@;cQ$9H9`0OvOv#JJrn=3IP?+lgZUvbIs!nfh+ z!vo4GZ{CrMy0m|KdNRLs;A|B_WRzbp6{%$L?jk3>gnSueBr$;MD2UmqW!%Os!A4)oMLr>Z~UR`#;1WQf4#6!4CyNc9^hcVm0&i`>pC& zCI{b!zwbAxXZ6c1);yjx;9{V$fnR3|b=$f-M6Ybt<~OW&)uf){Vz^NcYYF70zv$IH zvKt%Q-Y$$?!euVkc?T51V1c9{Sw{D!nVHC9o^boT?^o@eadF2)ZW9!`WG}nQu;(eV zgBTD2%<;(W-U)tXyk^XUqbm0BCmd=fE?lx47qM;mtKD>3G);4LTF_-^_428!<3%3wN>{3!+@DS$cV|m4s#O<6naS?#Xu(n)E zD&|xapD_>$xA3vm8ON~U0Ju|caA0#X4XPus>m_WX;_o1Ar- zb&QXGTMqmq{VpEfaV)GTxQ;;32l2yNyB)>eO=vxXqk4j3K&PAgQ$7ib3T{B2pmWWTrVaN(XDH zqn!vF$2G{2NGZE&Qg1X%8%995ul!nH_t)^&V_P-=OT}*r=4rj^AZBB5&8iptI#h~e zV)PA!;LU1Cfv<^osxXo6zg%W7LX*noUyeWcc1Pfco9pxsMo5vUJvc4qLd-R=SBBxBR zh96=_8|@^)U5GH5fxiIJVQTlPyMg6%9bqEvrY=jy@T#g)69EIN=^-ZHV49CadimBO zw0L=@xhB>5pxD0TRcJmk_9hwqfOd#iOZ9@%3`OdtL*{9005?(Txah$+u|FPpp&jC8 z$~NH(xgn1F>!N~d=(Y(VvnR<}-~Up0Jo~jR)L3v$1Wd@_rR6RRj-i>%2WOM!SierB zh8L*C(-ZQrRFi7a(7egWg=f$}q>(ccjr7V@-c|aAH1D!A${{K=rwMANShWQWNpa$F z%Ge@FE3f7lD;MTKgcu^;!<30SLOTHU=+rcz2C4-5lhfsclbiF6wx27mN*WH{VFXsB z22vxxs!ioA{fMS?9s=IhPcV>k)`TEpzrG?<*buGcD4m+J%GKtW-pt|?m8Vj2Xqnbv zJLzr}?oB7283VddRbKa(!YYpP+BsL1VXh18KZ||XiN-RyeM`=a!ff6|Z*4ub4Z^3oomlr0iDC1^ILzWiL?5yyzg=vLpwp%l!_G1|7xCd z29s0~%mt$Wjn7)@yc2YH0+r70geEWKbPZ;{PpfQ+Xxp`i&?mfGE*-dKl@N2%Nz5%X!DfN#V|J(P0v_;OPZPRxumFX%%U+UOzrgu74{MRbYyyhlt z;Ewix+|IH?`H^+4z9M;>Q2-?t`E)FGx*+xiV^5>?`y#a{3lAz>5Z*Ul#n?sy@FyF? zjs}ky;jD$2GqJ2n2fBe#+wmfGGOSWp5boD{Q>w!$G%+?I4yr!rx~0ahz+ZN>Qq757 zlO#&d+V9(lz9d;%%|A{u&xUxYqvRxym_G%1K8_lXeIsIA^T~NO4!EcoI>rs2{F(&Y7Vq`v> zlmm80O8G9TH=cl!v!h;xgzJleWZ?Fo!Y_0`|8(2p@fA69HTghw|GUyLRTC3guG)Xw zcizUk_V8eA_Y^+fQSBGtwjBqr8krRNl@**2A~01tgcJ8_A!7>t4=XxF*Gku`}m7Z}N zuD-^!T@-0HLDks)Zzgn(Tbj7wKKO<14)5)K%+OKuQN3~E4-+=AqtXa-zm3R1Dn4nZ zaxCj+_M|uFPoK?NoIgLaZ0d&?Zj-%h{IBP}(yqB^azO6JUXQ;}6fsLgNu=e@O|v%y zmgBQj-SOyn_2LoopE6S2alRuxG2&W^5L7HFDIUX`5yGKufz*dg{}-vWb?{+2uoVD# z?w<+PtB&6&XP0Z1N6b<15$9%Lh_`9$kEpfBwOn zUVkc1#&TuE7-}g}5g^}Zf^1gcuP*I;_>JD#uQQHLFi=c?8vlACH+E<4zXuUelONuQ3=9D(ig*4Fb~+FDA3Gz8ObQLv9jBiadRoQ4XcD>L z>5iyM$7-nIA|{g-VdV+tkjojYd5xu@U;Q>VW2JtMt9^oCLH{{ntre`Rv%Qm~=E{^} zw%>u$SJ&0|k76?08pokWA*m8?0$G)!*@Z1@p6dU5pOOE{y;swksie?h&;WBaedUzG zCSzy9JH6FiJnXim)QnSfmR1CD_sXM%(7N~-_@?G{5f2YnKaKqq(+2UOza4}a)POKu z+ZWMvc(p2t@#{*YU5G7P`8jVSR}w(sYa$b#V=C?>f#OPE_7z!`YC^3$N1R^EH)QS9 zk1$!&_BjN;*|54OXZpMU?K^t~PLdTOe)cwp>*%t>Z@wrQKWb>qz*qL-uIKuCb9b{V zKAIQWO3rlcN(!Zbdqd}6u55B5`4zqIJ%+o0LLEl7!NEFqMu#@^BMal8HZ{jZhi z*aU2GeqPpxXSiI%GPwkpiJ1B4?&B988HBhS0ab?matAa;(%)tl=X?ZRWRb$g?e`6} zLYPJM5lG)s1BN(Uj7xk-7ERR!L#4?$M~eSsO+_(NGzc;)iU4V zwvuJf!(SvUyUrz?V2Y|*b0a?|;*W79bB&RCK^AFtHEp^2+Pn$RI-itxs}_4AzJStQ z+=n=)=lss;4U(S%<vb zNUrppx^84%9lWD4bnLbE!r#U_@tIUaaVJAwE_1Nz_ay=v%6X?^Dnc`AE&oT|RFlCS z?)7^Is}Rk_HNMw_%MzX;`1hCI5+qcoere%>91{+eHET-C*X z2^}CR@^bf8`(l$zs@9-4^>=@Gudf4EroASMX9iNBKYmx=!82F{Gd%2G*vlYF>G)qg zp9qPVm*Sr$y$t4JC)$t6U5~hbzdxDcj> z9OHiCORK(`>O(&@p}WnA;_e@p!VuQ>{lqyXn}Yykc1_-uJyHI9A+~%SIoz%{#&b5W z;@W*H4|YPP$53L$@g(Vi+-7jxi0#<8f3T3io0^)=i&$yDW;+2uS?vGA%nu@PdBp7h zUSX^V1<>KaSp4YN;~BMUjQy)`Q|wg^y`KSs@@tW-h$+GOotZq@Xo98$`#c^^t^(<;&h2MPlUUNzk%79gAY1l%rIG%T%}Ksn zbX{3|eq`=1*GakZxO&N&66nY<$iA)#NPNfmIx2rMGM_M%`E`t| zqVGl1o*iV#Frjy`s5h!<=9*$ngl#~VV!Bj%ARRmckt3=9&L@JB8&JZGi77VvYM_Fm zO(C$?<}iVH;o0~}>Md~OzqJ{e)Z6}ngQqZJ*=qtVna2Ql-3&yLG7KcT>YDTpK)6ke zKeRX>#0eZ#i5y1niTJWqhxnu-)ERsPXu+Ml2f?AA67uI?Ej=5_yx4x->C!{JKI*+& z)%=&c)-l41H$vx4=-H3s9`(nWtc1Z?tvfOgtFxC2w0q14U=G`-t*xIIetxWSb56N% z|D5ewz;xxwD;+HKw*`7~{QTiFw(ak98rI%~E*2+?dSaraAw@PmUh-sW#hD9Y87DI) z@3u~TMv(LMSie`bzI@6RiH9_i$W6)D?F4>Q1ttbLj0LYGiWXLgF^}Cu0Q;^?et4Nj z0ye-x`avP{fF=m+La}^!92aIvFozU`$z!tuOJm@uWCJFJAh@-ju$&GY!~!Pbpr7M3 zp>)EBRS+u3M=E)D&IRX3-Js-P627oOUN{ItNOv7J%JkSS!9vAh0##^9t6b@-v?v`Q zI9%JvAPEO&S+}S;$Mjz|C>YXN#JV)y_D_Ky?8rU8$hrpM2>xk9zveCVa5*|y|6H01 z2DtWnPlIe)a`Ghxf22)<(-)o7Ot;ef@W{_p%ZYa&!5YTKW>W|`Qe4P^?aQl9j*9ls z%afudobc$#9b%mDZogDh0Z=Mvi$M)n!TI(M7SiFSW#$hxOhHtccQ~uoCc9~%H=J+V zC@o94q*NTz{(Sq^p>Eb^;a80nNs1-^^Uyg&yuLB6aafw zSjluS5OFT*t36N!2jf*DJ7G4aV5b6LB#B8Hi5WWp#OQJ2P_>bZgQ4d{IC67TQhx6s zimEN0s)$bmsQ5%Afk`GI0TAIf zoLLkGG3?;R6#6QE2E%z_Fko0M>BL3M2RPn*qa!9U4cdSjOqQSS9QOhjU$(KA0v|60 zz+NiZQoTr6moqwV!|?a4)xtD3w{55u+ia~Nq<8m9pOzqniS8|Hs;oSB{ zHCfGZ^pFT@=HRbHPFxn8_N!Omj{^iqLdp;}?TxMV{DZuPYR*E=in8l=oLy&r!ySvGUl9sXB0$m~4%hX+Gf! zv{>>7bvI24A*ZQ1_{%Dg-*DnmP%#d_(yJV{&$hMP6J>of@Yl6Jq<$U-yIar3RScl) zE(EILxbIz}JRRhrnGOpQCjavHspof8zRblxz*8zN+rN&I!t1*QMcF<1uQuQ?mS}GI zQDJVnOX`^%SO-|oMg{(r*Ona!6^bRl@i^8AOR3bZkL(X{-XmKLr*vQ!sCSX#gY^>Z z;H0|>uW^)9LO#Q(2ryYl=i#lm!bEa>A_Vk1*kp`C4U}0Bxkz`$y15Ov`MGI>@Ml^v z;xR9{pId!sXw|4&tJwW8^z+%IXMgsPi44dRbF|k2I7!kE7%0&~DT3``;&D=#ayAtp z6@v;M8xNc+!d_iu2#>^ZXxE$lfb)tRgTZqBu8f+$J`YbC59*lLE*~6E4pdxfT^)3- z-`YIZPh@z?+?Og31PS<@to>BtVx|4rpWxiT_~5}8Ih=P%9P~$*qud4>WXs@4?~~-2 z=Ct2;U}lTnR`%)D`9-npSwvao6nhelJ$uL4975H!8 z2cH*FnE+R304b1UH5_DLLJUphxUK|ETj`^8CJqXS@lLm7A{Nd*DM2no9GocW&-yIxD z;i{9sNYBBM@+20_06-*GLTylJXcTGGgxnmc1`LTw!Q`VF%&{(UX!_b)!5$C7)xo7? z2buO7h`0zxrFh%@2Jr}C32>Qs@o9Vkbkj6g&jDR*6CLUAd0_uJui^CY82Z4z=bG#$yMJ{g+p4 z!4Dii6F@aS7_T`_r zXSfxC^U^mooz?jo-oUa|JbxD!C_aX0*F1G-1Znc9$}%Nld|c{!ioN z_Sz8(Lk;2zcbIbf9@&hM^%Puj6j{ZjlNgl1$9;Di{JnX8n;CE!-Tbk(+!2jI`2@8L529 zHgR~x-r1GS4|?R#7j;=;wI7!s$e-wA*t%#(MwZ+EeEqbGd3e5rr=!L-WUJ6$fB^QwGeOgYfwF_mOa|0zzE+l+`?qE=G1#vBC?(o#8 z_(3J?a=|6`cv+JEsoFYsEx)C%f^SO2?>~K#`g_~`>HChNpwgwC!u8eVGF%J0oG@pr z(DI>CN%JW;aM(hr_q))_e(55#S$wPR=NK(dt*;}h&T)>#;{ID}8jT`Ofox(H?i`j5 zVTqUDr2N(YIIdee?yb4vnOwjztD z0%2K!g=12*PXp#j9UU6|Y}2zl`td!TXZ!XlZL39%)xbO+>cA5ZQ}E9Q~D!6CAxguqJ02 z!BEgPbWz8QEnHiwXnib}f1%vX!5&=P81lYN&%h{L>_Txq+_}(1&j49qT%~JkQTtbQ zN~!kLr>TwM$1-W)fUwTVT4T?p)e!8d3oMH8nWyO54SOQ!fA6&-w1Cm_{Lz;Rt?OT| zl)_=!OH@+zN>&#egCqe1HnW*+GVWd(t(c~Dh7no(-7RU#(((_V`}n;GH4Uk|XY$kO z*t%2(@=Az2Zhd%<;^N>=ssK8ah!&&2cjep;U&-5^PFioU?DRFo6}kCTyMd6XSk?Pr z;3k#p(;gzt1L7tp=0`cqX40Hb;VhH0|2>uwjsgq9OOI9P7j`5<`9D)Q-Iq;y(0aE( zNa`x$dlDaNzb}qZCeKEf&^grUbYEx$vfu4(N@`~|I-qqaBZ8E${B^PYJ=R)Dezg1&Fn z@z+b?>qIchp5m2bY)hvqe$JMmmSr%sT<>^;R4-&ahyP;2l=B*-qEKor>v+RIqISXS zY%c8OBk8h(Em8aG_@~ZIUJPsMr;xdFm*p5n;-C$~b_#}aow!h#?X2I&lu}touDLzG zGQAQXqof}3Kh96IBJ7RkV+#R5kv^V;{F`t5u2qj* z>ql+|ziWSg_`uA$JPF4aOljlKFWaI{QSqS_6I^~^Z_t#x1lRuj;b{}zpgw2zYv~^j zI(*#^XZ!H)%3&W*wLgA&XRe^6(@YwI_bQy^V{X&nmvB$dq~d}DEvDt!A03$GQIZDU zM>^{VflP<__0xXeCn(*+*6pQE827#hMeKi5OW!Q{2Fr)S||H6#2Q7Q+g*xKR<@eXR)NWP zpFA$P4tdxsT&kr>U7Xk}7;7wn)J0C-+ zUXTp`qs~?nljF-?y`|OZ`3-+^ZMZv7p3eL0A}mTpxuNG_;a=2v&yIx&a816kq}N9- zyMVkSrTHeFAJ`TtPww*K-+rCK&0Qv#=CqvoK7+q~hG}XHxbc2?PsL*>qMHP;Q+AR} z#8493ZgEh3oYj6I`Kdw3?S;~9eySE&ZhP}6Fl63u<3`$J?=r@ub_A4UNa}Opd~KYR zqwF!>5Gc;7tE^@;!K)?iXl|VBTD#~f>9AO=f4g2*SHr(a2@mGaG~=2~&8^L~GN)y; z6kvOsw;6a}cCy>LU){E>Y^FlPdur0si}1P(`CWKv=!gkh05B0UAzn=OMUaH8P)8hU z=&K#?`Aaq-dL{e3^(?n9qE%i4SPGa}K#&9-1iMs}_B^4xMU^&Ds=1T)HUGNVSt`i> zuNy+s&TUqQ*L`%;*<$eq`@`20Hy|jbREpyTN6XjFE7Hcu`@!F9st4y8 zMR@YqM5$MNZ&(6vHnnBVXDu0BLl)>q7K-lehbcMv1V_rw+ewgw$PlGlR$xo*Lg5GMXIGJckzkGXhXVwQRDE0(l)`S znhg0wOQw2OWy|5yXUoQ^#aW`t_EC0YP6hM@ZEE`-qc6D{v>%l!9yuPpqREd5-U=Eh zv9~a9H*e6Ib)c1@#h+=~Ov~GQMYhPSIQRWsjXtb2J~@u^#EUEkUn6)&u|-+r&|i4& zh1$ZkX^wJfb;wK%VS_PdsTY(XzNA8?K8J0euh%Vzo75+*Ev{K#UiW$37lLwCVaBu+ zibmu#8Z{qkSVWPad?at+Dn=QB#eVI^P!-rXjRiV!mFsVBH8Ke$qmY=$3i9W~!W@HDPt$x?PrwBP_u7wQy4 z2>Xfr8xRV6U42BM%=4Wv_vAcw3@dP;$=b`=UTk*&pG88#y3?`-$D_`#@6z;tB z*lT+~I2{P0@_APD4UYmj_$W82nW4cXw&~h}gU-Zx!iGc1vOhq_VVuYj$s4Eh5(k0i zcv{NAS?$-lRB383zEjIW@g}+IZ~Kp-L@e=j{iz?LB@*Kq2@6~%BRdrRKQ@Nsq!+Mz zu~gzH90Dt8Ce!-dPzm?u?t~ z)q8kcC51?P{^}N)IUy0}NupJtLC5i-%aG{WC$tRYbSCx~xUAx!0U%_ zFQ;&q8XmjWx7HSCHl66en@s;}ug4~>(JI!%BX17*mZ&<}o2MP6(p%j2|Tq%~_O zV`RdWshFj>#b6;xEN=}a0rr|OaWw_iGUjnHCS2?jLrf$nWSR2=Yt77e>WmBCByFv+ zAES*Ljh>X}|4D96w-t_%Z5_Ygcmv-sFlW=9n>F;%_eg3Dr7*MrPox@7L*w6fRo4W< zzDQrXO|_k*F2$7Wp~wuK4!L$JHJ-!?4%Wl}!x0w2hQ*Zri}M?BHmb%)UM>KpJS7*8 z^q$>9pAYlD)lxKduwTmU`{pIrk=@3dxXR!+$aZMZqIqka`YJnC%Yt?*<;{XSAGd@x%EztI9-^FV)3P|5k;IGLunt zdaMqubST_eq1J^abZ-+1x|E9)U-~et@2rAv8_P%r5*@@ve&z-%AWUYwlZy}U=ky@x0ZT!B?|2#7G z`51`1!$XWWk5{GmKH^;UFthI?o=cb1fz;U5t;nH?459~@W?et*+Gg9dg?(#EgD%<- zjJTsbOpEQJXSZWphI6<@LHGF!H*;n7S&!iF!|l%xFat^AKd4>~j!dyCibP00d>lxl z5gAEfaHl$mmRAVGnI=iF?uZ?z)`N@mW?63Tyf2TzePUIGqSY^qiQ^^BQpn)us=(xt zwj(1dB`d99Y6y{=EK@O{)BRuXW`kFns>-S_{Vt4ZhonfzEF<98sGzZ}KAXEzI@Cge zNBz-^(}4ZQZI&RWsd zF%7I?iBXDiT%#;gUI)+uL_-pT(#mTgkS(YHFD(7_(pyH}mM*!0z=be-8VVg^>**4O_ z<3iz}&{*uh+$z7#6(?>Pj6kBlu%FlC|0bGQr#IeMAH&cin-?<`C`td2w-kcu7Ls^) zhj5ggtisOvy`zyrWvdsPFiwHpA%b9HkYML{Yga(=*RRe&d?G(WQ$a;F8k&KZlz<$p z4!LNSsBzYo^T&ofBdv*#Y4DJmdPX<3rQ{d^*j14?({AQnOkZ~U_FQVjpE!lY`QN>~ zCwD?s#}vx+0l{BArWHUf6ijODrD7{GyW_3fveSm_onRFvGa8#SGAqol>wzOEUzBqa z012}W83>Luzbw$vfu>^}_xmJ{V{*;=fsiH_(OrbPa%Syu%(PlWm(bxI%g&;(uY-_fbl z>fw`V=hc5;(PwZFj1@2BlT(B$Gi3}dNPxfF})L-TXp8eQ-=i}CIE4sx_wtmp{kR>4gpRgRYpK7l_%)j^A z6xt=rqv_W&^rfUIWn-f*GgRO%;B7owHxPrI8LT+Mf0g~#txs)@>I>QgfbEOypx7h1 z{VZkc2^|WHtSZfudnuiy%35{cz&d7_xBQXL9si)c@Mw*yPi>pC_vvl+mrM^t&AYWm z-L{qmk*rL2OA$q}rKxk5_fan6r&P5ri~LWCE=AUcydRG0=Pe(Osp_({vi)lM%F+t= z#E6P7d0zZ3?pgkIYnm}91dI&=Kh!f5gx89g0@1(C$F?F(^*3sQ{rDdg!+wwS-|pm zaxY(MuYLJR$oaY}I!49nT|~(Cr)J!&MgIsBOPqu#a4Gq6Ao1C~TA5s23=&TY6M)*> zye}k!{cvxr6P0{a7m$N96?aIrl#fX!qwOx3)Ke1gxmTr`8pb*Kr32*S`Cl9`C zKg6j#JrR5K;w2WlW1Hp1%>JE_!~tb-8mRq;d6lcRtg?GHUKf`&R^ApeKS=>)?0O!8 zOeL#Wm7vg5V8wvI85k^M1@*No6{Vx|QUCkOA%AHvTnZk(( zC%ReIKWlbm9I|t!90*Dq9_VfR^Qv9Mj59#9B@$gM&IMJIEcWE7*{XXKXi`+z;Qppo z%JU0uMtoRL3es5wPWA}lMvIw3_FANPzW?cu@IXRH=BO0Lfgn9DBK+hKZ~OWG&8%%A zPN)jP=b4>O_&ix^D64(IpPsA){nveJIYHOe?#XO6@!RJJbdM9TUGHYm9~WF$FBzk}>5=R_rSE^dhO5)wgm&txy6A-fY`__jjn{DNN4 z79LhF4MA-5YIGc(An=DzP)UDsbrs);(~Ce7GFy?037j%WrFzAklJrkA_PeR179J9_ zD9DVLEC_#S%p2{Ep`EcogCG=d5|IqGEix<&>Npm_06}wvi!FoeSrrPEBMRBRj$`tLR)kc2Hw4v$14gEA{L2 zr^ubESy!*xq=aVj>NRI{)MnVSmY?&f(Yuy?dG^pyTa%JJmgm4$_pD{MhO^@MsExfP z<0Pv_(5{9$HDp#_UAej>b4p(?BUseEC3~{KBQm44ssRf}9)*FYNcdl<%&Rip)lb9p zN9if1F4&mLN?JsSf3amv;3r0PkBJiyLpz5ps;@Vhp?^9a(}Km?E{4A&ep2Q2A%hkg z!Kl~~ih-iFp+%Qa*a-uxNB5C58Up6fq8z@hE}m03jL7GNtau<;{aQDgnYjq)?I5t2nK^8v!V z^W1p@`*n*hip6z{@qz)K`aklm9kz93)qm~dTCq4oE+UJIBKUj{{9gu{b%amDCe^#2 zJvD#l$n$ggD!bfdGjW(zd=UeHO@)sc*))Ry4npFr*D&`nBfk=IL`GT%q4^?bOI8L8OX}dLT)-#{=!3#d!AWNVJu1lK zqhy8XV$We8qi|u&C~7UN-UJiNNA+Y3RtQNX3JXRLu!gY$*kR-VRu~kZXSNLkZVN1M z3qIxL3)p~p5u!@K;Fk!XlP-`-z21jA^t;k)4d($LI@f7yVw8Hnsn5TB@O_vYl``|o z-+23UfC>h{PeyqMpAUobx+=w;lIEnP);O3bVkiIx0SYw%GV$Xk^|h8O-B|-(@wm9c zoyK3reKqGQ1}ns&xd3V~tg>EAArZ=gKnP+;y2U_jUQm8UdJvmO+i9v}Wi~E^ zje7helkUX^NE-qicI5!3VC*tE>^yvSMK#~ny*pym81R3I60B$dlf(s5Gu!xAxA>^=psY}G`4mG{IG@s8m;Lu$*jdt zXlogadTU|Su5~EpH|qv$L+d6S6p|L_#QFpNAo5S79}-BSZ95TZZQGjWW}``-W=mE6 z+n7ECsN-?FjtR=e1f+;69StTCm^Y2mX2s;E-vs3-(WzUrRZ(!E_F!^Ix6Hj@391U8 zLl1+YVyoSvkW6m!_?^p{4`gzSbrbpbhza_FD;qIM)R72Pb*6H>xuRnO>uf{TNc=TM zM&QT)d_R6^JtJomllUF0@mcm+*H6nzsK@;|OV0%T?}ZnqhzpnHjh2HuG`Cn?F4+J_Y-~e9ioe z)_bwbs`jkipw8a;=as4I_G{-tJ(}!T?_I(+yQqmWBLFVv;EMi&1i(uqNM&xIeYcDVGEIreg&=Eg*uG3>j~6ve8c_9{X00 zTEVc^m$BV*sut7eA)8>aPmf~uyF=O7xSKFLC>&Tqnb5u(NkFVFI7lzwO&Jwo-0@|-Ox+NbY5ifFK}MC6>oLIy8JncUF%b=QND%;Ei&a-U!ydlKC~ zPaZ-5HXBG!jQZoBxDpf!$d2N{0025Z{q#|Lwr@qDBJq_H{m`poLzN1;@j;gw6oLkD zG-H4Kk86Dhn<&ea{S3}vn2oiNQ7?wt9j!%W+(fIND^ajj5wUtgl(t9cJ z4*R`(Z4h{SIQ29rK#3^S4O0=WC2Zv6njz1i(Ku)jPy;)UiBCEXBX${u7D+05w*iOb z^NQfH-~08ze)W?{FK`E1uoNaP=Uz|{^N;YfeY>`G**SP~+cmG|y{D8MzYk{8BRdN^ zL7j&@R2-;xE%CkxaM}uJv&g%aPnbS@%)9^H%c4qD8Qq_Ci>^i}eK!?ug@`f3SXCq@ z$+?CFK$V&aTyj_h5FSDfH@9>Mmwqd)t_dKP4{kpK1fsz)-0~~8-vE)4B6VX?ZSOm0 z{9cC3&C%g`;6+y4hS2bEHLZVBQ(rVPO;A&+jB-88tm6`S$rv-iDzWrAu&Bb{wgCty z|4K&b;ecKGNGovf!c6eRFHj7oO(%cTV@!ev&0HeWtnNVAamwf5XOZX!5$b ziRslGU5HxMi|BkO{LC0lm_B>{2c7r1F=fBKNw?OQqu7s^$q&Pq7<@eV(h2M&&YG@T z6%A5Abh1KDy`z=3R3J?gUNV}};8?7E9$Ykq2oVT9VsYJzZK zO-fimi_vkUUhg?g&qO7S9uLq!$QluY!AGm_vhp1*M;t$fvIfhEkYlF+fF(jFcnU6t z6r>p-#&R4Kt=X=p5&(>v1vo96^V;#@Ss#^{!)Yu7*E%j`mf%Y- z`1*~77H{@%J?~sEuFJEpQDW@6LO?zmLe$U%qIA2C3=^i?3;@$ebOYt*u_>@AXt_r% zUu7QQ8HG}{ntHTUUk?_YOVEsW7B}sW{0b_XA10Uig`_mJ9W8C> zmgW+S$Ez``xuPM3mIlmH38TGWNo6U#%{Ps$Et;6@dez7iEKdt1e-`Czg>S|%ChhRp`SZ=9{h zl=X{@lc;R;DvfFL9#(V01)k=t8Pji7JfW=WjB9r{zM*_mwV&pEy!OpxFHT=s+4;?m z6+!2H#Ga?oubzUX0#_rGHr;*Y7*e=4kz%j=5ez#9MoD65B2Ub4i>(@$V;|$>9OoYw zAMnMpy%hkg2Sh4OeXSV8Aju9|`4^wZcqypfF>A^Jw9pxWiV0(>pwt9i0E1A`K2(Sc zV%CS}vxUnTRTS0nDr1H-D!Ns~P_Q4SVc)31;o;@O!$v5AoVGow;-wt3*OfNY({dg| zci&3?%y28XtE>|U{eBl0dzxb!WuKk*YPq>ZOEgWtht%=@GMF^JidE;hSVpixD~Ijl zQI^84o{L6PzM*%Q22`YJL1^Dc;bHmRmE$H4etQ;P(B8t$jU9yMVSE`Yuqb32nSg3Z;peA;FrY@lE)@zSV)@TdJpwg}s3PFxS#lgrXp{9SO4L zIia5zpcV&?w7GL%DwDdu^DfdvvI;7ic{AoP_LZTY2T1; zQ*b3`%Q7e!aD8n6jT((c|8?_J+Zp_|Cn!vpIg7 zE{LIO`^(#aDJh;{<~-4y{^Abo{BM8RJT#X!_O$zxOQw@&Nt)@h+nSy0{coP~DX&IT zYWz40woEXLcX*?MV(r*BDN6!;fQ^gc{9EVm)_hbp!G&WGf#19X6ne-BF@&5Fe$fnJ zQh_~Ath!N95JfhY?uBc44QbLE6lXBc&x0*^qK2d`(UKQn--y>nz)^z3RqaM|wRBpkaN2Qx0vA8i!fcy z)P80v1pbjQiCd^z0dk7kb;Q=ILXeQmo4@?t$U*4WxP7-c_8S*ql==PLva#$&JA#Uw zlizIsQI&05B|Mr;s$Ra-VRT;z9=ePI&>P$Ooey{?c1IwnZb%?6A!dSaES3bqfwykZjzq zzgj|ZEK-En%kHzDkjx|u-Ki2kkPzbBm-LVJhL~hM zSe5WfuC;An9}90|75S6&%=7DVi|)tUlRZefQ{bR^#gnauN6_<#qciLMwix%p0vB}Q zFGKQ4^`N5(EL$ISNYq6Y>0F}#tgr{SzKnqtK@Ao3}Jq+`;j#jU>V#v2CHts*{87%9H!TPJ*q z=alIWI`ALwZFv9M2e5r08|wA6Yj{JhvBUZeT=NeaJmYtr+7wt#H5oMyMVbdA$t$Q)}<( zELc&DE)|{Uh|1pLOhr`9IgBrrmg~wj-cc^EBs7FrHFa@NFK&e5hT=OGtNP*qNGf6) zI0`R8XaY5`5d$f+sAkKKso%kCdZI@sacEm`B=~i0NyXGQ- z4o=+Mn>{o3ng_1#M?_CVk?w5~=~V8@xPDkpKi=DzI;G`ruZd@9$Uv7|# z&8o@1!H0b;nQS_pH1-C2>-H-`yrmA?RqZt|f3h5~_h!D7YWMwoCI6}Gpr6TN>c6~h zLg6d8=)Gkxj%`6n!z-B!pAM0-Dv>#E9n0CYL(a4xeJQeUCqC}Jd#VZFlb&nN$5_mf z%hlezwPht|izUZJMDhSgaIQ+gc+}OOsO(0HBNrOqMp9UoF#_)3fM`2n1H(RVqkY?jBh&wsX>-g5NVD%S+6}-%551Uq!_Xu$~8pOT&G@*pl&*^&kL>t-h zFQX*)ew8&52qF+``EVS(Oot7y#yzEA2R$c%YN)aUoB;W7-be}u%SsJ`o$j4Z zde89*&zT{SjB?dBht8zgxJ>I*m*|6kdEIV6F1Yc1-VZ9BEXA0^IChOp5+af~KX!(1 zNk)NMS((siy;{?igY}22J!#Y1XKih8iStLQBMN^Nshh}y_(XRY?jSn)&;k+~+(0$x z6k~{JQi*spBMqGz1yqQIQ{=g{s5_$$Q0^+O+V>q}&Me%fb@kf$Ezl1!>Y8nse%I-LqoV3`&`8kgq-?HsdshB&~@|OwN8QY2Z@aoiulzo5wXQUK;F*#&#t*Toi|g8-tQ<=XcT4z*kfAK21v5SS?GH~@VSd_c zMEC&k&a<_WaC4{D86`9*#ZBem^AHYJj(1$+){xxgaT)FnCm-xX}f zUp~PrM{iIAMbB}6xD=_-sqm}lu+s*WC2oCkd4wNQ03p@n=V)eB>0MMHP+GJMGwazY zv^%&*N-CC&SwL1)Sxe|dN5{|#Baz7UGYCZB=@TUK3|Uf9;|*ghvl+zLrX71{aHApt40d{wd$~YHMRMw#6u~CUM?M#R3N0i; zbd-$ezq&eaENIkbLms5pctS$gwj1{vdR!skfS8BUz0bU4LQ zNR@x)LgU*Vq6v!{5ub;E@kR@vW~zm`-$pm9U*{z8#u)^)nk!K;t_1p5kK4TMh**?j ziD=n1^_}_#F#S3YD1y~7q0%@|%>m!WyR){tT7d{NVcg=hVKaR5_vT<(NfZ2rTkr#h zuz$hk%OmQUH0r?3$b1@0l`uxtSzfu^IjtqBS;J(%rQ3!}j)fJUJ}T?WDZ>tNc_Qx* z?D3guf3estKg(A5(k_#!G(}jqkh;2sx_(V<&vg=777O<`eFo#XXp4azA3ly~M`r)s zwS0OsRR4tvNtwDOJM^+g*4(O@O$q3E^K_ZpHhmbs|9F^IBT|jcHRL+kU;{Kxdl0L(>(p={X1RjohI}_z`3?3b`lLUka;3`vBI?rVQ~5O z{8@ln%l)G{-FT-n2u|etzjePNffR-86H@%+?kc>aa%&O=^}9^QT3W6E^oE=Ebbk&+ zz9jEIEZ%_#8DX;RGSzCplNL*c=;=|m1^+?mxw4W-$M$)dDMv4NO6?q?Aef9A=OUE6 zLKGkoBWYA9FP#>MTQJ30#1$0f&*JLF#ePt6RT+E+OS~ost@i=V%0FxXkmeq14?mkG zN2-tEyT(n}!YD!Miv4h>SA za4~+k_Hr?_o|{{T1!177>zgmT{yV1gSkbF|-u6?^yc_5>tH9CCZoqxb7%P=*i^3t$N4Jx8Huj(%lO3u5tNQw}s&f7v8sg+#y2hfgI3ZG)*}g_v*pcj1KRK z-PJ1n`UfSgP&M7?fi%+8Qdhace0FWv=Yw55E7yM9O&IFzqod71_`8DIy#6+@ep}(E zCrxE9+BeHDFZ6i6j(wnqtpy>SO{E}|M14h&-`s48_5-A(pg4q}mOqTv0ww513X36? zZ;`yxCnfsq!zX7dL)_pT9pt{W$9Fmj${CzTgnFF~Jk3EyT$OcnJAR`Y;_^c8w+RM( zoV&6vKb}}2Rm7c&trqHc)5ClbFdQdaCKYFS(bZfnF_gTgc)fx!7fdx)7N)}?r6!NU z3u0R|ggzm?D*&la?vZ$*bTD!fAMf#KsGro z-$Dr&Ko|p11BPXb>5z+4)T^8Jvf(~7S+$l(X=>)2&ry|Kkc8RRUm1|$+>5E*%_Zi| zQK`x_*JH|NCB=D};m%<>kV>V6@9$W~3)(xYEte$>gfmZ~fu>lUVnedhK{!>U_gEss zo;E{2EZfyOv&c7#^#S8sp@$!T@v;(syw0{Vsy2^R)t~{wtI<41**q@!_D%)UTg2#~ zGaG>x&i1+59wR2^C$0b1?k_Ljf?aw(91zdJR6Qj{L{mMWrc^{+gZw0f_oH|W@XKVw zc?3xb#iHp?mWC`>mfLJ+QPx{p0*>#X_F4v574hEr#5$+=d<9aMJE4<k5Jblr*?*#rj zt?T@wt-^u1P9#5Vb&l+^IXtS@wJjlnKXFqzmg};**Sl(;q&RxBm#^V-Jau!O%Qn#E z>jUlZ3Zq%hs9G{;LMTbZ!DR;Ghczt@AAo|Ez@zZzlF~gu5dIE<4I|DL4`KDEgQF@^a#p}WjLW>6s^?!y#LwU|dgdkOjMz40|Lfm_P6 z{OMJs!b)I?%O9bcAf;jo+0odoU>v;R82+~nMie1_;`XUyAVP!4mCi^|bE~L-M2G02 zc{}QpK+65#3dczex*xwfxa`w$cVGl&ms>hy$K&c$N$u7tWw#rq8qmpm1zIAJDotrf1rYr@V6{aY8n0*|NShC4baOF)|2tc<9 z2b5PJ3xbI)u6$zgM~36o+Z2dRWM3bRCD=lFbwIajQ{hYB@nFfu6Dzq{g{Wd%%ca~` zx`-Q#z%!G@xOjX}?3U^m)-N*yQZDd2@w)N#z2$S`2B&F;L1bSH?^6ww!1VXb%RCVTDmz{HRGJ1SB!VWGJzCJf- zQhl~weHbwIMR>tCd!bw~K7~DtSi%&@!u;1Hx778%C4|a`PM5CneJ$oxkY{tI z!q1Rd%bY>(KD`eJ!lAnq{`&0#zhy^6SY9F$_Ih!y_Kbum1U*p67IYKL_y|ZATT@M# z2Wsa{T8`r2&>t^8s}Bwi(Jnhuw~f+!#4b=cFQ2Hhk8o~4HcgMb9plflMatvJ&c7D( z{TJ)L34~IV+@gWk!rL>abgn+ld-=$v!k-tv-ZW98H)Ij3>#Z^;$O6SewJWUOZ3`L?>dj*of!*H$3nA4J5ei-=>!yEDjF z&6D@ueOx$u=-QR58VenI_`_c5Mov1&KJ%^Urb7f$2w?J!FT^I5kup|cdYZ@i7(Qwn zt^~phCDvj<_DvK~G|S1TbF-vdT07n0=@)gFI1> z?+q;@A@{DcyHYahCf^uNj}}?gGp4e3H+BT9e5)QAK!XIM5Hn{e{_g#?okwnY$jjU7 z@27;$ij`S#2kp?@4A1p%cONjhPXQDHSUqwH8qRYC`plj*uH`FY)}B1;N_AJi$bl|_d|Nns{AL+<)O5dAZP z{Qy5(G}&r9tj4mE2`pMNKAHMiPE8^O+zif6EORPLHc5rMaE^OC7c`1aa6+rc>$`TB zOt;2Gi%zrcigojUQ{-wKeE%MFzIOoD{mCtx@~V5vweHR2Fq|Z$K5I^4_xDnSg2>Zf zEP89Dogq9-;vH_Aiq7q;k8ZopsVQnNeMqH$v0ivM+h{&8KEl3vpgU&=_;TuH#z_|& z>OhTSry@Y+MMZ?;a;Pv$426rM6v+1JyLImrZ4Uw;2S+84IyfXNf?YXAFdDoAfbAZoC#>0c7K(cV|T|HkFOU2Vs`QRil6IUz%-C%X5HK@FEdj3g7Esfa<8RuHgtX9z!a zpd(vSs4_xSbMj@=9LB*b_q=@mkamr&nLBwX;0^>+E4ScMILKzt)+s}rNdb<+?KLE) zaKS|5vmxy8^PB@s{pdls%Dp0rZ&$yqRsGKGvv7F)SYd5<`Sr}$_xGEv9lU|;lz@XE zR<7RkOds7WCYj(6RRMI0i2C*;O2>PkZDF;Prv@e@5qN3ij(ShFB}mSt=EYkzw|Rq~ zKZxA+)1L^>Dg+!#?sBH){sVcBX)!Y< z#tUd#uf_0G$Xd7v{P-)f>gvhKE6Gpl3iGkeG=@UYhmoQvWHTC6^>JYGWZ+zK0aJ&a zVP1LC^uAWRexfT*qV`_1TUfZXB*HQRyGPSvb56qjHsM|V^5B9i?Z1Dgo!^?G_U3MX zs#`wdjQlJS<>u^aq=CL5Zx0`+%ar=??YK(HfKmThnIC0J!|4~fFzpcz^ObYOv3HlU z&Un$~1RRM}GwCDkJVTew4eA3IPJ#f5Bbj(Dy|i=>Jfiei2Yg9KXc}PkKKpaM@82&NzgVTq02(8_^Ic&BLoiL z15;5S?t{!|;^(D4C713~2};WE0A5Ga>D$Q(nRkw&rLzv5J(mD%J2evM&m3$2<(Yh# ze|%OEuf@Xc%TBk+pTXFpl+nyV4ALNcM*^e0oKp^$vg%CJ`U?`eWAn`H*<`cLbcOjW zd#q<$rMa`Fva7pY^QjB_u|APzeD;cSmG#fbv=@dcY$_;e!$vvexk!p2rZK&TJ%!{^ zo`$7TZiOKS{_#Cg{Pn9ZxGcCczWe)v%pp+5(nT%fmZ?g!V5h?jtTg#_Aj6mTf)! zCw4bvmNgzL4J?*=c6N7$H1XzH(YH?~{m6&*$ED@zEsKmU{7sZFx&76(tmiuuqw(_o zxco7F>&bc3O5<&#FV_MO_sXJ`1zo|Roz6mvKSb!+js$7<)iJr=XrH}BG%i0g@M~#8 zvg=AU%zlFTIY1+dL=LQ8p=Kn@~WlpnZx-RTJcY#VPYjYvPSSq6O)8wlJ|56|Ort>}REgCsinyfSCW$%P3dcnPy} zj4Y$x73S8yEcs{i4xiin9lwpvsuj^HMan8&AtiIR`x-0dlAN(DfzOm`X$WW5Jw>J>lJ<^_5*ZOqPZCq8N^o9X_B;|8>T;(56S(hOd0?@N zQyKRN_6&3e&Rmo#tiWl<&QnsJJ@XY0!k|W^4DL+^iR>kjNqG9ZAV((xaKvqfXLWJn z3j4-?zQK1D?qF-5XJ5|OVW&y#5+vN@5d%r z%XrB8oE*k|@xI+3=p3U9oA(}OYZ|K@zTX@t%!H)JRtY8)4^}7X8k2sG7CXB?>q_L7)HsQPoZBD7EX2L9{b%zr zB<8Q#UEDZ4bdIIV{#1Q?mkSHoShmFOR-D(YiwNW|YY!0=BbmwA6=~dMuDFaeE=&q? z!V1$s6K2 zB~KTTrEEMw;cHvFWsj?y?kg=n@}ab+W8`q3`_XAYU8tBl=i%Yj-;4u$ z6I!qAVRag1QL%!1x%&2+1#kY@ybn3=w&=D`+N_Z>@Eq&(|J9|2+A*QiQx!$Jk7N^)}+D|x(Ji0 zg+@F-be7KlwYQ$R`!MUM>~=eTzoR1O@JDrArxN)Iu{QJh&FRHr{m#2qL9s#IA_$0u z{#YPYkae`oTg@)LV4jPIOepKYuOnUIbZxD;jlGG07s8(#_=Q-W{O226`~qA5{299c zIP+Y*pJpyi-WIi1v+Dl^0k5}Y`P0!2{Jav^1u1y@%nEMqmc4G<{F^oL&9(D1<)3Ef zgWNpr%Qru6TawRycbJYWG6XGfI|8)aJ!4I0Q-elcbokEF$wWrj;Gc`Wyq8YW=bU7T z+)7o|!S7&Mx~W?lY$fUab0Am6Ml$3&Tkm%iK7Bu-kg>tWKALkSQ7le8 zWZ>OAop)wd)}{J%EyPgS{N@;v0*-1Y90KMXq^jE0xOhL zs`W_Y;41eklQun-hmp>UO$=clqcv7gaEFg=bUOw8njCop(YnQz@RvnP7!_nX6iYdkbyvpe6A{69V6|NkcJIYa*e4i4Tm&fWN5 GUjIMKg6h=( literal 0 HcmV?d00001 diff --git a/onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt b/onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt new file mode 100644 index 0000000000000..e3dbef248d0b2 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt @@ -0,0 +1 @@ + the cut on his chest still dripping blood the ache of his overstrained eyes even the soaring arena around him with the thousands of spectators were trivialities not worth thinking about diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml index 65866fc9827a5..43dedbc394c38 100644 --- a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -314,12 +314,111 @@ stages: pushd /workspace/onnxruntime/python/tools/transformers/ ; \ python3 -m pip install --upgrade pip ; \ pushd models/llama ; \ - python3 -m pip install -r requirements-cuda.txt ; \ + python3 -m pip install -r requirements.txt ; \ popd ; \ python3 -m pip install /ort-artifact/*.whl ; \ + python3 -m pip uninstall -y torch ; \ python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda --input /meta-llama2 --small_gpu ;\ popd ; \ " displayName: 'Run Llama2 to Onnx F16 and parity Test' workingDirectory: $(Build.SourcesDirectory) + +- stage: Whisper_ONNX + dependsOn: + - Build_Onnxruntime_Cuda + jobs: + - job: Whisper_ONNX + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: Onnxruntime-Linux-A10-24G + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: none + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Onnxruntime Artifact' + ArtifactName: 'drop-ort-linux-gpu' + TargetPath: '$(Build.BinariesDirectory)/ort-artifact/' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu + Context: tools/ci_build/github/linux/docker/ + ScriptName: tools/ci_build/get_docker_image.py + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimepackagestest + UpdateDepsTxt: false + + - task: DownloadPackage@1 + # The model data in artifact is downloaded from openai/whisper-large-v3 in huggingface model hub + # In order to save size, removed .git directory and pickled files, and keep the safetensors model files + displayName: 'Download Whisper Model' + inputs: + packageType: upack + feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' + version: 1.0.0 + definition: 'b583ce7c-1a8f-4099-ae28-5d5f56c478b1' + downloadPath: $(Agent.TempDirectory)/whisper_large_v3 + + - script: | + docker run --rm --gpus all -v $(Build.SourcesDirectory):/workspace \ + -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ + -v $(Agent.TempDirectory)/whisper_large_v3:/whisper_large_v3 \ + onnxruntimepackagestest \ + bash -c ' + set -ex; \ + pushd /workspace/onnxruntime/python/tools/transformers/ ; \ + python3 -m pip install --upgrade pip ; \ + pushd models/whisper ; \ + python3 -m pip install -r requirements.txt ; \ + popd ; \ + python3 -m pip install /ort-artifact/*.whl ; \ + python3 -m pip uninstall -y torch ; \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ + python3 -m models.whisper.convert_to_onnx -m /whisper_large_v3 --output whisperlargev3 --use_external_data_format ; \ + popd ; \ + ' + displayName: 'Convert Whisper Model' + workingDirectory: $(Build.SourcesDirectory) + + - script: | + docker run --rm --gpus all -v $(Build.SourcesDirectory):/workspace \ + -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ + -v $(Agent.TempDirectory)/whisper_large_v3:/whisper_large_v3 \ + onnxruntimepackagestest \ + bash -c ' + set -ex; \ + pushd /workspace/onnxruntime/python/tools/transformers/ ; \ + python3 -m pip install --upgrade pip ; \ + pushd models/whisper ; \ + python3 -m pip install -r requirements.txt ; \ + popd ; \ + python3 -m pip install /ort-artifact/*.whl ; \ + python3 -m pip uninstall -y torch ; \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ + ls whisperlargev3; \ + python3 -m models.whisper.benchmark \ + --benchmark-type ort \ + --audio-path models/whisper/test/1272-141231-0002.mp3 \ + --model-name openai/whisper-large-v3 \ + --ort-model-path /workspace/onnxruntime/python/tools/transformers/whisperlargev3/whisper_large_v3_beamsearch.onnx \ + --precision fp32 \ + --device cuda > ort_output.txt ; \ + cat ort_output.txt ; \ + diff ort_output.txt /workspace/onnxruntime/python/tools/transformers/models/whisper/test/whisper_ort_output.txt && exit 0 || exit 1 + popd ; \ + ' + displayName: 'Test Whisper ONNX Model' + workingDirectory: $(Build.SourcesDirectory) diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu index 9b9dc9ecae822..c9038afc0954c 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2004_gpu @@ -16,15 +16,18 @@ ENV DEBIAN_FRONTEND=noninteractive ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG}:${LD_LIBRARY_PATH} RUN apt-get update &&\ - apt-get install -y git bash wget + apt-get install -y git bash wget diffutils # Install python3 RUN apt-get install -y --no-install-recommends \ python3 \ python3-pip \ python3-dev \ - python3-wheel - + python3-wheel + +# Install ffmpeg, which couldn't be installed in UBI8 +# https://stackoverflow.com/questions/73597789/how-to-install-ffmpeg-on-ubi-docker-images +RUN apt-get install -y --no-install-recommends ffmpeg RUN pip install --upgrade pip From 430a086f22684ad0020819dc3e7712f36fe9f016 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Sun, 25 Feb 2024 08:50:45 -0800 Subject: [PATCH 146/207] fix memory mapping on Windows (#19623) ### Description Windows memory map casts mapped_offset to DWORD directly. It will be truncated if it is larger than 2^32-1. We need to set high dwFileOffsetHigh for this case. ### Motivation and Context The bug was found from #19450 --- onnxruntime/core/platform/windows/env.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 0eb34cbfbc9eb..983cc6089bb4c 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -459,8 +459,8 @@ Status WindowsEnv::MapFileIntoMemory(_In_z_ const ORTCHAR_T* file_path, void* const mapped_base = MapViewOfFile(file_mapping_handle.get(), FILE_MAP_READ, - 0, - static_cast(mapped_offset), + static_cast((mapped_offset >> 32) & 0xFFFFFFFF), + static_cast(mapped_offset & 0xFFFFFFFF), mapped_length); GSL_SUPPRESS(r.11) mapped_memory = From a9568935a52b3d51ec802a4ab89ab3852129fc1e Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Mon, 26 Feb 2024 11:35:13 -0800 Subject: [PATCH 147/207] [DML EP] Enable DML Graph Serialization (#19505) ### Description This PR adds a feature to serialize all DML EP partitions into DML currency individually for a given a model. This feature can be dynamically turned on by using DML EP option `ep.dml.enable_graph_serialization`. ### Motivation and Context - Why is this change required? What problem does it solve? Useful when user want to capture the DML EP specific partition into DML currency to mitigate the dependency on the framework. --- .../inc/IWinmlExecutionProvider.h | 7 +- .../DmlExecutionProvider/src/ApiTraits.cpp | 570 +++++++ .../src/DmlGraphDeserialization.cpp | 554 +++++++ .../src/DmlGraphFusionHelper.cpp | 247 ++- .../src/DmlGraphFusionHelper.h | 19 +- .../src/DmlGraphFusionTransformer.cpp | 41 +- .../src/DmlGraphFusionTransformer.h | 4 +- .../src/DmlGraphSerialization.cpp | 580 ++++++++ .../src/DmlRuntimeFusedGraphKernel.cpp | 30 +- .../src/External/DirectMLHelpers/ApiTraits.h | 453 +++++- .../External/DirectMLHelpers/DirectMLSchema.h | 112 +- .../DirectMLHelpers/DmlGraphDesc_generated.h | 788 ++++++++++ .../DirectMLHelpers/DmlGraphDeserialization.h | 14 + .../DirectMLHelpers/DmlGraphSerialization.h | 8 + .../DirectMLHelpers/DmlSerializedGraphDesc.h | 73 + .../DirectMLHelpers/GeneratedSchemaHelpers.h | 92 +- .../DirectMLHelpers/GeneratedSchemaTypes.h | 32 +- .../OperatorFieldTypes_generated.h | 1318 +++++++++++++++++ .../External/DirectMLHelpers/SchemaHelpers.h | 54 +- .../src/GraphDescBuilder.cpp | 404 ++--- .../src/GraphDescBuilder.h | 21 +- .../src/MLOperatorAuthorImpl.cpp | 30 +- .../src/Operators/DmlOperator.cpp | 4 +- .../src/Operators/DmlOperatorAttention.cpp | 2 +- .../src/Operators/DmlOperatorBiasAdd.cpp | 2 +- .../Operators/DmlOperatorBiasSplitGelu.cpp | 2 +- .../DmlOperatorEmbedLayerNormalization.cpp | 2 +- .../src/Operators/DmlOperatorGroupNorm.cpp | 2 +- .../DmlOperatorLayerNormalization.cpp | 2 +- .../Operators/DmlOperatorQLinearConcat.cpp | 2 +- .../Operators/DmlOperatorQLinearSigmoid.cpp | 2 +- .../src/Operators/DmlOperatorQuickGelu.cpp | 2 +- .../Operators/DmlOperatorRotaryEmbedding.cpp | 2 +- .../DmlOperatorSkipLayerNormalization.cpp | 2 +- .../dml/DmlExecutionProvider/src/Utility.h | 141 ++ .../dml/DmlExecutionProvider/src/precomp.h | 7 + .../MLOperatorAuthorPrivate.h | 11 +- .../dml/dml_session_options_config_keys.h | 1 + onnxruntime/core/session/inference_session.cc | 9 +- .../test/perftest/command_args_parser.cc | 1 + onnxruntime/test/perftest/ort_test_session.cc | 10 + 41 files changed, 5203 insertions(+), 454 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index f29cc3afc3cda..88e3dd487d427 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -80,15 +80,10 @@ namespace Windows::AI::MachineLearning::Adapter }; // This is the counterpart to the MLOperatorGraphDesc ABI struct which owns its memory and uses containers. - // Either nodesAsOperatorDesc or nodesAsIDMLOperator can have non-zero size. struct DmlGraphNodeCreateInfo { uint32_t nodeCount = 0; - std::vector> nodesAsOperatorDesc; - - // TODO (jeffbloo): Remove this - std::vector> nodesAsIDMLOperator; - + std::vector> nodes; std::vector inputEdges; std::vector outputEdges; std::vector intermediateEdges; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp new file mode 100644 index 0000000000000..bf9800458102b --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp @@ -0,0 +1,570 @@ +//--------------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// This file is automatically generated. Please do not edit it directly. +// To modify this file, edit the schema: dml/Tools/DirectMLSchema.json +// And run this script to regenerate: dml/Tools/GenerateSchema.ps1 +// +// #dml-new-operator-location +//--------------------------------------------------------------------------- + +#pragma once + +#include "precomp.h" + +template +T ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ +#ifndef WAI_BUILD_LINUX + // Clang will instantiate this template even if it isn't used, + // so this static_assert will always fire and break the build. + static_assert(false, "Not implemented for this type"); +#endif +} + +template <> +DML_TENSOR_DATA_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_TENSOR_DATA_TYPE_UNKNOWN", DML_TENSOR_DATA_TYPE_UNKNOWN}, + {"DML_TENSOR_DATA_TYPE_FLOAT32", DML_TENSOR_DATA_TYPE_FLOAT32}, + {"DML_TENSOR_DATA_TYPE_FLOAT16", DML_TENSOR_DATA_TYPE_FLOAT16}, + {"DML_TENSOR_DATA_TYPE_UINT32", DML_TENSOR_DATA_TYPE_UINT32}, + {"DML_TENSOR_DATA_TYPE_UINT16", DML_TENSOR_DATA_TYPE_UINT16}, + {"DML_TENSOR_DATA_TYPE_UINT8", DML_TENSOR_DATA_TYPE_UINT8}, + {"DML_TENSOR_DATA_TYPE_INT32", DML_TENSOR_DATA_TYPE_INT32}, + {"DML_TENSOR_DATA_TYPE_INT16", DML_TENSOR_DATA_TYPE_INT16}, + {"DML_TENSOR_DATA_TYPE_INT8", DML_TENSOR_DATA_TYPE_INT8}, + {"DML_TENSOR_DATA_TYPE_FLOAT64", DML_TENSOR_DATA_TYPE_FLOAT64}, + {"DML_TENSOR_DATA_TYPE_UINT64", DML_TENSOR_DATA_TYPE_UINT64}, + {"DML_TENSOR_DATA_TYPE_INT64", DML_TENSOR_DATA_TYPE_INT64}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_TENSOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_TENSOR_TYPE_INVALID", DML_TENSOR_TYPE_INVALID}, + {"DML_TENSOR_TYPE_BUFFER", DML_TENSOR_TYPE_BUFFER}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_OPERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_OPERATOR_INVALID", DML_OPERATOR_INVALID}, + {"DML_OPERATOR_ELEMENT_WISE_IDENTITY", DML_OPERATOR_ELEMENT_WISE_IDENTITY}, + {"DML_OPERATOR_ELEMENT_WISE_ABS", DML_OPERATOR_ELEMENT_WISE_ABS}, + {"DML_OPERATOR_ELEMENT_WISE_ACOS", DML_OPERATOR_ELEMENT_WISE_ACOS}, + {"DML_OPERATOR_ELEMENT_WISE_ADD", DML_OPERATOR_ELEMENT_WISE_ADD}, + {"DML_OPERATOR_ELEMENT_WISE_ASIN", DML_OPERATOR_ELEMENT_WISE_ASIN}, + {"DML_OPERATOR_ELEMENT_WISE_ATAN", DML_OPERATOR_ELEMENT_WISE_ATAN}, + {"DML_OPERATOR_ELEMENT_WISE_CEIL", DML_OPERATOR_ELEMENT_WISE_CEIL}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP", DML_OPERATOR_ELEMENT_WISE_CLIP}, + {"DML_OPERATOR_ELEMENT_WISE_COS", DML_OPERATOR_ELEMENT_WISE_COS}, + {"DML_OPERATOR_ELEMENT_WISE_DIVIDE", DML_OPERATOR_ELEMENT_WISE_DIVIDE}, + {"DML_OPERATOR_ELEMENT_WISE_EXP", DML_OPERATOR_ELEMENT_WISE_EXP}, + {"DML_OPERATOR_ELEMENT_WISE_FLOOR", DML_OPERATOR_ELEMENT_WISE_FLOOR}, + {"DML_OPERATOR_ELEMENT_WISE_LOG", DML_OPERATOR_ELEMENT_WISE_LOG}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND", DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS", DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN", DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN", DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL", DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL", DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT", DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR", DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR", DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR}, + {"DML_OPERATOR_ELEMENT_WISE_MAX", DML_OPERATOR_ELEMENT_WISE_MAX}, + {"DML_OPERATOR_ELEMENT_WISE_MEAN", DML_OPERATOR_ELEMENT_WISE_MEAN}, + {"DML_OPERATOR_ELEMENT_WISE_MIN", DML_OPERATOR_ELEMENT_WISE_MIN}, + {"DML_OPERATOR_ELEMENT_WISE_MULTIPLY", DML_OPERATOR_ELEMENT_WISE_MULTIPLY}, + {"DML_OPERATOR_ELEMENT_WISE_POW", DML_OPERATOR_ELEMENT_WISE_POW}, + {"DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW", DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW}, + {"DML_OPERATOR_ELEMENT_WISE_RECIP", DML_OPERATOR_ELEMENT_WISE_RECIP}, + {"DML_OPERATOR_ELEMENT_WISE_SIN", DML_OPERATOR_ELEMENT_WISE_SIN}, + {"DML_OPERATOR_ELEMENT_WISE_SQRT", DML_OPERATOR_ELEMENT_WISE_SQRT}, + {"DML_OPERATOR_ELEMENT_WISE_SUBTRACT", DML_OPERATOR_ELEMENT_WISE_SUBTRACT}, + {"DML_OPERATOR_ELEMENT_WISE_TAN", DML_OPERATOR_ELEMENT_WISE_TAN}, + {"DML_OPERATOR_ELEMENT_WISE_THRESHOLD", DML_OPERATOR_ELEMENT_WISE_THRESHOLD}, + {"DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR", DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR}, + {"DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR", DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR}, + {"DML_OPERATOR_ACTIVATION_ELU", DML_OPERATOR_ACTIVATION_ELU}, + {"DML_OPERATOR_ACTIVATION_CELU", DML_OPERATOR_ACTIVATION_CELU}, + {"DML_OPERATOR_ACTIVATION_HARDMAX", DML_OPERATOR_ACTIVATION_HARDMAX}, + {"DML_OPERATOR_ACTIVATION_HARDMAX1", DML_OPERATOR_ACTIVATION_HARDMAX1}, + {"DML_OPERATOR_ACTIVATION_HARD_SIGMOID", DML_OPERATOR_ACTIVATION_HARD_SIGMOID}, + {"DML_OPERATOR_ACTIVATION_IDENTITY", DML_OPERATOR_ACTIVATION_IDENTITY}, + {"DML_OPERATOR_ACTIVATION_LEAKY_RELU", DML_OPERATOR_ACTIVATION_LEAKY_RELU}, + {"DML_OPERATOR_ACTIVATION_LINEAR", DML_OPERATOR_ACTIVATION_LINEAR}, + {"DML_OPERATOR_ACTIVATION_LOG_SOFTMAX", DML_OPERATOR_ACTIVATION_LOG_SOFTMAX}, + {"DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1", DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1}, + {"DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU", DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU}, + {"DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS", DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS}, + {"DML_OPERATOR_ACTIVATION_RELU", DML_OPERATOR_ACTIVATION_RELU}, + {"DML_OPERATOR_ACTIVATION_SCALED_ELU", DML_OPERATOR_ACTIVATION_SCALED_ELU}, + {"DML_OPERATOR_ACTIVATION_SCALED_TANH", DML_OPERATOR_ACTIVATION_SCALED_TANH}, + {"DML_OPERATOR_ACTIVATION_SIGMOID", DML_OPERATOR_ACTIVATION_SIGMOID}, + {"DML_OPERATOR_ACTIVATION_SOFTMAX", DML_OPERATOR_ACTIVATION_SOFTMAX}, + {"DML_OPERATOR_ACTIVATION_SOFTMAX1", DML_OPERATOR_ACTIVATION_SOFTMAX1}, + {"DML_OPERATOR_ACTIVATION_SOFTPLUS", DML_OPERATOR_ACTIVATION_SOFTPLUS}, + {"DML_OPERATOR_ACTIVATION_SOFTSIGN", DML_OPERATOR_ACTIVATION_SOFTSIGN}, + {"DML_OPERATOR_ACTIVATION_TANH", DML_OPERATOR_ACTIVATION_TANH}, + {"DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU", DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU}, + {"DML_OPERATOR_CONVOLUTION", DML_OPERATOR_CONVOLUTION}, + {"DML_OPERATOR_GEMM", DML_OPERATOR_GEMM}, + {"DML_OPERATOR_REDUCE", DML_OPERATOR_REDUCE}, + {"DML_OPERATOR_AVERAGE_POOLING", DML_OPERATOR_AVERAGE_POOLING}, + {"DML_OPERATOR_AVERAGE_POOLING1", DML_OPERATOR_AVERAGE_POOLING1}, + {"DML_OPERATOR_LP_POOLING", DML_OPERATOR_LP_POOLING}, + {"DML_OPERATOR_LP_POOLING1", DML_OPERATOR_LP_POOLING1}, + {"DML_OPERATOR_MAX_POOLING", DML_OPERATOR_MAX_POOLING}, + {"DML_OPERATOR_ROI_POOLING", DML_OPERATOR_ROI_POOLING}, + {"DML_OPERATOR_SLICE", DML_OPERATOR_SLICE}, + {"DML_OPERATOR_CAST", DML_OPERATOR_CAST}, + {"DML_OPERATOR_SPLIT", DML_OPERATOR_SPLIT}, + {"DML_OPERATOR_JOIN", DML_OPERATOR_JOIN}, + {"DML_OPERATOR_PADDING", DML_OPERATOR_PADDING}, + {"DML_OPERATOR_PADDING1", DML_OPERATOR_PADDING1}, + {"DML_OPERATOR_VALUE_SCALE_2D", DML_OPERATOR_VALUE_SCALE_2D}, + {"DML_OPERATOR_UPSAMPLE_2D", DML_OPERATOR_UPSAMPLE_2D}, + {"DML_OPERATOR_GATHER", DML_OPERATOR_GATHER}, + {"DML_OPERATOR_SPACE_TO_DEPTH", DML_OPERATOR_SPACE_TO_DEPTH}, + {"DML_OPERATOR_DEPTH_TO_SPACE", DML_OPERATOR_DEPTH_TO_SPACE}, + {"DML_OPERATOR_TILE", DML_OPERATOR_TILE}, + {"DML_OPERATOR_TOP_K", DML_OPERATOR_TOP_K}, + {"DML_OPERATOR_BATCH_NORMALIZATION", DML_OPERATOR_BATCH_NORMALIZATION}, + {"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING", DML_OPERATOR_BATCH_NORMALIZATION_TRAINING}, + {"DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION", DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION}, + {"DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION", DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION}, + {"DML_OPERATOR_LP_NORMALIZATION", DML_OPERATOR_LP_NORMALIZATION}, + {"DML_OPERATOR_RNN", DML_OPERATOR_RNN}, + {"DML_OPERATOR_LSTM", DML_OPERATOR_LSTM}, + {"DML_OPERATOR_GRU", DML_OPERATOR_GRU}, + {"DML_OPERATOR_ELEMENT_WISE_SIGN", DML_OPERATOR_ELEMENT_WISE_SIGN}, + {"DML_OPERATOR_ELEMENT_WISE_IS_NAN", DML_OPERATOR_ELEMENT_WISE_IS_NAN}, + {"DML_OPERATOR_ELEMENT_WISE_ERF", DML_OPERATOR_ELEMENT_WISE_ERF}, + {"DML_OPERATOR_ELEMENT_WISE_SINH", DML_OPERATOR_ELEMENT_WISE_SINH}, + {"DML_OPERATOR_ELEMENT_WISE_COSH", DML_OPERATOR_ELEMENT_WISE_COSH}, + {"DML_OPERATOR_ELEMENT_WISE_TANH", DML_OPERATOR_ELEMENT_WISE_TANH}, + {"DML_OPERATOR_ELEMENT_WISE_ASINH", DML_OPERATOR_ELEMENT_WISE_ASINH}, + {"DML_OPERATOR_ELEMENT_WISE_ACOSH", DML_OPERATOR_ELEMENT_WISE_ACOSH}, + {"DML_OPERATOR_ELEMENT_WISE_ATANH", DML_OPERATOR_ELEMENT_WISE_ATANH}, + {"DML_OPERATOR_ELEMENT_WISE_IF", DML_OPERATOR_ELEMENT_WISE_IF}, + {"DML_OPERATOR_ELEMENT_WISE_ADD1", DML_OPERATOR_ELEMENT_WISE_ADD1}, + {"DML_OPERATOR_ACTIVATION_SHRINK", DML_OPERATOR_ACTIVATION_SHRINK}, + {"DML_OPERATOR_MAX_POOLING1", DML_OPERATOR_MAX_POOLING1}, + {"DML_OPERATOR_MAX_UNPOOLING", DML_OPERATOR_MAX_UNPOOLING}, + {"DML_OPERATOR_DIAGONAL_MATRIX", DML_OPERATOR_DIAGONAL_MATRIX}, + {"DML_OPERATOR_SCATTER", DML_OPERATOR_SCATTER}, + {"DML_OPERATOR_ONE_HOT", DML_OPERATOR_ONE_HOT}, + {"DML_OPERATOR_RESAMPLE", DML_OPERATOR_RESAMPLE}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT", DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT", DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT}, + {"DML_OPERATOR_ELEMENT_WISE_ROUND", DML_OPERATOR_ELEMENT_WISE_ROUND}, + {"DML_OPERATOR_ELEMENT_WISE_IS_INFINITY", DML_OPERATOR_ELEMENT_WISE_IS_INFINITY}, + {"DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE", DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE}, + {"DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR", DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR}, + {"DML_OPERATOR_FILL_VALUE_SEQUENCE", DML_OPERATOR_FILL_VALUE_SEQUENCE}, + {"DML_OPERATOR_FILL_VALUE_CONSTANT", DML_OPERATOR_FILL_VALUE_CONSTANT}, + {"DML_OPERATOR_CUMULATIVE_SUMMATION", DML_OPERATOR_CUMULATIVE_SUMMATION}, + {"DML_OPERATOR_REVERSE_SUBSEQUENCES", DML_OPERATOR_REVERSE_SUBSEQUENCES}, + {"DML_OPERATOR_GATHER_ELEMENTS", DML_OPERATOR_GATHER_ELEMENTS}, + {"DML_OPERATOR_GATHER_ND", DML_OPERATOR_GATHER_ND}, + {"DML_OPERATOR_SCATTER_ND", DML_OPERATOR_SCATTER_ND}, + {"DML_OPERATOR_MAX_POOLING2", DML_OPERATOR_MAX_POOLING2}, + {"DML_OPERATOR_SLICE1", DML_OPERATOR_SLICE1}, + {"DML_OPERATOR_TOP_K1", DML_OPERATOR_TOP_K1}, + {"DML_OPERATOR_DEPTH_TO_SPACE1", DML_OPERATOR_DEPTH_TO_SPACE1}, + {"DML_OPERATOR_SPACE_TO_DEPTH1", DML_OPERATOR_SPACE_TO_DEPTH1}, + {"DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1", DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1}, + {"DML_OPERATOR_RESAMPLE1", DML_OPERATOR_RESAMPLE1}, + {"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER", DML_OPERATOR_MATRIX_MULTIPLY_INTEGER}, + {"DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY", DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY}, + {"DML_OPERATOR_CONVOLUTION_INTEGER", DML_OPERATOR_CONVOLUTION_INTEGER}, + {"DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION", DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_AND", DML_OPERATOR_ELEMENT_WISE_BIT_AND}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_OR", DML_OPERATOR_ELEMENT_WISE_BIT_OR}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_XOR", DML_OPERATOR_ELEMENT_WISE_BIT_XOR}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_NOT", DML_OPERATOR_ELEMENT_WISE_BIT_NOT}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_COUNT", DML_OPERATOR_ELEMENT_WISE_BIT_COUNT}, + {"DML_OPERATOR_ACTIVATION_RELU_GRAD", DML_OPERATOR_ACTIVATION_RELU_GRAD}, + {"DML_OPERATOR_AVERAGE_POOLING_GRAD", DML_OPERATOR_AVERAGE_POOLING_GRAD}, + {"DML_OPERATOR_MAX_POOLING_GRAD", DML_OPERATOR_MAX_POOLING_GRAD}, + {"DML_OPERATOR_RANDOM_GENERATOR", DML_OPERATOR_RANDOM_GENERATOR}, + {"DML_OPERATOR_NONZERO_COORDINATES", DML_OPERATOR_NONZERO_COORDINATES}, + {"DML_OPERATOR_RESAMPLE_GRAD", DML_OPERATOR_RESAMPLE_GRAD}, + {"DML_OPERATOR_SLICE_GRAD", DML_OPERATOR_SLICE_GRAD}, + {"DML_OPERATOR_ADAM_OPTIMIZER", DML_OPERATOR_ADAM_OPTIMIZER}, + {"DML_OPERATOR_ARGMIN", DML_OPERATOR_ARGMIN}, + {"DML_OPERATOR_ARGMAX", DML_OPERATOR_ARGMAX}, + {"DML_OPERATOR_ROI_ALIGN", DML_OPERATOR_ROI_ALIGN}, + {"DML_OPERATOR_GATHER_ND1", DML_OPERATOR_GATHER_ND1}, + {"DML_OPERATOR_ELEMENT_WISE_ATAN_YX", DML_OPERATOR_ELEMENT_WISE_ATAN_YX}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD", DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD}, + {"DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE", DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE}, + {"DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD", DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD}, + {"DML_OPERATOR_CUMULATIVE_PRODUCT", DML_OPERATOR_CUMULATIVE_PRODUCT}, + {"DML_OPERATOR_BATCH_NORMALIZATION_GRAD", DML_OPERATOR_BATCH_NORMALIZATION_GRAD}, + {"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD", DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD}, + {"DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD", DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD}, + {"DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR", DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR}, + {"DML_OPERATOR_ROI_ALIGN1", DML_OPERATOR_ROI_ALIGN1}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP1", DML_OPERATOR_ELEMENT_WISE_CLIP1}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1", DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1}, + {"DML_OPERATOR_ELEMENT_WISE_NEGATE", DML_OPERATOR_ELEMENT_WISE_NEGATE}, + {"DML_OPERATOR_ACTIVATION_GELU", DML_OPERATOR_ACTIVATION_GELU}, + {"DML_OPERATOR_ACTIVATION_SWISH", DML_OPERATOR_ACTIVATION_SWISH}, + {"DML_OPERATOR_ACTIVATION_HARD_SWISH", DML_OPERATOR_ACTIVATION_HARD_SWISH}, + {"DML_OPERATOR_RESAMPLE2", DML_OPERATOR_RESAMPLE2}, + {"DML_OPERATOR_RESAMPLE_GRAD1", DML_OPERATOR_RESAMPLE_GRAD1}, + {"DML_OPERATOR_DIAGONAL_MATRIX1", DML_OPERATOR_DIAGONAL_MATRIX1}, + {"DML_OPERATOR_MULTIHEAD_ATTENTION", DML_OPERATOR_MULTIHEAD_ATTENTION}, + {"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING}, + {"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_BINDING_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_BINDING_TYPE_NONE", DML_BINDING_TYPE_NONE}, + {"DML_BINDING_TYPE_BUFFER", DML_BINDING_TYPE_BUFFER}, + {"DML_BINDING_TYPE_BUFFER_ARRAY", DML_BINDING_TYPE_BUFFER_ARRAY}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_REDUCE_FUNCTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_REDUCE_FUNCTION_ARGMAX", DML_REDUCE_FUNCTION_ARGMAX}, + {"DML_REDUCE_FUNCTION_ARGMIN", DML_REDUCE_FUNCTION_ARGMIN}, + {"DML_REDUCE_FUNCTION_AVERAGE", DML_REDUCE_FUNCTION_AVERAGE}, + {"DML_REDUCE_FUNCTION_L1", DML_REDUCE_FUNCTION_L1}, + {"DML_REDUCE_FUNCTION_L2", DML_REDUCE_FUNCTION_L2}, + {"DML_REDUCE_FUNCTION_LOG_SUM", DML_REDUCE_FUNCTION_LOG_SUM}, + {"DML_REDUCE_FUNCTION_LOG_SUM_EXP", DML_REDUCE_FUNCTION_LOG_SUM_EXP}, + {"DML_REDUCE_FUNCTION_MAX", DML_REDUCE_FUNCTION_MAX}, + {"DML_REDUCE_FUNCTION_MIN", DML_REDUCE_FUNCTION_MIN}, + {"DML_REDUCE_FUNCTION_MULTIPLY", DML_REDUCE_FUNCTION_MULTIPLY}, + {"DML_REDUCE_FUNCTION_SUM", DML_REDUCE_FUNCTION_SUM}, + {"DML_REDUCE_FUNCTION_SUM_SQUARE", DML_REDUCE_FUNCTION_SUM_SQUARE}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + +template <> +DML_MATRIX_TRANSFORM ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_MATRIX_TRANSFORM_NONE", DML_MATRIX_TRANSFORM_NONE}, + {"DML_MATRIX_TRANSFORM_TRANSPOSE", DML_MATRIX_TRANSFORM_TRANSPOSE}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_CONVOLUTION_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_CONVOLUTION_MODE_CONVOLUTION", DML_CONVOLUTION_MODE_CONVOLUTION}, + {"DML_CONVOLUTION_MODE_CROSS_CORRELATION", DML_CONVOLUTION_MODE_CROSS_CORRELATION}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_CONVOLUTION_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_CONVOLUTION_DIRECTION_FORWARD", DML_CONVOLUTION_DIRECTION_FORWARD}, + {"DML_CONVOLUTION_DIRECTION_BACKWARD", DML_CONVOLUTION_DIRECTION_BACKWARD}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + +template <> +DML_PADDING_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_PADDING_MODE_CONSTANT", DML_PADDING_MODE_CONSTANT}, + {"DML_PADDING_MODE_EDGE", DML_PADDING_MODE_EDGE}, + {"DML_PADDING_MODE_REFLECTION", DML_PADDING_MODE_REFLECTION}, + {"DML_PADDING_MODE_SYMMETRIC", DML_PADDING_MODE_SYMMETRIC}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_INTERPOLATION_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR}, + {"DML_INTERPOLATION_MODE_LINEAR", DML_INTERPOLATION_MODE_LINEAR}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_RECURRENT_NETWORK_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_RECURRENT_NETWORK_DIRECTION_FORWARD", DML_RECURRENT_NETWORK_DIRECTION_FORWARD}, + {"DML_RECURRENT_NETWORK_DIRECTION_BACKWARD", DML_RECURRENT_NETWORK_DIRECTION_BACKWARD}, + {"DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL", DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_FEATURE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT", DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT}, + {"DML_FEATURE_FEATURE_LEVELS", DML_FEATURE_FEATURE_LEVELS}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_FEATURE_LEVEL ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_FEATURE_LEVEL_1_0", DML_FEATURE_LEVEL_1_0}, + {"DML_FEATURE_LEVEL_2_0", DML_FEATURE_LEVEL_2_0}, + {"DML_FEATURE_LEVEL_2_1", DML_FEATURE_LEVEL_2_1}, + {"DML_FEATURE_LEVEL_3_0", DML_FEATURE_LEVEL_3_0}, + {"DML_FEATURE_LEVEL_3_1", DML_FEATURE_LEVEL_3_1}, + {"DML_FEATURE_LEVEL_4_0", DML_FEATURE_LEVEL_4_0}, + {"DML_FEATURE_LEVEL_4_1", DML_FEATURE_LEVEL_4_1}, + {"DML_FEATURE_LEVEL_5_0", DML_FEATURE_LEVEL_5_0}, + {"DML_FEATURE_LEVEL_5_1", DML_FEATURE_LEVEL_5_1}, + {"DML_FEATURE_LEVEL_5_2", DML_FEATURE_LEVEL_5_2}, + {"DML_FEATURE_LEVEL_6_0", DML_FEATURE_LEVEL_6_0}, + {"DML_FEATURE_LEVEL_6_1", DML_FEATURE_LEVEL_6_1}, + {"DML_FEATURE_LEVEL_6_2", DML_FEATURE_LEVEL_6_2}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_IS_INFINITY_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_IS_INFINITY_MODE_EITHER", DML_IS_INFINITY_MODE_EITHER}, + {"DML_IS_INFINITY_MODE_POSITIVE", DML_IS_INFINITY_MODE_POSITIVE}, + {"DML_IS_INFINITY_MODE_NEGATIVE", DML_IS_INFINITY_MODE_NEGATIVE}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_DEPTH_SPACE_ORDER ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW", DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW}, + {"DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH", DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_AXIS_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_AXIS_DIRECTION_INCREASING", DML_AXIS_DIRECTION_INCREASING}, + {"DML_AXIS_DIRECTION_DECREASING", DML_AXIS_DIRECTION_DECREASING}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_ROUNDING_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN", DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN}, + {"DML_ROUNDING_MODE_TOWARD_ZERO", DML_ROUNDING_MODE_TOWARD_ZERO}, + {"DML_ROUNDING_MODE_TOWARD_INFINITY", DML_ROUNDING_MODE_TOWARD_INFINITY}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_RANDOM_GENERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10", DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_MULTIHEAD_ATTENTION_MASK_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE", DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN", DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp new file mode 100644 index 0000000000000..7d8ed17e7d925 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp @@ -0,0 +1,554 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "precomp.h" + +OperatorFieldVariant CreateAttribute( + const DML_SCHEMA_FIELD* schemaField, + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc); + +OperatorFieldVariant CreateActivation( + const dml::ir::operatorFieldTypes::Activation* activationDesc) +{ + DML_OPERATOR_TYPE activationOperatorType = ApiTraits::StringifyHelpers::FromString(activationDesc->type()->c_str()); + const DML_OPERATOR_SCHEMA& activationSchema = SchemaHelpers::GetSchema(activationOperatorType); + std::vector activationOperatorFields(activationSchema.FieldCount); + uint32_t attributeIndex = 0; + + for (uint32_t fieldIndex = 0; fieldIndex < activationSchema.FieldCount; fieldIndex++) + { + const DML_SCHEMA_FIELD* schemaField = &activationSchema.Fields[fieldIndex]; + OperatorFieldVariant field; + switch (schemaField->Kind) + { + case DML_SCHEMA_FIELD_KIND_INPUT_TENSOR: + case DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR: + { + if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC) + { + field = OperatorFieldTypes::TensorDesc(); + } + else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY) + { + field = OperatorFieldTypes::TensorDescArray(); + } + break; + } + case DML_SCHEMA_FIELD_KIND_ATTRIBUTE: + { + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc = + attributeIndex >= activationDesc->attributes()->size() ? + nullptr : + activationDesc->attributes()->Get(attributeIndex++); + field = CreateAttribute(schemaField, attributeDesc); + break; + } + } + + activationOperatorFields[fieldIndex] = OperatorField(schemaField, std::move(field)); + } + + return AbstractOperatorDesc(&activationSchema, std::move(activationOperatorFields)); +} + +OperatorFieldVariant CreateActivations( + const dml::ir::operatorFieldTypes::ActivationArray* activationDescs) +{ + std::vector activations; + for (uint32_t index = 0; index < static_cast(activationDescs->data()->size()); index++) + { + OperatorFieldVariant activation = CreateActivation(activationDescs->data()->Get(index)); + activations.push_back(std::get(activation).value()); + } + return activations; +} + +OperatorFieldVariant CreateAttribute( + const DML_SCHEMA_FIELD* schemaField, + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc) +{ + switch (schemaField->Type) + { + case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC: + { + return attributeDesc != nullptr && attributeDesc->val_as_Activation() != nullptr ? + CreateActivation(attributeDesc->val_as_Activation()) : + OperatorFieldTypes::FusedActivationOperatorDesc(); + } + case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY: + { + return attributeDesc != nullptr && attributeDesc->val_as_ActivationArray() != nullptr ? + CreateActivations(attributeDesc->val_as_ActivationArray()) : + OperatorFieldTypes::FusedActivationOperatorDescArray(); + } + case DML_SCHEMA_FIELD_TYPE_UINT: + { + OperatorFieldTypes::UInt data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_UInt32()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_UINT64: + { + OperatorFieldTypes::UInt64 data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_UInt64()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_INT: + { + OperatorFieldTypes::Int data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_Int32()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_FLOAT: + { + OperatorFieldTypes::Float data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_Float32()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY: + { + OperatorFieldTypes::UIntArray data; + if (attributeDesc != nullptr) + { + data.assign(attributeDesc->val_as_UIntArray()->data()->begin(), attributeDesc->val_as_UIntArray()->data()->end()); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_INT_ARRAY: + { + OperatorFieldTypes::IntArray data; + if (attributeDesc != nullptr) + { + data.assign(attributeDesc->val_as_IntArray()->data()->begin(), attributeDesc->val_as_IntArray()->data()->end()); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY: + { + OperatorFieldTypes::FloatArray data; + if (attributeDesc != nullptr) + { + data.assign(attributeDesc->val_as_FloatArray()->data()->begin(), attributeDesc->val_as_FloatArray()->data()->end()); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_SCALE_BIAS: + { + OperatorFieldTypes::ScaleBias scaleBias; + const dml::ir::operatorFieldTypes::ScaleBias* scaleBiasAttribute = attributeDesc->val_as_ScaleBias(); + if (scaleBiasAttribute != nullptr) + { + scaleBias = {scaleBiasAttribute->scale(), scaleBiasAttribute->bias()}; + } + return scaleBias; + } + case DML_SCHEMA_FIELD_TYPE_SIZE_2D: + { + OperatorFieldTypes::Size2D size2d = {}; + if (attributeDesc != nullptr) + { + size2d.Height = attributeDesc->val_as_Size2D()->height(); + size2d.Width = attributeDesc->val_as_Size2D()->width(); + } + return size2d; + } + case DML_SCHEMA_FIELD_TYPE_SCALAR_UNION: + { + DML_SCALAR_UNION scalarUnion; + if (attributeDesc != nullptr) + { + const dml::ir::operatorFieldTypes::ByteArray* byteArr = attributeDesc->val_as_ScalarUnionData()->data_as_ByteArray(); + std::copy(byteArr->data()->begin(), byteArr->data()->end(), scalarUnion.Bytes); + } + return scalarUnion; + } + case DML_SCHEMA_FIELD_TYPE_BOOL: + { + OperatorFieldTypes::Bool data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_Bool()->data(); + } + return data; + } + default: + { + throw std::invalid_argument("Invalid attribute type."); + } + } +} + +OperatorFieldTypes::TensorDesc CreateBufferTensorDesc( + const dml::ir::DmlBufferTensorDesc* tensorDesc, + const bool isConstantTensor = false) +{ + DmlBufferTensorDesc bufferTensorDesc = {}; + bufferTensorDesc.dataType = ApiTraits::StringifyHelpers::FromString(tensorDesc->dataType()->c_str()); + if (isConstantTensor) + { + bufferTensorDesc.flags = DML_TENSOR_FLAG_OWNED_BY_DML; + } + bufferTensorDesc.sizes.assign(tensorDesc->sizes()->begin(), tensorDesc->sizes()->end()); + if (flatbuffers::IsFieldPresent(tensorDesc, dml::ir::DmlBufferTensorDesc::VT_STRIDES)) + { + bufferTensorDesc.strides.emplace(tensorDesc->strides()->begin(), tensorDesc->strides()->end()); + } + bufferTensorDesc.totalTensorSizeInBytes = tensorDesc->totalTensorSizeInBytes(); + return bufferTensorDesc; +} + +AbstractOperatorDesc CreateAbstractOperatorDesc( + uint32_t nodeIndex, + const dml::ir::OperatorNodeDesc* flatbufferOperatorNodeDesc, + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* nodeInputNames, + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* nodeOutputNames, + const std::unordered_set& constantInputs) +{ + DML_OPERATOR_TYPE type = ApiTraits::StringifyHelpers::FromString(flatbufferOperatorNodeDesc->type()->c_str()); + if (type == DML_OPERATOR_INVALID) + { + throw std::invalid_argument("Graph operator node at index:" + std::to_string(nodeIndex) + + " either has empty or invalid operator type."); + } + const DML_OPERATOR_SCHEMA& schema = SchemaHelpers::GetSchema(type); + std::vector operatorFields(schema.FieldCount); + + auto inputNameItr = nodeInputNames->begin(); + uint32_t inputTensorDescIndex = 0; + + uint32_t outputTensorDescIndex = 0; + auto outputNameItr = nodeOutputNames->begin(); + + uint32_t attributeIndex = 0; + + + for (uint32_t fieldIndex = 0; fieldIndex < schema.FieldCount; fieldIndex++) + { + const DML_SCHEMA_FIELD* schemaField = &schema.Fields[fieldIndex]; + + OperatorFieldVariant field; + switch (schemaField->Kind) + { + case DML_SCHEMA_FIELD_KIND_INPUT_TENSOR: + { + if (inputNameItr == nodeInputNames->end()) + { + throw std::invalid_argument("Missing input names for node at index:" + std::to_string(nodeIndex)); + } + + if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC) + { + const flatbuffers::String* inputName = *inputNameItr; + inputNameItr++; + if (inputName->size() == 0) + { + field = OperatorFieldTypes::TensorDesc(); + break; + } + bool isConstantTensor = !constantInputs.empty() && constantInputs.find(inputName->c_str()) != constantInputs.end(); + + if (flatbufferOperatorNodeDesc->inputs()->size() <= inputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(inputTensorDescIndex + 1) + + "input tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->inputs()->Get(inputTensorDescIndex++); + field = CreateBufferTensorDesc(tensorDesc, isConstantTensor); + } + else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY) + { + std::vector tensors; + while (inputTensorDescIndex < static_cast(flatbufferOperatorNodeDesc->inputs()->size())) + { + const flatbuffers::String* inputName = *inputNameItr; + inputNameItr++; + bool isConstantTensor = !constantInputs.empty() && constantInputs.find(inputName->c_str()) != constantInputs.end(); + + if (flatbufferOperatorNodeDesc->inputs()->size() <= inputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(inputTensorDescIndex + 1) + + "input tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->inputs()->Get(inputTensorDescIndex++); + tensors.push_back(CreateBufferTensorDesc(tensorDesc, isConstantTensor).value()); + } + field = tensors; + } + break; + } + case DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR: + { + if (outputNameItr == nodeOutputNames->end()) + { + throw std::invalid_argument("Missing output names for node at index:" + std::to_string(nodeIndex)); + } + + if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC) + { + const flatbuffers::String* outputName = *outputNameItr; + outputNameItr++; + + if (outputName->size() == 0) + { + field = OperatorFieldTypes::TensorDesc(); + break; + } + + if (flatbufferOperatorNodeDesc->outputs()->size() <= outputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(outputTensorDescIndex + 1) + + "output tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->outputs()->Get(outputTensorDescIndex++); + field = CreateBufferTensorDesc(tensorDesc); + } + else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY) + { + std::vector tensors; + while (outputTensorDescIndex < static_cast(flatbufferOperatorNodeDesc->outputs()->size())) + { + if (flatbufferOperatorNodeDesc->outputs()->size() <= outputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(outputTensorDescIndex + 1) + + "output tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->outputs()->Get(outputTensorDescIndex++); + tensors.push_back(CreateBufferTensorDesc(tensorDesc).value()); + } + field = tensors; + } + break; + } + case DML_SCHEMA_FIELD_KIND_ATTRIBUTE: + { + if (flatbufferOperatorNodeDesc->attributes()->size() <= attributeIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(attributeIndex + 1) + + "attributes for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc = + attributeIndex >= flatbufferOperatorNodeDesc->attributes()->size() ? + nullptr : + flatbufferOperatorNodeDesc->attributes()->Get(attributeIndex++); + field = CreateAttribute(schemaField, attributeDesc); + break; + } + } + + operatorFields[fieldIndex] = OperatorField(schemaField, std::move(field)); + } + + return AbstractOperatorDesc(&schema, std::move(operatorFields)); +} + +std::unordered_map ConvertToEdgeNameToIndexMap( + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* list) +{ + std::unordered_map nameToIndexMap; + for (uint32_t index = 0; index < list->size(); index++) + { + const flatbuffers::String* name = list->GetAsString(index); + if (name->size() == 0) + { + continue; + } + nameToIndexMap[name->string_view()] = index; + } + return nameToIndexMap; // NRVO will automatically move it. no need to use std::move +} + +template void PopulateEdges( + const uint32_t nodeIndex, + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* edgeNames, + const std::unordered_map& edgeNameToIndexMap, + /*out*/ std::vector& edges, + /*out*/ std::vector& intermediateEdges, + /*out*/ std::unordered_map& edgeToOutgoingNodeIndexMap) +{ + for (flatbuffers::uoffset_t edgeIndex = 0; edgeIndex < edgeNames->size(); edgeIndex++) + { + const flatbuffers::String* edgeName = edgeNames->Get(edgeIndex); + if (edgeName->size() == 0) + { + // This must be optional input/output + continue; + } + // edge can be graphInput or graphOutput + if (edgeNameToIndexMap.find(edgeName->string_view()) != edgeNameToIndexMap.end()) + { + EdgeType edge = {}; + edge.Name = edgeName->str(); + + if constexpr (std::is_same_v) + { + edge.GraphInputIndex = edgeNameToIndexMap.at(edgeName->string_view()); + edge.ToNodeIndex = nodeIndex; + edge.ToNodeInputIndex = edgeIndex; + } + else if constexpr (std::is_same_v) + { + edge.GraphOutputIndex = edgeNameToIndexMap.at(edgeName->string_view()); + edge.FromNodeIndex = nodeIndex; + edge.FromNodeOutputIndex = edgeIndex; + edgeToOutgoingNodeIndexMap[edgeName->string_view()] = {nodeIndex, edgeIndex}; + } + + edges.push_back(edge); + } + // edge is intermediate edge + else + { + if constexpr (std::is_same_v) + { + if (edgeToOutgoingNodeIndexMap.find(edgeName->string_view()) == edgeToOutgoingNodeIndexMap.end()) + { + throw std::range_error("Neither there is any graph input with name " + edgeName->str() + + "nor there is any node which has " + edgeName->str() + " as one of the output."); + } + auto& intermediateEdgeNodeIndex = edgeToOutgoingNodeIndexMap[edgeName->string_view()]; + DmlIntermediateSerializedGraphEdge intermediateEdge = {}; + intermediateEdge.Name = edgeName->str(); + intermediateEdge.FromNodeIndex = intermediateEdgeNodeIndex.nodeIndex; + intermediateEdge.FromNodeOutputIndex = intermediateEdgeNodeIndex.nodeOutputIndex; + intermediateEdge.ToNodeIndex = nodeIndex; + intermediateEdge.ToNodeInputIndex = edgeIndex; + intermediateEdges.push_back(std::move(intermediateEdge)); + } + else if constexpr (std::is_same_v) + { + edgeToOutgoingNodeIndexMap[edgeName->string_view()] = {nodeIndex, edgeIndex}; + } + } + } +} + +/* +* - Handling of empty optional input/output/attibute for non-constant node: +* input/output +* - and will have an null entry +* but the actual OperatorNodeDesc variant's +* and will not have any entry. +* attribute +* - will have null entry +*/ +DmlSerializedGraphDesc DeserializeDmlGraph( + const uint8_t* flatbufferGraphDescBlob, + /*out*/ std::vector>& rawData) +{ + if (flatbufferGraphDescBlob == nullptr) + { + throw std::invalid_argument("Given pointer to flatbuffer blob is null"); + } + const dml::ir::DmlGraphDesc* flatbufferGraphDesc = dml::ir::GetDmlGraphDesc(flatbufferGraphDescBlob); + + std::unordered_map graphInputEdgeToIndexMap = ConvertToEdgeNameToIndexMap(flatbufferGraphDesc->graphInputNames()); + std::unordered_map graphOutputEdgeToIndexMap = ConvertToEdgeNameToIndexMap(flatbufferGraphDesc->graphOutputNames()); + + std::unordered_map edgeToOutgoingNodeIndexMap; + std::unordered_set constantInputs; + + std::vector nodes(flatbufferGraphDesc->nodes()->size()); + std::vector inputEdges; + std::vector outputEdges; + std::vector intermediateEdges; + + for (uint32_t nodeIndex = 0; nodeIndex < flatbufferGraphDesc->nodes()->size(); nodeIndex++) + { + const dml::ir::DmlGraphNode* flatbufferNode = flatbufferGraphDesc->nodes()->Get(nodeIndex); + + PopulateEdges( + nodeIndex, + flatbufferNode->inputNames(), + graphInputEdgeToIndexMap, + inputEdges, + intermediateEdges, + edgeToOutgoingNodeIndexMap); + PopulateEdges( + nodeIndex, + flatbufferNode->outputNames(), + graphOutputEdgeToIndexMap, + outputEdges, + intermediateEdges, + edgeToOutgoingNodeIndexMap); + + DmlSerializedGraphNode node = {}; + if (flatbufferNode->name()->size() == 0) + { + throw std::invalid_argument("Graph node at index:" + std::to_string(nodeIndex) + " doesn't have any name"); + } + node.Name = flatbufferNode->name()->c_str(); + + if (flatbufferNode->desc_type() == dml::ir::NodeDesc_ConstantNodeDesc) + { + const dml::ir::ConstantNodeDesc* flatbufferConstantNode = flatbufferNode->desc_as_ConstantNodeDesc(); + if (flatbufferConstantNode->data_type() == dml::ir::ConstantNodeDescDetail_ConstantName) + { + if (flatbufferConstantNode->data_as_ConstantName()->name()->size() == 0) + { + throw std::invalid_argument("Constant node at index:" + std::to_string(nodeIndex) + + " doesn't have constant data name."); + } + + ConstantName constantNode = {flatbufferConstantNode->data_as_ConstantName()->name()->c_str()}; + node.Desc = constantNode; + // output of this node will part of constantInputs list + for (uint32_t outputIndex = 0; outputIndex < flatbufferNode->outputNames()->size(); outputIndex++) + { + constantInputs.insert(flatbufferNode->outputNames()->Get(outputIndex)->c_str()); + } + } + else if (flatbufferConstantNode->data_type() == dml::ir::ConstantNodeDescDetail_ConstantRawData) + { + + uint32_t rawDataSize = flatbufferConstantNode->data_as_ConstantRawData()->data()->size(); + rawData.push_back(std::make_unique(rawDataSize)); + std::transform( + flatbufferConstantNode->data_as_ConstantRawData()->data()->begin(), + flatbufferConstantNode->data_as_ConstantRawData()->data()->end(), + rawData.back().get(), + [](uint8_t b) {return static_cast(b);}); + + ConstantData constantData = {}; + constantData.dataSize = rawDataSize; + constantData.data = rawData.back().get(); + node.Desc = constantData; + } + + + } + else if (flatbufferNode->desc_type() == dml::ir::NodeDesc::NodeDesc_OperatorNodeDesc) + { + // convert dml::ir::OperatorNodeDesc to AbstractOperatorDesc + const dml::ir::OperatorNodeDesc* flatbufferOperatorNodeDesc = flatbufferNode->desc_as_OperatorNodeDesc(); + node.Desc = CreateAbstractOperatorDesc( + nodeIndex, + flatbufferOperatorNodeDesc, + flatbufferNode->inputNames(), + flatbufferNode->outputNames(), + constantInputs); + } + + nodes[nodeIndex] = node; + } + + DmlSerializedGraphDesc graphDesc; + graphDesc.InputCount = flatbufferGraphDesc->graphInputNames()->size(); + graphDesc.OutputCount = flatbufferGraphDesc->graphOutputNames()->size(); + graphDesc.InputEdges = std::move(inputEdges); + graphDesc.IntermediateEdges = std::move(intermediateEdges); + graphDesc.OutputEdges = std::move(outputEdges); + graphDesc.Nodes = std::move(nodes); + return graphDesc; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 642d9aa03eeef..202b762d99e01 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -135,8 +135,10 @@ namespace DmlGraphFusionHelper void ProcessInputData( const ExecutionProviderImpl* providerImpl, + const bool graphSerializationEnabled, const std::vector& isInputsUploadedByDmlEP, - const std::vector& inputEdges, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex, const gsl::span subGraphInputArgNames, const std::unordered_map>& initializerNameToInitializerMap, onnxruntime::Graph& graph, @@ -162,8 +164,17 @@ namespace DmlGraphFusionHelper // Walk through each graph edge and mark used inputs inputsUsed.assign(fusedNodeInputCount, false); - for (const DML_INPUT_GRAPH_EDGE_DESC& edge : inputEdges) { - inputsUsed[edge.GraphInputIndex] = true; + for (auto it = serializedGraphInputIndexToSubgraphInputIndex->begin(); it != serializedGraphInputIndexToSubgraphInputIndex->end(); it++) { + inputsUsed[it->second] = true; + } + for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex->begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex->end(); it++) { + inputsUsed[it->second] = true; + } + + std::wstring modelName; + if (graphSerializationEnabled) + { + modelName = GetModelName(graph.ModelPath()); } for (uint32_t i = 0; i < initInputBindings.size(); i++) @@ -209,6 +220,10 @@ namespace DmlGraphFusionHelper // Tensor sizes in DML must be a multiple of 4 bytes large. tensorByteSize = AlignToPow2(tensorByteSize, 4); + if(graphSerializationEnabled) + { + WriteToFile(modelName, ConvertToWString(iter->first) + L".bin", reinterpret_cast(tensorPtr), tensorByteSize); + } if (inputRawData) { @@ -287,55 +302,158 @@ namespace DmlGraphFusionHelper return initializerPartitionMap; } + inline uint32_t GetConstantNodeGraphInputIndex( + const std::string& constantName, + const std::unordered_map* serializedGraphConstantNameToMainGraphInputIndex, + uint32_t& graphMaxInputIndex, + std::unordered_map& localConstantNameToIndexMap) + { + if (serializedGraphConstantNameToMainGraphInputIndex == nullptr) + { + if (localConstantNameToIndexMap.find(constantName) == localConstantNameToIndexMap.end()) + { + localConstantNameToIndexMap[constantName] = ++graphMaxInputIndex; + } + return localConstantNameToIndexMap[constantName]; + } + else + { + graphMaxInputIndex = std::max(graphMaxInputIndex, serializedGraphConstantNameToMainGraphInputIndex->at(constantName)); + return serializedGraphConstantNameToMainGraphInputIndex->at(constantName); + } + } + + template void ConvertGraphDesc( const Dml::GraphDescBuilder::GraphDesc& graphDesc, - _Out_ DML_GRAPH_DESC& dmlGraphDesc, const uint32_t inputCount, const uint32_t outputCount, - _Inout_ std::vector& dmlOperatorGraphNodes, - _Inout_ std::vector& dmlConstantGraphNodes, + IDMLDevice* device, + StackAllocator& allocator, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex, + _Out_ DML_GRAPH_DESC& dmlGraphDesc, + _Inout_ std::vector>& dmlOperators, _Inout_ std::vector& dmlGraphNodes, _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, _Inout_ std::vector& dmlIntermediateEdges) { - for (size_t i = 0; i < graphDesc.nodes.size(); ++i) + std::unordered_map oldNodeIndexToNewNodeIndexMap; + for (uint32_t index = 0; index < static_cast(graphDesc.Nodes.size()); index++) { - auto& nodeInfo = graphDesc.nodes[i]; - - if (std::holds_alternative>(nodeInfo.nodeDef)) + const DmlSerializedGraphNode& node = graphDesc.Nodes[index]; + if (std::holds_alternative(node.Desc)) { - dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()}; - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + oldNodeIndexToNewNodeIndexMap[index] = static_cast(dmlGraphNodes.size()); + DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(std::get(node.Desc), &allocator); + ComPtr op; + ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); + dmlOperators.push_back(op); + DML_OPERATOR_GRAPH_NODE_DESC* dmlOperatorGraphNode = allocator.template Allocate(); + dmlOperatorGraphNode->Name = node.Name.data(); + dmlOperatorGraphNode->Operator = op.Get(); + dmlGraphNodes.push_back(DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, dmlOperatorGraphNode}); } else { - auto& nodeDefinitionData = std::get>(nodeInfo.nodeDef); - dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC{ - nodeDefinitionData.data(), - nodeDefinitionData.size(), - nodeInfo.name.data() - }; - - // TODO: Change as new header is ingested - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{static_cast(2), &dmlConstantGraphNodes[i]}; + auto& constantNodeVariant = std::get(node.Desc); + if (std::holds_alternative(constantNodeVariant)) + { + oldNodeIndexToNewNodeIndexMap[index] = static_cast(dmlGraphNodes.size()); + + auto& constantData = std::get(constantNodeVariant); + + DML_CONSTANT_DATA_GRAPH_NODE_DESC* constantNode = allocator.template Allocate(); + constantNode->Name = node.Name.data(); + constantNode->DataSize = constantData.dataSize; + constantNode->Data = constantData.data; + dmlGraphNodes.push_back(DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_CONSTANT, constantNode}); + } } } - for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i) + uint32_t graphMaxInputIndex = 0; + + for (size_t i = 0; i < graphDesc.InputEdges.size(); ++i) { - dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i]}; + DML_INPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + // 1. If serializedGraphInputIndexToMainGraphInputIndex is not null: + // then use the corresponding main graph input index, because the caller will use corresponding + // main graph input index for extracting the actual input tensor from the main graph and + // the caller does not own the creation of dml bindings directly. + // Use Case: When the caller is ORT (DML EP) or DmlEngine. + // + // 2. If serializedGraphInputIndexToMainGraphInputIndex is null: + // then assign the sequential graph input index, because it owns the creation of dml bindings + // directly. + edge->GraphInputIndex = serializedGraphInputIndexToSubgraphInputIndex == nullptr ? + graphDesc.InputEdges[i].GraphInputIndex : + serializedGraphInputIndexToSubgraphInputIndex->at(graphDesc.InputEdges[i].GraphInputIndex); + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.InputEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.InputEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.InputEdges[i].Name.data(); + + graphMaxInputIndex = std::max(graphMaxInputIndex, edge->GraphInputIndex); + dmlInputEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, edge}); } - for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i) + for (size_t i = 0; i < graphDesc.OutputEdges.size(); ++i) { - dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i]}; + DML_OUTPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->GraphOutputIndex = graphDesc.OutputEdges[i].GraphOutputIndex; + edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.OutputEdges[i].FromNodeIndex]; + edge->FromNodeOutputIndex = graphDesc.OutputEdges[i].FromNodeOutputIndex; + edge->Name = graphDesc.OutputEdges[i].Name.data(); + + dmlOutputEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, edge}); } - for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i) + std::unordered_map localConstantNameToIndexMap; + for (uint32_t i = 0; i < static_cast(graphDesc.IntermediateEdges.size()); ++i) { - dmlIntermediateEdges[i] = - DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i]}; + DmlSerializedGraphNodeDescVariant descVariant = graphDesc.Nodes[graphDesc.IntermediateEdges[i].FromNodeIndex].Desc; + bool isConstantEdge = std::holds_alternative(descVariant); + if (isConstantEdge) + { + auto& constantNodeVariant = std::get(descVariant); + if (std::holds_alternative(constantNodeVariant)) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].FromNodeIndex]; + edge->FromNodeOutputIndex = graphDesc.IntermediateEdges[i].FromNodeOutputIndex; + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.IntermediateEdges[i].Name.data(); + dmlIntermediateEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, edge}); + } + else + { + const std::string& constantName = graphDesc.Nodes[graphDesc.IntermediateEdges[i].FromNodeIndex].Name; + + DML_INPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->GraphInputIndex = GetConstantNodeGraphInputIndex( + constantName, + serializedGraphLargeConstantNameToSubgraphInputIndex, + graphMaxInputIndex, + localConstantNameToIndexMap); + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.IntermediateEdges[i].Name.data(); + + dmlInputEdges.push_back({DML_GRAPH_EDGE_TYPE_INPUT, edge}); + } + } + else + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].FromNodeIndex]; + edge->FromNodeOutputIndex = graphDesc.IntermediateEdges[i].FromNodeOutputIndex; + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.IntermediateEdges[i].Name.data(); + dmlIntermediateEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, edge}); + } } dmlGraphDesc.InputCount = inputCount; @@ -400,27 +518,34 @@ namespace DmlGraphFusionHelper Microsoft::WRL::ComPtr TryCreateCompiledOperator( const GraphDescBuilder::GraphDesc& graphDesc, const onnxruntime::IndexedSubGraph& indexedSubGraph, - const ExecutionProviderImpl* providerImpl) + const ExecutionProviderImpl* providerImpl, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex) { const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); const uint32_t fusedNodeOutputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->outputs.size()); // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator - DML_GRAPH_DESC dmlGraphDesc = {}; - std::vector dmlOperatorGraphNodes(graphDesc.nodes.size()); - std::vector dmlConstantGraphNodes(graphDesc.nodes.size()); + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); - std::vector dmlGraphNodes(graphDesc.nodes.size()); - std::vector dmlInputEdges(graphDesc.inputEdges.size()); - std::vector dmlOutputEdges(graphDesc.outputEdges.size()); - std::vector dmlIntermediateEdges(graphDesc.intermediateEdges.size()); + StackAllocator<1024> allocator; + DML_GRAPH_DESC dmlGraphDesc = {}; + std::vector> dmlOperators; + std::vector dmlGraphNodes; + std::vector dmlInputEdges; + std::vector dmlOutputEdges; + std::vector dmlIntermediateEdges; ConvertGraphDesc( graphDesc, - dmlGraphDesc, fusedNodeInputCount, fusedNodeOutputCount, - dmlOperatorGraphNodes, - dmlConstantGraphNodes, + device.Get(), + allocator, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, + dmlGraphDesc, + dmlOperators, dmlGraphNodes, dmlInputEdges, dmlOutputEdges, @@ -438,8 +563,6 @@ namespace DmlGraphFusionHelper executionFlags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS; } - ComPtr device; - ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); ComPtr device1; ORT_THROW_IF_FAILED(device.As(&device1)); @@ -460,6 +583,7 @@ namespace DmlGraphFusionHelper } void FusePartitionAndRegisterKernel( + const uint32_t partitionIndex, onnxruntime::Graph& graph, onnxruntime::KernelRegistry* registryForPartitionKernels, const std::unordered_map>& initializerNameToInitializerMap, @@ -467,8 +591,43 @@ namespace DmlGraphFusionHelper const onnxruntime::IndexedSubGraph& indexedSubGraph, std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, - Microsoft::WRL::ComPtr compiledExecutionPlanOperator) + Microsoft::WRL::ComPtr compiledExecutionPlanOperator, + const bool graphSerializationEnabled, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex) { + if (graphSerializationEnabled) + { + + const std::wstring modelName = GetModelName(graph.ModelPath()); + auto buffer = SerializeDmlGraph(graphDesc); + + const std::wstring partitionName = + L"Partition_" + + std::to_wstring(partitionIndex) + + L".bin"; + WriteToFile(modelName, partitionName, buffer.data(), buffer.size()); + + std::vector> rawData; + DmlSerializedGraphDesc deserializedGraphDesc = DeserializeDmlGraph(buffer.data(), rawData); + GraphDescBuilder::GraphDesc deserializedDmlGraphDesc = {}; + deserializedDmlGraphDesc.InputCount = deserializedGraphDesc.InputCount; + deserializedDmlGraphDesc.InputEdges = std::move(deserializedGraphDesc.InputEdges); + deserializedDmlGraphDesc.IntermediateEdges = std::move(deserializedGraphDesc.IntermediateEdges); + deserializedDmlGraphDesc.Nodes = std::move(deserializedGraphDesc.Nodes); + deserializedDmlGraphDesc.OutputCount = deserializedGraphDesc.OutputCount; + deserializedDmlGraphDesc.OutputEdges = std::move(deserializedGraphDesc.OutputEdges); + deserializedDmlGraphDesc.reuseCommandList = graphDesc.reuseCommandList; + deserializedDmlGraphDesc.outputShapes = graphDesc.outputShapes; + + compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( + deserializedDmlGraphDesc, + indexedSubGraph, + providerImpl, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex); + } + auto& fusedNode = graph.BeginFuseSubGraph(indexedSubGraph, indexedSubGraph.GetMetaDef()->name); fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); @@ -482,8 +641,10 @@ namespace DmlGraphFusionHelper std::vector inputsUsed; ProcessInputData( providerImpl, + graphSerializationEnabled, isInputsUploadedByDmlEP, - graphDesc.inputEdges, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, indexedSubGraph.GetMetaDef()->inputs, initializerNameToInitializerMap, graph, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index f8f6162aaa1e0..f1e9654021196 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -45,12 +45,17 @@ namespace DmlGraphFusionHelper gsl::span> partitions ); + template void ConvertGraphDesc( const Dml::GraphDescBuilder::GraphDesc& graphDesc, - _Out_ DML_GRAPH_DESC& dmlGraphDesc, const uint32_t inputCount, const uint32_t outputCount, - _Inout_ std::vector& dmlOperatorGraphNodes, + IDMLDevice* device, + StackAllocator& allocator, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex, + _Out_ DML_GRAPH_DESC& dmlGraphDesc, + _Inout_ std::vector>& dmlOperators, _Inout_ std::vector& dmlGraphNodes, _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, @@ -69,9 +74,12 @@ namespace DmlGraphFusionHelper Microsoft::WRL::ComPtr TryCreateCompiledOperator( const GraphDescBuilder::GraphDesc& graphDesc, const onnxruntime::IndexedSubGraph& indexedSubGraph, - const ExecutionProviderImpl* providerImpl); + const ExecutionProviderImpl* providerImpl, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex); void FusePartitionAndRegisterKernel( + const uint32_t partitionIndex, onnxruntime::Graph& graph, onnxruntime::KernelRegistry* registryForPartitionKernels, const std::unordered_map>& initializerNameToInitializerMap, @@ -79,7 +87,10 @@ namespace DmlGraphFusionHelper const onnxruntime::IndexedSubGraph& indexedSubGraph, std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, - Microsoft::WRL::ComPtr compiledExecutionPlanOperator); + Microsoft::WRL::ComPtr compiledExecutionPlanOperator, + const bool graphSerializationEnabled, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex = nullptr, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex = nullptr); void RegisterDynamicKernel( onnxruntime::Graph& graph, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 679738b639ec9..35a2c451a49a5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -24,15 +24,20 @@ namespace Dml std::vector isInputsUploadedByDmlEP; GraphDescBuilder::GraphDesc graphDesc; std::unordered_map> isInitializerTransferable; + std::vector> smallConstantData; // Need to keep it alive for maintaining lifetime + std::unordered_map serializedGraphInputIndexToSubgraphInputIndex; + std::unordered_map serializedGraphLargeConstantNameToSubgraphInputIndex; }; } DmlGraphFusionTransformer::DmlGraphFusionTransformer( const std::string& name, - const onnxruntime::IExecutionProvider* provider + const onnxruntime::IExecutionProvider* provider, + const bool graphSerializationEnabled ) :onnxruntime::GraphTransformer(name), - m_providerImpl(static_cast(provider)->GetImpl()) + m_providerImpl(static_cast(provider)->GetImpl()), + graphSerializationEnabled(graphSerializationEnabled) { } @@ -227,23 +232,39 @@ namespace Dml ComPtr device; ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf())); + // This map will be used to transfer the initializer to D3D12 system heap memory. + // 'serializedDmlGraphDesc' will have constant input as intermediate edges, that's why + // we need a mapping between intermediateEdgeIndex and indexedSubGraph's (a given partition) + // input arg index. + // For ex: Let's say intermediate edge index = idx, then + // indexedSubGraphInputArgIdx = constantEdgeIdxToSubgraphInputArgIdxMap[idx]; + // corresponding constant tensor = initializerNameToInitializerMap[indexedSubGraph.GetMetaDef()->inputs[indexedSubGraphInputArgIdx]] + // We are using intermediate edge index as a key because same constant tensor can be used by + // multiple nodes. + std::unordered_map serializedGraphInputIndexToSubgraphInputIndex; + std::unordered_map serializedGraphLargeConstantNameToSubgraphInputIndex; + std::vector> smallConstantData; GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), isInitializerTransferable, partitionNodePropsMap, - device.Get(), m_providerImpl, modelPath, subgraphNodes, subgraphInputs, - subgraphOutputs); + subgraphOutputs, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, + smallConstantData); // Compile the operator auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator( graphDesc, indexedSubGraph, - m_providerImpl); + m_providerImpl, + &serializedGraphInputIndexToSubgraphInputIndex, + &serializedGraphLargeConstantNameToSubgraphInputIndex); if (!compiledPartition) { @@ -264,6 +285,9 @@ namespace Dml compiledPartitionInfo->isInputsUploadedByDmlEP = std::move(isInputsUploadedByDmlEP); compiledPartitionInfo->graphDesc = std::move(graphDesc); compiledPartitionInfo->isInitializerTransferable = std::move(isInitializerTransferable); + compiledPartitionInfo->smallConstantData = std::move(smallConstantData); + compiledPartitionInfo->serializedGraphInputIndexToSubgraphInputIndex = std::move(serializedGraphInputIndexToSubgraphInputIndex); + compiledPartitionInfo->serializedGraphLargeConstantNameToSubgraphInputIndex = std::move(serializedGraphLargeConstantNameToSubgraphInputIndex); compiledPartitionInfos[partitionIndex] = std::move(compiledPartitionInfo); } } @@ -271,12 +295,14 @@ namespace Dml } while (!additionalSplittingNodes.empty()); + uint32_t partitionIndex = 0; for (auto&& compiledPartitionInfo : compiledPartitionInfos) { // Null compiled operators were not DML partitions if (compiledPartitionInfo) { DmlGraphFusionHelper::FusePartitionAndRegisterKernel( + partitionIndex++, graph, m_providerImpl->GetKernelRegistry().get(), compiledPartitionInfo->isInitializerTransferable, @@ -284,7 +310,10 @@ namespace Dml compiledPartitionInfo->indexedSubGraph, std::move(compiledPartitionInfo->isInputsUploadedByDmlEP), compiledPartitionInfo->graphDesc, - compiledPartitionInfo->compiledOperator); + compiledPartitionInfo->compiledOperator, + graphSerializationEnabled, + &compiledPartitionInfo->serializedGraphInputIndexToSubgraphInputIndex, + &compiledPartitionInfo->serializedGraphLargeConstantNameToSubgraphInputIndex); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h index 19dab0c89943c..b370f3ef9043c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h @@ -16,7 +16,8 @@ class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer public: DmlGraphFusionTransformer( const std::string& name, - const onnxruntime::IExecutionProvider* provider + const onnxruntime::IExecutionProvider* provider, + const bool graphSerializationEnabled ); public: @@ -38,5 +39,6 @@ class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer private: const ExecutionProviderImpl* m_providerImpl = nullptr; + const bool graphSerializationEnabled = false; }; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp new file mode 100644 index 0000000000000..5355964e8db74 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp @@ -0,0 +1,580 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "precomp.h" + +template +T* ReadAs(uint8_t* base, size_t byteOffset) +{ + return reinterpret_cast(base + byteOffset); +} + +void SerializeAttributeDescs( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& operatorDesc, + /*out*/ std::vector>& attributeDescs); + +flatbuffers::Offset serializeActivation( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& activationOperatorDesc) +{ + std::vector> attributeDescs; + SerializeAttributeDescs(builder, activationOperatorDesc, attributeDescs); + + flatbuffers::Offset offset = dml::ir::operatorFieldTypes::CreateActivationDirect( + builder, + activationOperatorDesc.schema->OperatorName, + &attributeDescs); + return offset; +} + +void SerializeAttributeDescs( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& operatorDesc, + /*out*/ std::vector>& attributeDescs) +{ + for (const OperatorField& field : operatorDesc.fields) + { + if (field.GetSchema()->Kind == DML_SCHEMA_FIELD_KIND_INPUT_TENSOR || + field.GetSchema()->Kind == DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR) + { + continue; + } + + flatbuffers::Offset offset; + + if (std::holds_alternative(field.GetData())) + { + const OperatorFieldTypes::FusedActivationOperatorDesc& fusedActivation = field.AsFusedActivationOperatorDesc(); + if (!fusedActivation.has_value()) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation); + } + else + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation, + serializeActivation(builder, fusedActivation.value()).Union()); + } + } + else if (std::holds_alternative(field.GetData())) + { + const OperatorFieldTypes::FusedActivationOperatorDescArray& fusedActivations = + field.AsFusedActivationOperatorDescArray(); + if (!fusedActivations.has_value()) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray); + } + else + { + std::vector> fbActivations; + + for (AbstractOperatorDesc activationOpDesc : fusedActivations.value()) + { + flatbuffers::Offset fbActivation = + serializeActivation(builder, activationOpDesc); + fbActivations.push_back(fbActivation); + } + + flatbuffers::Offset activationOffset = + dml::ir::operatorFieldTypes::CreateActivationArrayDirect(builder, &fbActivations); + + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray, + activationOffset.Union()); + } + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt32, + builder.CreateStruct(dml::ir::operatorFieldTypes::UInt32(field.AsUInt())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt64, + builder.CreateStruct(dml::ir::operatorFieldTypes::UInt64(field.AsUInt64())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Int32, + builder.CreateStruct(dml::ir::operatorFieldTypes::Int32(field.AsInt())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Float32, + builder.CreateStruct(dml::ir::operatorFieldTypes::Float32(field.AsFloat())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_UIntArray, + dml::ir::operatorFieldTypes::CreateUIntArray(builder, builder.CreateVector(field.AsUIntArray())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_IntArray, + dml::ir::operatorFieldTypes::CreateIntArray(builder, builder.CreateVector(field.AsIntArray())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_FloatArray, + dml::ir::operatorFieldTypes::CreateFloatArray(builder, builder.CreateVector(field.AsFloatArray())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + const OperatorFieldTypes::ScaleBias& scaleBias = field.AsScaleBias(); + if (!scaleBias.has_value()) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias); + } + else + { + dml::ir::operatorFieldTypes::ScaleBias fbScaleBias(scaleBias.value().Scale, scaleBias.value().Bias); + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias, + builder.CreateStruct(fbScaleBias).Union()); + } + } + else if (std::holds_alternative(field.GetData())) + { + const DML_SIZE_2D size2d = field.AsSize2D(); + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Size2D, + builder.CreateStruct(dml::ir::operatorFieldTypes::Size2D(size2d.Width, size2d.Height)).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + OperatorFieldTypes::ScalarUnion scalarUnion = field.AsScalarUnion(); + dml::ir::operatorFieldTypes::ByteArray byteArr; + for (uint32_t index = 0; index < static_cast(sizeof(scalarUnion.Bytes)); index++) + { + byteArr.mutable_data()->Mutate(index, scalarUnion.Bytes[index]); + } + + flatbuffers::Offset scalarUnionOffset = + dml::ir::operatorFieldTypes::CreateScalarUnionData( + builder, + dml::ir::operatorFieldTypes::ScalarVariant_ByteArray, + builder.CreateStruct(byteArr).Union()); + + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ScalarUnionData, + scalarUnionOffset.Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Bool, + builder.CreateStruct(dml::ir::operatorFieldTypes::Bool(field.AsBool())).Union()); + } + else + { + continue; + } + + attributeDescs.push_back(offset); + } +} + +flatbuffers::Offset SerializeDmlTensorDesc( + flatbuffers::FlatBufferBuilder& builder, + const DmlBufferTensorDesc* tensorDesc) +{ + const std::vector *strides = nullptr; + if (tensorDesc->strides.has_value()) + { + strides = &tensorDesc->strides.value(); + } + + flatbuffers::Offset offset = dml::ir::CreateDmlBufferTensorDescDirect( + builder, + ApiTraits::StringifyHelpers::ToString(tensorDesc->dataType), + &tensorDesc->sizes, + strides, + tensorDesc->totalTensorSizeInBytes); + return offset; +} + +flatbuffers::Offset SerializeOperatorNodeDesc( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& operatorDesc) +{ + const DML_OPERATOR_SCHEMA* operatorSchema = operatorDesc.schema; + + std::vector> inputTensorDescs; + std::vector> outputTensorDescs; + + for (const DmlBufferTensorDesc* tensorDesc : operatorDesc.GetInputTensors()) + { + if (tensorDesc == nullptr) + { + continue; + } + flatbuffers::Offset serializedDmlTensorDesc = SerializeDmlTensorDesc(builder, tensorDesc); + inputTensorDescs.push_back(serializedDmlTensorDesc); + } + + for (const DmlBufferTensorDesc* tensorDesc : operatorDesc.GetOutputTensors()) + { + if (tensorDesc == nullptr) + { + continue; + } + flatbuffers::Offset serializedDmlTensorDesc = SerializeDmlTensorDesc(builder, tensorDesc); + outputTensorDescs.push_back(serializedDmlTensorDesc); + } + + std::vector> attributeDescs; + SerializeAttributeDescs(builder, operatorDesc, attributeDescs); + + flatbuffers::Offset offset = dml::ir::CreateOperatorNodeDesc( + builder, + builder.CreateString(operatorSchema->OperatorName), + builder.CreateVector(inputTensorDescs), + builder.CreateVector(outputTensorDescs), + builder.CreateVector(attributeDescs)); + return offset.Union(); +} + +flatbuffers::Offset SerializeConstantNodeDesc( + flatbuffers::FlatBufferBuilder& builder, + uint32_t nodeIndex, + const DmlSerializedGraphNodeConstantVariant& constantNodeDesc) +{ + flatbuffers::Offset offset; + + if (std::holds_alternative(constantNodeDesc)) + { + auto& constantName = std::get(constantNodeDesc); + if (constantName.name.empty()) + { + throw std::invalid_argument("Graph constant node at index:" + std::to_string(nodeIndex) + + " doesn't have the constant data name."); + } + + flatbuffers::Offset constantNameOffset = dml::ir::CreateConstantName( + builder, + builder.CreateString(constantName.name)); + + offset = dml::ir::CreateConstantNodeDesc( + builder, + dml::ir::ConstantNodeDescDetail_ConstantName, + constantNameOffset.Union()); + } + else + { + auto& constantData = std::get(constantNodeDesc); + std::vector rawBytes; + std::transform(constantData.data, constantData.data + constantData.dataSize, + std::back_inserter(rawBytes), [](std::byte b) {return static_cast(b); }); + flatbuffers::Offset constantDataOffset = dml::ir::CreateConstantRawDataDirect( + builder, + &rawBytes); + + offset = dml::ir::CreateConstantNodeDesc( + builder, + dml::ir::ConstantNodeDescDetail_ConstantRawData, + constantDataOffset.Union()); + } + + return offset.Union(); +} + +flatbuffers::Offset SerializeNode( + flatbuffers::FlatBufferBuilder& builder, + const uint32_t nodeIndex, + const DmlSerializedGraphNode& graphNode, + const std::vector>& nodeInputNames, + const std::vector>& nodeOutputNames) +{ + if (graphNode.Name.empty()) + { + throw std::invalid_argument("Graph node at index:" + std::to_string(nodeIndex) + + " does not have any name."); + } + + flatbuffers::Offset offset; + if (std::holds_alternative(graphNode.Desc)) + { + auto& operatorNode = std::get(graphNode.Desc); + offset = dml::ir::CreateDmlGraphNode( + builder, + dml::ir::NodeDesc_OperatorNodeDesc, + SerializeOperatorNodeDesc(builder, operatorNode), + builder.CreateString(graphNode.Name), + builder.CreateVector(nodeInputNames), + builder.CreateVector(nodeOutputNames)); + } + else + { + auto& constantNodeVariant = std::get(graphNode.Desc); + offset = dml::ir::CreateDmlGraphNode( + builder, + dml::ir::NodeDesc_ConstantNodeDesc, + SerializeConstantNodeDesc(builder, nodeIndex, constantNodeVariant), + builder.CreateString(graphNode.Name), + builder.CreateVector(nodeInputNames), + builder.CreateVector(nodeOutputNames)); + } + return offset; +} + +/* +* validates input/output edges and throws exception if an edge +* does not have a name or if an edge has more than 1 names. +*/ +template +std::unordered_map> ConvertToEdgeIndexToNameMap( + const std::vector& edges, + flatbuffers::FlatBufferBuilder& builder) +{ + std::unordered_map> edgeIndexToNameMap; + for (auto& edge : edges) + { + uint32_t index; + if constexpr (std::is_same_v) + { + index = edge.GraphInputIndex; + } + else if constexpr (std::is_same_v) + { + index = edge.GraphOutputIndex; + } + + if (edge.Name.empty()) + { + throw std::invalid_argument("Graph input or output edge at index " + std::to_string(index) + " does not have name."); + } + + if (edgeIndexToNameMap.find(index) != edgeIndexToNameMap.end()) + { + flatbuffers::String* edgeName = ReadAs( + builder.GetCurrentBufferPointer(), + builder.GetSize() - edgeIndexToNameMap[index].o); + if (edge.Name != edgeName->str()) + { + throw std::invalid_argument("Graph input or output edge at index " + std::to_string(index) + " has more than 1 names."); + } + } + + edgeIndexToNameMap[index] = builder.CreateString(edge.Name); + } + return edgeIndexToNameMap; // NRVO will automatically move it. no need to use std::move +} + +void PopulateNonConstantNodeInputOutputCount( + const std::vector& nodes, + /*out*/ std::vector& nodeInputCounts, + /*out*/ std::vector& nodeOutputCounts) +{ + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(nodes.size()); nodeIndex++) + { + auto& node = nodes[nodeIndex]; + if (std::holds_alternative(node.Desc)) + { + auto& operatorNode = std::get(node.Desc); + nodeInputCounts[nodeIndex] = std::max( + nodeInputCounts[nodeIndex], + static_cast(operatorNode.GetInputTensors().size())); + + nodeOutputCounts[nodeIndex] = std::max( + nodeOutputCounts[nodeIndex], + static_cast(operatorNode.GetOutputTensors().size())); + } + } +} + +void PopulateConstantNodeInputOutputCount( + const std::vector& edges, + /*out*/std::vector& maxInputIndexForNodes, + /*out*/std::vector& maxOutputIndexForNodes) +{ + for (auto& edge : edges) + { + maxInputIndexForNodes[edge.ToNodeIndex] = std::max(maxInputIndexForNodes[edge.ToNodeIndex], edge.ToNodeInputIndex + 1); + maxOutputIndexForNodes[edge.FromNodeIndex] = std::max(maxOutputIndexForNodes[edge.FromNodeIndex], edge.FromNodeOutputIndex + 1); + } +} + +/* +* validates intermediate edge and throws exception if an edge +* does not have a name or if an edge has more than 1 names. +*/ +void PopulateNodeInputOutputNames( + flatbuffers::FlatBufferBuilder& builder, + const DmlSerializedGraphDesc& graphDesc, + const std::unordered_map>& graphInputIndexToNameMap, + const std::unordered_map>& graphOutputIndexToNameMap, + /*out*/std::vector>>& nodeToInputNames, + /*out*/std::vector>>& nodeToOutputNames) +{ + for (auto& edge : graphDesc.InputEdges) + { + nodeToInputNames[edge.ToNodeIndex][edge.ToNodeInputIndex] = graphInputIndexToNameMap.at(edge.GraphInputIndex); + } + + for (auto& edge : graphDesc.OutputEdges) + { + nodeToOutputNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = graphOutputIndexToNameMap.at(edge.GraphOutputIndex); + } + + std::unordered_map>> intermediateEdgeNames; + for (uint32_t edgeIndex = 0; edgeIndex < static_cast(graphDesc.IntermediateEdges.size()); edgeIndex++) + { + auto& edge = graphDesc.IntermediateEdges[edgeIndex]; + if (edge.Name.empty()) + { + throw std::invalid_argument( + "Graph intermediate edge from nodeIndex:" + std::to_string(edge.FromNodeIndex) + + " & nodeOutputIndex:" + std::to_string(edge.FromNodeOutputIndex) + " doesn't have name."); + } + + if (intermediateEdgeNames.find(edge.FromNodeIndex) != intermediateEdgeNames.end() && + intermediateEdgeNames[edge.FromNodeIndex].find(edge.FromNodeOutputIndex) != intermediateEdgeNames[edge.FromNodeIndex].end()) + { + flatbuffers::Offset edgeNameOffset = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex]; + flatbuffers::String* edgeName = ReadAs( + builder.GetCurrentBufferPointer(), + builder.GetSize() - edgeNameOffset.o); + + if (edgeName->str() != edge.Name) + { + throw std::invalid_argument( + "Graph intermediate edge from nodeIndex:" + std::to_string(edge.FromNodeIndex) + + " & nodeOutputIndex:" + std::to_string(edge.FromNodeOutputIndex) + " has more than 1 names."); + } + } + else + { + intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = builder.CreateString(edge.Name.c_str()); + } + nodeToInputNames[edge.ToNodeIndex][edge.ToNodeInputIndex] = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex]; + nodeToOutputNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex]; + } +} + + +/* +* - If an edge is connected to multiple nodes, then there will be multiple instances +* of input or intermediate edges, all with the same name. +* - The input will be validated incrementally throughout the execution +* of the method. +* - Handling of empty optional input/output/attibute for non-constant node: +* input/output +* - and will have an null entry +* but the actual OperatorNodeDesc variant's +* and will not have any entry. +* attribute +* - will have null entry +*/ +flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& graphDesc) +{ + + flatbuffers::FlatBufferBuilder builder(1024); + if (graphDesc.Nodes.empty()) + { + return builder.Release(); + } + + // create input/output edge index to name map + std::unordered_map> graphInputIndexToNameMap = + ConvertToEdgeIndexToNameMap(graphDesc.InputEdges, builder); + std::unordered_map> graphOutputIndexToNameMap = + ConvertToEdgeIndexToNameMap(graphDesc.OutputEdges, builder); + + /* + * - Calculate number of input/output for each operator to allocate + * appropriate amount of memory for each node to store input/output names. + * - Non-constant node's input/output count can be determined by the + * AbstractOperatorDesc. + * - Constant node will only have outgoing edges and those outgoing edges + * will be intermediate edges. + */ + std::vector nodeInputCounts(graphDesc.Nodes.size(), 0); + std::vector nodeOutputCounts(graphDesc.Nodes.size(), 0); + PopulateNonConstantNodeInputOutputCount(graphDesc.Nodes, nodeInputCounts, nodeOutputCounts); + PopulateConstantNodeInputOutputCount(graphDesc.IntermediateEdges, nodeInputCounts, nodeOutputCounts); + + // populate node input/output names. + std::vector>> nodeToInputNames(graphDesc.Nodes.size()); + std::vector>> nodeToOutputNames(graphDesc.Nodes.size()); + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(graphDesc.Nodes.size()); nodeIndex++) + { + nodeToInputNames[nodeIndex].assign(nodeInputCounts[nodeIndex], builder.CreateString(nullptr, 0)); + nodeToOutputNames[nodeIndex].assign(nodeOutputCounts[nodeIndex], builder.CreateString(nullptr, 0)); + } + PopulateNodeInputOutputNames(builder, graphDesc, graphInputIndexToNameMap, graphOutputIndexToNameMap, nodeToInputNames, nodeToOutputNames); + + // Create flatbuffer node objects + std::vector> nodes(graphDesc.Nodes.size()); + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(graphDesc.Nodes.size()); nodeIndex++) + { + nodes[nodeIndex] = SerializeNode( + builder, + nodeIndex, + graphDesc.Nodes[nodeIndex], + nodeToInputNames[nodeIndex], + nodeToOutputNames[nodeIndex]); + } + + // Convert to std::vector to create the object. + std::vector> graphInputNames(graphDesc.InputCount, builder.CreateString(nullptr, 0)); + std::vector> graphOutputNames(graphDesc.OutputCount, builder.CreateString(nullptr, 0)); + for (const auto& [key, value] : graphInputIndexToNameMap) + { + graphInputNames[key] = value; + } + for (const auto& [key, value] : graphOutputIndexToNameMap) + { + graphOutputNames[key] = value; + } + + flatbuffers::Offset dmlGraphDescOffset = dml::ir::CreateDmlGraphDescDirect( + builder, + &nodes, + &graphInputNames, + &graphOutputNames); + builder.Finish(dmlGraphDescOffset); + return builder.Release(); +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 5c7b7bff1e370..0f0d445a95bae 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -180,32 +180,50 @@ namespace Dml // Convert partitionONNXGraph into DML EP GraphDesc ComPtr device; ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + // This map will be used to transfer the initializer to D3D12 system heap memory. + // 'serializedDmlGraphDesc' will have constant input as intermediate edges, that's why + // we need a mapping between intermediateEdgeIndex and indexedSubGraph's (a given partition) + // input arg index. + // For ex: Let's say intermediate edge index = idx, then + // indexedSubGraphInputArgIdx = constantEdgeIdxToSubgraphInputArgIdxMap[idx]; + // corresponding constant tensor = initializerNameToInitializerMap[indexedSubGraph.GetMetaDef()->inputs[indexedSubGraphInputArgIdx]] + // We are using intermediate edge index as a key because same constant tensor can be used by + // multiple nodes. + std::unordered_map serializedGraphInputIndexToSubgraphInputIndex; + std::unordered_map serializedGraphLargeConstantNameToSubgraphInputIndex; + std::vector> smallConstantData; GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), m_isInitializerTransferable, m_partitionNodePropsMap, - device.Get(), providerImpl, m_modelPath, m_subgraphNodePointers, m_subgraphInputs, - m_subgraphOutputs); + m_subgraphOutputs, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, + smallConstantData); m_outputShapes = graphDesc.outputShapes; // Walk through each graph edge and mark used inputs m_inputsUsed.resize(fusedNodeInputCount, false); - for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) - { - m_inputsUsed[edge.GraphInputIndex] = true; + for (auto it = serializedGraphInputIndexToSubgraphInputIndex.begin(); it != serializedGraphInputIndexToSubgraphInputIndex.end(); it++) { + m_inputsUsed[it->second] = true; + } + for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex.begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex.end(); it++) { + m_inputsUsed[it->second] = true; } // Compile the operator m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( graphDesc, *m_indexedSubGraph, - providerImpl); + providerImpl, + &serializedGraphInputIndexToSubgraphInputIndex, + &serializedGraphLargeConstantNameToSubgraphInputIndex); // Queue references to objects which must be kept alive until resulting GPU work completes m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index a5415ba85f3d3..e1e7eacfbd85d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -24,8 +24,8 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 161; - static constexpr size_t ActivationFunctionCount = 24; + static constexpr auto ValueCount = 168; + static constexpr size_t ActivationFunctionCount = 26; }; template <> @@ -62,7 +62,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 4; + static constexpr auto ValueCount = 5; }; template <> @@ -86,7 +86,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 8; + static constexpr auto ValueCount = 13; }; template <> @@ -119,6 +119,12 @@ struct EnumTraits static constexpr auto ValueCount = 1; }; +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 5; +}; + template constexpr auto EnumValueCount = EnumTraits::ValueCount; @@ -495,12 +501,6 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING; }; -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; -}; - template <> struct OperatorDescTraits { @@ -1029,6 +1029,24 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DIAGONAL_MATRIX1; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT; +}; + template <> struct OperatorDescTraits { @@ -1174,9 +1192,15 @@ struct OperatorDescTraits }; template <> -struct OperatorDescTraits +struct OperatorDescTraits { - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION; + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SWISH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARD_SWISH; }; template @@ -1502,12 +1526,6 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING> using DescType = DML_ROI_POOLING_OPERATOR_DESC; }; -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING> -{ - using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC; -}; - template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE> { @@ -2036,6 +2054,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DIAGONAL_MATRIX1> using DescType = DML_DIAGONAL_MATRIX1_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION> +{ + using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING> +{ + using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT> +{ + using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU> { @@ -2181,14 +2217,20 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_GELU> }; template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SWISH> { - using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC; + using DescType = DML_ACTIVATION_SWISH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARD_SWISH> +{ + using DescType = DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC; }; // Calls a visitor functor, supplying an empty operator desc corresponding to the given DML_OPERATOR_TYPE as // the first argument. -// +// // For example: // Visit(DML_OPERATOR_ELEMENT_WISE_IDENTITY, [](auto tag) { // using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs @@ -2485,6 +2527,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_DIAGONAL_MATRIX1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_MULTIHEAD_ATTENTION: return std::invoke(std::forward(visitor), DML_MULTIHEAD_ATTENTION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: + return std::invoke(std::forward(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_ELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_CELU: @@ -2533,13 +2579,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ACTIVATION_SHRINK_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_GELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward(args)...); - -#pragma warning(push) -#pragma warning(disable: 4063) - case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: - return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); -#pragma warning(pop) - + case DML_OPERATOR_ACTIVATION_SWISH: + return std::invoke(std::forward(visitor), DML_ACTIVATION_SWISH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_HARD_SWISH: + return std::invoke(std::forward(visitor), DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC{}, std::forward(args)...); default: ORT_THROW_HR(E_INVALIDARG); return std::invoke(std::forward(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward(args)...); @@ -2547,7 +2590,55 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args } #pragma warning(pop) +namespace StringifyHelpers +{ +template +inline gsl::czstring ToString(T value) +{ +#ifndef WAI_BUILD_LINUX + // Clang will instantiate this template even if it isn't used, + // so this static_assert will always fire and break the build. + static_assert(false, "Not implemented for this type"); +#endif +} + +template <> +inline gsl::czstring ToString(DML_TENSOR_DATA_TYPE value) +{ + switch (value) + { + case DML_TENSOR_DATA_TYPE_UNKNOWN: return "DML_TENSOR_DATA_TYPE_UNKNOWN"; + case DML_TENSOR_DATA_TYPE_FLOAT32: return "DML_TENSOR_DATA_TYPE_FLOAT32"; + case DML_TENSOR_DATA_TYPE_FLOAT16: return "DML_TENSOR_DATA_TYPE_FLOAT16"; + case DML_TENSOR_DATA_TYPE_UINT32: return "DML_TENSOR_DATA_TYPE_UINT32"; + case DML_TENSOR_DATA_TYPE_UINT16: return "DML_TENSOR_DATA_TYPE_UINT16"; + case DML_TENSOR_DATA_TYPE_UINT8: return "DML_TENSOR_DATA_TYPE_UINT8"; + case DML_TENSOR_DATA_TYPE_INT32: return "DML_TENSOR_DATA_TYPE_INT32"; + case DML_TENSOR_DATA_TYPE_INT16: return "DML_TENSOR_DATA_TYPE_INT16"; + case DML_TENSOR_DATA_TYPE_INT8: return "DML_TENSOR_DATA_TYPE_INT8"; + case DML_TENSOR_DATA_TYPE_FLOAT64: return "DML_TENSOR_DATA_TYPE_FLOAT64"; + case DML_TENSOR_DATA_TYPE_UINT64: return "DML_TENSOR_DATA_TYPE_UINT64"; + case DML_TENSOR_DATA_TYPE_INT64: return "DML_TENSOR_DATA_TYPE_INT64"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_TENSOR_TYPE value) +{ + switch (value) + { + case DML_TENSOR_TYPE_INVALID: return "DML_TENSOR_TYPE_INVALID"; + case DML_TENSOR_TYPE_BUFFER: return "DML_TENSOR_TYPE_BUFFER"; + default: + assert(false); + return ""; + } +} +template <> inline gsl::czstring ToString(DML_OPERATOR_TYPE value) { switch (value) @@ -2561,9 +2652,6 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_ATAN: return "DML_OPERATOR_ELEMENT_WISE_ATAN"; case DML_OPERATOR_ELEMENT_WISE_CEIL: return "DML_OPERATOR_ELEMENT_WISE_CEIL"; case DML_OPERATOR_ELEMENT_WISE_CLIP: return "DML_OPERATOR_ELEMENT_WISE_CLIP"; - case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1"; - case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD"; - case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1"; case DML_OPERATOR_ELEMENT_WISE_COS: return "DML_OPERATOR_ELEMENT_WISE_COS"; case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return "DML_OPERATOR_ELEMENT_WISE_DIVIDE"; case DML_OPERATOR_ELEMENT_WISE_EXP: return "DML_OPERATOR_ELEMENT_WISE_EXP"; @@ -2587,24 +2675,41 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_RECIP: return "DML_OPERATOR_ELEMENT_WISE_RECIP"; case DML_OPERATOR_ELEMENT_WISE_SIN: return "DML_OPERATOR_ELEMENT_WISE_SIN"; case DML_OPERATOR_ELEMENT_WISE_SQRT: return "DML_OPERATOR_ELEMENT_WISE_SQRT"; - case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE"; - case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX"; case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return "DML_OPERATOR_ELEMENT_WISE_SUBTRACT"; case DML_OPERATOR_ELEMENT_WISE_TAN: return "DML_OPERATOR_ELEMENT_WISE_TAN"; case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return "DML_OPERATOR_ELEMENT_WISE_THRESHOLD"; case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR"; case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR"; + case DML_OPERATOR_ACTIVATION_ELU: return "DML_OPERATOR_ACTIVATION_ELU"; + case DML_OPERATOR_ACTIVATION_CELU: return "DML_OPERATOR_ACTIVATION_CELU"; + case DML_OPERATOR_ACTIVATION_HARDMAX: return "DML_OPERATOR_ACTIVATION_HARDMAX"; + case DML_OPERATOR_ACTIVATION_HARDMAX1: return "DML_OPERATOR_ACTIVATION_HARDMAX1"; + case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return "DML_OPERATOR_ACTIVATION_HARD_SIGMOID"; + case DML_OPERATOR_ACTIVATION_IDENTITY: return "DML_OPERATOR_ACTIVATION_IDENTITY"; + case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return "DML_OPERATOR_ACTIVATION_LEAKY_RELU"; + case DML_OPERATOR_ACTIVATION_LINEAR: return "DML_OPERATOR_ACTIVATION_LINEAR"; + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX"; + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: return "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1"; + case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return "DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU"; + case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return "DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS"; + case DML_OPERATOR_ACTIVATION_RELU: return "DML_OPERATOR_ACTIVATION_RELU"; + case DML_OPERATOR_ACTIVATION_SCALED_ELU: return "DML_OPERATOR_ACTIVATION_SCALED_ELU"; + case DML_OPERATOR_ACTIVATION_SCALED_TANH: return "DML_OPERATOR_ACTIVATION_SCALED_TANH"; + case DML_OPERATOR_ACTIVATION_SIGMOID: return "DML_OPERATOR_ACTIVATION_SIGMOID"; + case DML_OPERATOR_ACTIVATION_SOFTMAX: return "DML_OPERATOR_ACTIVATION_SOFTMAX"; + case DML_OPERATOR_ACTIVATION_SOFTMAX1: return "DML_OPERATOR_ACTIVATION_SOFTMAX1"; + case DML_OPERATOR_ACTIVATION_SOFTPLUS: return "DML_OPERATOR_ACTIVATION_SOFTPLUS"; + case DML_OPERATOR_ACTIVATION_SOFTSIGN: return "DML_OPERATOR_ACTIVATION_SOFTSIGN"; + case DML_OPERATOR_ACTIVATION_TANH: return "DML_OPERATOR_ACTIVATION_TANH"; + case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return "DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU"; case DML_OPERATOR_CONVOLUTION: return "DML_OPERATOR_CONVOLUTION"; case DML_OPERATOR_GEMM: return "DML_OPERATOR_GEMM"; case DML_OPERATOR_REDUCE: return "DML_OPERATOR_REDUCE"; - case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN"; - case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX"; case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING"; case DML_OPERATOR_AVERAGE_POOLING1: return "DML_OPERATOR_AVERAGE_POOLING1"; case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING"; case DML_OPERATOR_LP_POOLING1: return "DML_OPERATOR_LP_POOLING1"; case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING"; - case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING"; case DML_OPERATOR_SLICE: return "DML_OPERATOR_SLICE"; case DML_OPERATOR_CAST: return "DML_OPERATOR_CAST"; @@ -2620,18 +2725,15 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_TILE: return "DML_OPERATOR_TILE"; case DML_OPERATOR_TOP_K: return "DML_OPERATOR_TOP_K"; case DML_OPERATOR_BATCH_NORMALIZATION: return "DML_OPERATOR_BATCH_NORMALIZATION"; - case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD"; - case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD"; + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING"; case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION"; case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION"; - case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD"; case DML_OPERATOR_LP_NORMALIZATION: return "DML_OPERATOR_LP_NORMALIZATION"; case DML_OPERATOR_RNN: return "DML_OPERATOR_RNN"; case DML_OPERATOR_LSTM: return "DML_OPERATOR_LSTM"; case DML_OPERATOR_GRU: return "DML_OPERATOR_GRU"; case DML_OPERATOR_ELEMENT_WISE_SIGN: return "DML_OPERATOR_ELEMENT_WISE_SIGN"; case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return "DML_OPERATOR_ELEMENT_WISE_IS_NAN"; - case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE"; case DML_OPERATOR_ELEMENT_WISE_ERF: return "DML_OPERATOR_ELEMENT_WISE_ERF"; case DML_OPERATOR_ELEMENT_WISE_SINH: return "DML_OPERATOR_ELEMENT_WISE_SINH"; case DML_OPERATOR_ELEMENT_WISE_COSH: return "DML_OPERATOR_ELEMENT_WISE_COSH"; @@ -2641,6 +2743,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_ATANH: return "DML_OPERATOR_ELEMENT_WISE_ATANH"; case DML_OPERATOR_ELEMENT_WISE_IF: return "DML_OPERATOR_ELEMENT_WISE_IF"; case DML_OPERATOR_ELEMENT_WISE_ADD1: return "DML_OPERATOR_ELEMENT_WISE_ADD1"; + case DML_OPERATOR_ACTIVATION_SHRINK: return "DML_OPERATOR_ACTIVATION_SHRINK"; + case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; case DML_OPERATOR_MAX_UNPOOLING: return "DML_OPERATOR_MAX_UNPOOLING"; case DML_OPERATOR_DIAGONAL_MATRIX: return "DML_OPERATOR_DIAGONAL_MATRIX"; case DML_OPERATOR_SCATTER: return "DML_OPERATOR_SCATTER"; @@ -2652,10 +2756,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: return "DML_OPERATOR_ELEMENT_WISE_IS_INFINITY"; case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE"; case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR"; - case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT"; case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE"; + case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT"; case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION"; - case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT"; case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES"; case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS"; case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND"; @@ -2684,20 +2787,278 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_RESAMPLE_GRAD: return "DML_OPERATOR_RESAMPLE_GRAD"; case DML_OPERATOR_SLICE_GRAD: return "DML_OPERATOR_SLICE_GRAD"; case DML_OPERATOR_ADAM_OPTIMIZER: return "DML_OPERATOR_ADAM_OPTIMIZER"; + case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN"; + case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX"; case DML_OPERATOR_ROI_ALIGN: return "DML_OPERATOR_ROI_ALIGN"; - case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1"; case DML_OPERATOR_GATHER_ND1: return "DML_OPERATOR_GATHER_ND1"; - case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR"; + case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX"; + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD"; + case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE"; + case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD"; + case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT"; + case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD"; + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD"; case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD"; - case DML_OPERATOR_ROI_ALIGN_GRAD: return "DML_OPERATOR_ROI_ALIGN_GRAD"; - case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING"; + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR"; + case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1"; + case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1"; + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1"; + case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE"; + case DML_OPERATOR_ACTIVATION_GELU: return "DML_OPERATOR_ACTIVATION_GELU"; + case DML_OPERATOR_ACTIVATION_SWISH: return "DML_OPERATOR_ACTIVATION_SWISH"; + case DML_OPERATOR_ACTIVATION_HARD_SWISH: return "DML_OPERATOR_ACTIVATION_HARD_SWISH"; case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2"; case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1"; case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1"; case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION"; + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING"; + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_BINDING_TYPE value) +{ + switch (value) + { + case DML_BINDING_TYPE_NONE: return "DML_BINDING_TYPE_NONE"; + case DML_BINDING_TYPE_BUFFER: return "DML_BINDING_TYPE_BUFFER"; + case DML_BINDING_TYPE_BUFFER_ARRAY: return "DML_BINDING_TYPE_BUFFER_ARRAY"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_REDUCE_FUNCTION value) +{ + switch (value) + { + case DML_REDUCE_FUNCTION_ARGMAX: return "DML_REDUCE_FUNCTION_ARGMAX"; + case DML_REDUCE_FUNCTION_ARGMIN: return "DML_REDUCE_FUNCTION_ARGMIN"; + case DML_REDUCE_FUNCTION_AVERAGE: return "DML_REDUCE_FUNCTION_AVERAGE"; + case DML_REDUCE_FUNCTION_L1: return "DML_REDUCE_FUNCTION_L1"; + case DML_REDUCE_FUNCTION_L2: return "DML_REDUCE_FUNCTION_L2"; + case DML_REDUCE_FUNCTION_LOG_SUM: return "DML_REDUCE_FUNCTION_LOG_SUM"; + case DML_REDUCE_FUNCTION_LOG_SUM_EXP: return "DML_REDUCE_FUNCTION_LOG_SUM_EXP"; + case DML_REDUCE_FUNCTION_MAX: return "DML_REDUCE_FUNCTION_MAX"; + case DML_REDUCE_FUNCTION_MIN: return "DML_REDUCE_FUNCTION_MIN"; + case DML_REDUCE_FUNCTION_MULTIPLY: return "DML_REDUCE_FUNCTION_MULTIPLY"; + case DML_REDUCE_FUNCTION_SUM: return "DML_REDUCE_FUNCTION_SUM"; + case DML_REDUCE_FUNCTION_SUM_SQUARE: return "DML_REDUCE_FUNCTION_SUM_SQUARE"; default: assert(false); return ""; } } + +template <> +inline gsl::czstring ToString(DML_MATRIX_TRANSFORM value) +{ + switch (value) + { + case DML_MATRIX_TRANSFORM_NONE: return "DML_MATRIX_TRANSFORM_NONE"; + case DML_MATRIX_TRANSFORM_TRANSPOSE: return "DML_MATRIX_TRANSFORM_TRANSPOSE"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_CONVOLUTION_MODE value) +{ + switch (value) + { + case DML_CONVOLUTION_MODE_CONVOLUTION: return "DML_CONVOLUTION_MODE_CONVOLUTION"; + case DML_CONVOLUTION_MODE_CROSS_CORRELATION: return "DML_CONVOLUTION_MODE_CROSS_CORRELATION"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_CONVOLUTION_DIRECTION value) +{ + switch (value) + { + case DML_CONVOLUTION_DIRECTION_FORWARD: return "DML_CONVOLUTION_DIRECTION_FORWARD"; + case DML_CONVOLUTION_DIRECTION_BACKWARD: return "DML_CONVOLUTION_DIRECTION_BACKWARD"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_PADDING_MODE value) +{ + switch (value) + { + case DML_PADDING_MODE_CONSTANT: return "DML_PADDING_MODE_CONSTANT"; + case DML_PADDING_MODE_EDGE: return "DML_PADDING_MODE_EDGE"; + case DML_PADDING_MODE_REFLECTION: return "DML_PADDING_MODE_REFLECTION"; + case DML_PADDING_MODE_SYMMETRIC: return "DML_PADDING_MODE_SYMMETRIC"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_INTERPOLATION_MODE value) +{ + switch (value) + { + case DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR: return "DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR"; + case DML_INTERPOLATION_MODE_LINEAR: return "DML_INTERPOLATION_MODE_LINEAR"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_RECURRENT_NETWORK_DIRECTION value) +{ + switch (value) + { + case DML_RECURRENT_NETWORK_DIRECTION_FORWARD: return "DML_RECURRENT_NETWORK_DIRECTION_FORWARD"; + case DML_RECURRENT_NETWORK_DIRECTION_BACKWARD: return "DML_RECURRENT_NETWORK_DIRECTION_BACKWARD"; + case DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL: return "DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_FEATURE value) +{ + switch (value) + { + case DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT: return "DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT"; + case DML_FEATURE_FEATURE_LEVELS: return "DML_FEATURE_FEATURE_LEVELS"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_FEATURE_LEVEL value) +{ + switch (value) + { + case DML_FEATURE_LEVEL_1_0: return "DML_FEATURE_LEVEL_1_0"; + case DML_FEATURE_LEVEL_2_0: return "DML_FEATURE_LEVEL_2_0"; + case DML_FEATURE_LEVEL_2_1: return "DML_FEATURE_LEVEL_2_1"; + case DML_FEATURE_LEVEL_3_0: return "DML_FEATURE_LEVEL_3_0"; + case DML_FEATURE_LEVEL_3_1: return "DML_FEATURE_LEVEL_3_1"; + case DML_FEATURE_LEVEL_4_0: return "DML_FEATURE_LEVEL_4_0"; + case DML_FEATURE_LEVEL_4_1: return "DML_FEATURE_LEVEL_4_1"; + case DML_FEATURE_LEVEL_5_0: return "DML_FEATURE_LEVEL_5_0"; + case DML_FEATURE_LEVEL_5_1: return "DML_FEATURE_LEVEL_5_1"; + case DML_FEATURE_LEVEL_5_2: return "DML_FEATURE_LEVEL_5_2"; + case DML_FEATURE_LEVEL_6_0: return "DML_FEATURE_LEVEL_6_0"; + case DML_FEATURE_LEVEL_6_1: return "DML_FEATURE_LEVEL_6_1"; + case DML_FEATURE_LEVEL_6_2: return "DML_FEATURE_LEVEL_6_2"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_IS_INFINITY_MODE value) +{ + switch (value) + { + case DML_IS_INFINITY_MODE_EITHER: return "DML_IS_INFINITY_MODE_EITHER"; + case DML_IS_INFINITY_MODE_POSITIVE: return "DML_IS_INFINITY_MODE_POSITIVE"; + case DML_IS_INFINITY_MODE_NEGATIVE: return "DML_IS_INFINITY_MODE_NEGATIVE"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_DEPTH_SPACE_ORDER value) +{ + switch (value) + { + case DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW: return "DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW"; + case DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH: return "DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_AXIS_DIRECTION value) +{ + switch (value) + { + case DML_AXIS_DIRECTION_INCREASING: return "DML_AXIS_DIRECTION_INCREASING"; + case DML_AXIS_DIRECTION_DECREASING: return "DML_AXIS_DIRECTION_DECREASING"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_ROUNDING_MODE value) +{ + switch (value) + { + case DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN: return "DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN"; + case DML_ROUNDING_MODE_TOWARD_ZERO: return "DML_ROUNDING_MODE_TOWARD_ZERO"; + case DML_ROUNDING_MODE_TOWARD_INFINITY: return "DML_ROUNDING_MODE_TOWARD_INFINITY"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_RANDOM_GENERATOR_TYPE value) +{ + switch (value) + { + case DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10: return "DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_MULTIHEAD_ATTENTION_MASK_TYPE value) +{ + switch (value) + { + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN"; + default: + assert(false); + return ""; + } +} + + +template +T FromString(std::string_view value); + +} } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 2a82c12872a72..5fe6603c2a0bf 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -618,7 +618,7 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA { constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, }; @@ -633,7 +633,7 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA { constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, }; @@ -869,31 +869,6 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA { DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS, }; - -constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { - "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", - static_cast(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING), - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 13, - DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, -}; - constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -1146,7 +1121,7 @@ constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleGradientTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputBiasGradientTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, }; constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA { @@ -2312,7 +2287,7 @@ constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA { DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA_FIELDS, }; -constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8]{ +constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false }, @@ -2323,7 +2298,7 @@ constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8]{ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false }, }; -constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE2_OPERATOR_SCHEMA{ +constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE2_OPERATOR_SCHEMA { "DML_OPERATOR_RESAMPLE2", DML_OPERATOR_RESAMPLE2, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, @@ -2342,7 +2317,7 @@ constexpr DML_SCHEMA_FIELD DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA_FIELDS[8]{ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false }, }; -constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA{ +constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA { "DML_OPERATOR_RESAMPLE_GRAD1", DML_OPERATOR_RESAMPLE_GRAD1, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, @@ -2350,7 +2325,7 @@ constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA{ DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA_FIELDS, }; -constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6]{ +constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ValueDataType", false }, @@ -2359,7 +2334,7 @@ constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6]{ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT, "DiagonalFillEnd", false }, }; -constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA{ +constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA { "DML_OPERATOR_DIAGONAL_MATRIX1", DML_OPERATOR_DIAGONAL_MATRIX1, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, @@ -2396,6 +2371,48 @@ constexpr DML_OPERATOR_SCHEMA DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA { DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", + DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 13, + DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA { + "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", + DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS, +}; constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -2732,6 +2749,35 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_GELU_OPERATOR_SCHEMA { DML_ACTIVATION_GELU_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SWISH_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SigmoidInputScale", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SWISH_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_SWISH", + DML_OPERATOR_ACTIVATION_SWISH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_ACTIVATION_SWISH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_HARD_SWISH", + DML_OPERATOR_ACTIVATION_HARD_SWISH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_RNN_ZERO_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h new file mode 100644 index 0000000000000..72059b9a3f911 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h @@ -0,0 +1,788 @@ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_ +#define FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_ + +#include "flatbuffers/flatbuffers.h" + +#include "OperatorFieldTypes_generated.h" + +namespace dml { +namespace ir { + +struct ConstantRawData; +struct ConstantRawDataBuilder; + +struct ConstantName; +struct ConstantNameBuilder; + +struct ConstantNodeDesc; +struct ConstantNodeDescBuilder; + +struct DmlBufferTensorDesc; +struct DmlBufferTensorDescBuilder; + +struct OperatorNodeDesc; +struct OperatorNodeDescBuilder; + +struct DmlGraphNode; +struct DmlGraphNodeBuilder; + +struct DmlGraphDesc; +struct DmlGraphDescBuilder; + +enum ConstantNodeDescDetail { + ConstantNodeDescDetail_NONE = 0, + ConstantNodeDescDetail_ConstantName = 1, + ConstantNodeDescDetail_ConstantRawData = 2, + ConstantNodeDescDetail_MIN = ConstantNodeDescDetail_NONE, + ConstantNodeDescDetail_MAX = ConstantNodeDescDetail_ConstantRawData +}; + +inline const ConstantNodeDescDetail (&EnumValuesConstantNodeDescDetail())[3] { + static const ConstantNodeDescDetail values[] = { + ConstantNodeDescDetail_NONE, + ConstantNodeDescDetail_ConstantName, + ConstantNodeDescDetail_ConstantRawData + }; + return values; +} + +inline const char * const *EnumNamesConstantNodeDescDetail() { + static const char * const names[4] = { + "NONE", + "ConstantName", + "ConstantRawData", + nullptr + }; + return names; +} + +inline const char *EnumNameConstantNodeDescDetail(ConstantNodeDescDetail e) { + if (flatbuffers::IsOutRange(e, ConstantNodeDescDetail_NONE, ConstantNodeDescDetail_ConstantRawData)) return ""; + const size_t index = static_cast(e); + return EnumNamesConstantNodeDescDetail()[index]; +} + +template struct ConstantNodeDescDetailTraits { + static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_NONE; +}; + +template<> struct ConstantNodeDescDetailTraits { + static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_ConstantName; +}; + +template<> struct ConstantNodeDescDetailTraits { + static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_ConstantRawData; +}; + +bool VerifyConstantNodeDescDetail(flatbuffers::Verifier &verifier, const void *obj, ConstantNodeDescDetail type); +bool VerifyConstantNodeDescDetailVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +enum NodeDesc { + NodeDesc_NONE = 0, + NodeDesc_OperatorNodeDesc = 1, + NodeDesc_ConstantNodeDesc = 2, + NodeDesc_MIN = NodeDesc_NONE, + NodeDesc_MAX = NodeDesc_ConstantNodeDesc +}; + +inline const NodeDesc (&EnumValuesNodeDesc())[3] { + static const NodeDesc values[] = { + NodeDesc_NONE, + NodeDesc_OperatorNodeDesc, + NodeDesc_ConstantNodeDesc + }; + return values; +} + +inline const char * const *EnumNamesNodeDesc() { + static const char * const names[4] = { + "NONE", + "OperatorNodeDesc", + "ConstantNodeDesc", + nullptr + }; + return names; +} + +inline const char *EnumNameNodeDesc(NodeDesc e) { + if (flatbuffers::IsOutRange(e, NodeDesc_NONE, NodeDesc_ConstantNodeDesc)) return ""; + const size_t index = static_cast(e); + return EnumNamesNodeDesc()[index]; +} + +template struct NodeDescTraits { + static const NodeDesc enum_value = NodeDesc_NONE; +}; + +template<> struct NodeDescTraits { + static const NodeDesc enum_value = NodeDesc_OperatorNodeDesc; +}; + +template<> struct NodeDescTraits { + static const NodeDesc enum_value = NodeDesc_ConstantNodeDesc; +}; + +bool VerifyNodeDesc(flatbuffers::Verifier &verifier, const void *obj, NodeDesc type); +bool VerifyNodeDescVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +struct ConstantRawData FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConstantRawDataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct ConstantRawDataBuilder { + typedef ConstantRawData Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(ConstantRawData::VT_DATA, data); + } + explicit ConstantRawDataBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConstantRawDataBuilder &operator=(const ConstantRawDataBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConstantRawData( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + ConstantRawDataBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateConstantRawDataDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::CreateConstantRawData( + _fbb, + data__); +} + +struct ConstantName FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConstantNameBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + verifier.EndTable(); + } +}; + +struct ConstantNameBuilder { + typedef ConstantName Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(ConstantName::VT_NAME, name); + } + explicit ConstantNameBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConstantNameBuilder &operator=(const ConstantNameBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConstantName( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0) { + ConstantNameBuilder builder_(_fbb); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateConstantNameDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return dml::ir::CreateConstantName( + _fbb, + name__); +} + +struct ConstantNodeDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConstantNodeDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA_TYPE = 4, + VT_DATA = 6 + }; + dml::ir::ConstantNodeDescDetail data_type() const { + return static_cast(GetField(VT_DATA_TYPE, 0)); + } + const void *data() const { + return GetPointer(VT_DATA); + } + template const T *data_as() const; + const dml::ir::ConstantName *data_as_ConstantName() const { + return data_type() == dml::ir::ConstantNodeDescDetail_ConstantName ? static_cast(data()) : nullptr; + } + const dml::ir::ConstantRawData *data_as_ConstantRawData() const { + return data_type() == dml::ir::ConstantNodeDescDetail_ConstantRawData ? static_cast(data()) : nullptr; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DATA_TYPE) && + VerifyOffset(verifier, VT_DATA) && + VerifyConstantNodeDescDetail(verifier, data(), data_type()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::ConstantName *ConstantNodeDesc::data_as() const { + return data_as_ConstantName(); +} + +template<> inline const dml::ir::ConstantRawData *ConstantNodeDesc::data_as() const { + return data_as_ConstantRawData(); +} + +struct ConstantNodeDescBuilder { + typedef ConstantNodeDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data_type(dml::ir::ConstantNodeDescDetail data_type) { + fbb_.AddElement(ConstantNodeDesc::VT_DATA_TYPE, static_cast(data_type), 0); + } + void add_data(flatbuffers::Offset data) { + fbb_.AddOffset(ConstantNodeDesc::VT_DATA, data); + } + explicit ConstantNodeDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConstantNodeDescBuilder &operator=(const ConstantNodeDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConstantNodeDesc( + flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::ConstantNodeDescDetail data_type = dml::ir::ConstantNodeDescDetail_NONE, + flatbuffers::Offset data = 0) { + ConstantNodeDescBuilder builder_(_fbb); + builder_.add_data(data); + builder_.add_data_type(data_type); + return builder_.Finish(); +} + +struct DmlBufferTensorDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DmlBufferTensorDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATATYPE = 4, + VT_SIZES = 6, + VT_STRIDES = 8, + VT_TOTALTENSORSIZEINBYTES = 10 + }; + const flatbuffers::String *dataType() const { + return GetPointer(VT_DATATYPE); + } + const flatbuffers::Vector *sizes() const { + return GetPointer *>(VT_SIZES); + } + const flatbuffers::Vector *strides() const { + return GetPointer *>(VT_STRIDES); + } + uint64_t totalTensorSizeInBytes() const { + return GetField(VT_TOTALTENSORSIZEINBYTES, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATATYPE) && + verifier.VerifyString(dataType()) && + VerifyOffset(verifier, VT_SIZES) && + verifier.VerifyVector(sizes()) && + VerifyOffset(verifier, VT_STRIDES) && + verifier.VerifyVector(strides()) && + VerifyField(verifier, VT_TOTALTENSORSIZEINBYTES) && + verifier.EndTable(); + } +}; + +struct DmlBufferTensorDescBuilder { + typedef DmlBufferTensorDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_dataType(flatbuffers::Offset dataType) { + fbb_.AddOffset(DmlBufferTensorDesc::VT_DATATYPE, dataType); + } + void add_sizes(flatbuffers::Offset> sizes) { + fbb_.AddOffset(DmlBufferTensorDesc::VT_SIZES, sizes); + } + void add_strides(flatbuffers::Offset> strides) { + fbb_.AddOffset(DmlBufferTensorDesc::VT_STRIDES, strides); + } + void add_totalTensorSizeInBytes(uint64_t totalTensorSizeInBytes) { + fbb_.AddElement(DmlBufferTensorDesc::VT_TOTALTENSORSIZEINBYTES, totalTensorSizeInBytes, 0); + } + explicit DmlBufferTensorDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DmlBufferTensorDescBuilder &operator=(const DmlBufferTensorDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDmlBufferTensorDesc( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset dataType = 0, + flatbuffers::Offset> sizes = 0, + flatbuffers::Offset> strides = 0, + uint64_t totalTensorSizeInBytes = 0) { + DmlBufferTensorDescBuilder builder_(_fbb); + builder_.add_totalTensorSizeInBytes(totalTensorSizeInBytes); + builder_.add_strides(strides); + builder_.add_sizes(sizes); + builder_.add_dataType(dataType); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDmlBufferTensorDescDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *dataType = nullptr, + const std::vector *sizes = nullptr, + const std::vector *strides = nullptr, + uint64_t totalTensorSizeInBytes = 0) { + auto dataType__ = dataType ? _fbb.CreateString(dataType) : 0; + auto sizes__ = sizes ? _fbb.CreateVector(*sizes) : 0; + auto strides__ = strides ? _fbb.CreateVector(*strides) : 0; + return dml::ir::CreateDmlBufferTensorDesc( + _fbb, + dataType__, + sizes__, + strides__, + totalTensorSizeInBytes); +} + +struct OperatorNodeDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OperatorNodeDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TYPE = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_ATTRIBUTES = 10 + }; + const flatbuffers::String *type() const { + return GetPointer(VT_TYPE); + } + const flatbuffers::Vector> *inputs() const { + return GetPointer> *>(VT_INPUTS); + } + const flatbuffers::Vector> *outputs() const { + return GetPointer> *>(VT_OUTPUTS); + } + const flatbuffers::Vector> *attributes() const { + return GetPointer> *>(VT_ATTRIBUTES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TYPE) && + verifier.VerifyString(type()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && + verifier.VerifyVectorOfTables(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + verifier.VerifyVectorOfTables(outputs()) && + VerifyOffset(verifier, VT_ATTRIBUTES) && + verifier.VerifyVector(attributes()) && + verifier.VerifyVectorOfTables(attributes()) && + verifier.EndTable(); + } +}; + +struct OperatorNodeDescBuilder { + typedef OperatorNodeDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_type(flatbuffers::Offset type) { + fbb_.AddOffset(OperatorNodeDesc::VT_TYPE, type); + } + void add_inputs(flatbuffers::Offset>> inputs) { + fbb_.AddOffset(OperatorNodeDesc::VT_INPUTS, inputs); + } + void add_outputs(flatbuffers::Offset>> outputs) { + fbb_.AddOffset(OperatorNodeDesc::VT_OUTPUTS, outputs); + } + void add_attributes(flatbuffers::Offset>> attributes) { + fbb_.AddOffset(OperatorNodeDesc::VT_ATTRIBUTES, attributes); + } + explicit OperatorNodeDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OperatorNodeDescBuilder &operator=(const OperatorNodeDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOperatorNodeDesc( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset type = 0, + flatbuffers::Offset>> inputs = 0, + flatbuffers::Offset>> outputs = 0, + flatbuffers::Offset>> attributes = 0) { + OperatorNodeDescBuilder builder_(_fbb); + builder_.add_attributes(attributes); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_type(type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateOperatorNodeDescDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *type = nullptr, + const std::vector> *inputs = nullptr, + const std::vector> *outputs = nullptr, + const std::vector> *attributes = nullptr) { + auto type__ = type ? _fbb.CreateString(type) : 0; + auto inputs__ = inputs ? _fbb.CreateVector>(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector>(*outputs) : 0; + auto attributes__ = attributes ? _fbb.CreateVector>(*attributes) : 0; + return dml::ir::CreateOperatorNodeDesc( + _fbb, + type__, + inputs__, + outputs__, + attributes__); +} + +struct DmlGraphNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DmlGraphNodeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DESC_TYPE = 4, + VT_DESC = 6, + VT_NAME = 8, + VT_INPUTNAMES = 10, + VT_OUTPUTNAMES = 12 + }; + dml::ir::NodeDesc desc_type() const { + return static_cast(GetField(VT_DESC_TYPE, 0)); + } + const void *desc() const { + return GetPointer(VT_DESC); + } + template const T *desc_as() const; + const dml::ir::OperatorNodeDesc *desc_as_OperatorNodeDesc() const { + return desc_type() == dml::ir::NodeDesc_OperatorNodeDesc ? static_cast(desc()) : nullptr; + } + const dml::ir::ConstantNodeDesc *desc_as_ConstantNodeDesc() const { + return desc_type() == dml::ir::NodeDesc_ConstantNodeDesc ? static_cast(desc()) : nullptr; + } + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + const flatbuffers::Vector> *inputNames() const { + return GetPointer> *>(VT_INPUTNAMES); + } + const flatbuffers::Vector> *outputNames() const { + return GetPointer> *>(VT_OUTPUTNAMES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DESC_TYPE) && + VerifyOffset(verifier, VT_DESC) && + VerifyNodeDesc(verifier, desc(), desc_type()) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_INPUTNAMES) && + verifier.VerifyVector(inputNames()) && + verifier.VerifyVectorOfStrings(inputNames()) && + VerifyOffset(verifier, VT_OUTPUTNAMES) && + verifier.VerifyVector(outputNames()) && + verifier.VerifyVectorOfStrings(outputNames()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::OperatorNodeDesc *DmlGraphNode::desc_as() const { + return desc_as_OperatorNodeDesc(); +} + +template<> inline const dml::ir::ConstantNodeDesc *DmlGraphNode::desc_as() const { + return desc_as_ConstantNodeDesc(); +} + +struct DmlGraphNodeBuilder { + typedef DmlGraphNode Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_desc_type(dml::ir::NodeDesc desc_type) { + fbb_.AddElement(DmlGraphNode::VT_DESC_TYPE, static_cast(desc_type), 0); + } + void add_desc(flatbuffers::Offset desc) { + fbb_.AddOffset(DmlGraphNode::VT_DESC, desc); + } + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(DmlGraphNode::VT_NAME, name); + } + void add_inputNames(flatbuffers::Offset>> inputNames) { + fbb_.AddOffset(DmlGraphNode::VT_INPUTNAMES, inputNames); + } + void add_outputNames(flatbuffers::Offset>> outputNames) { + fbb_.AddOffset(DmlGraphNode::VT_OUTPUTNAMES, outputNames); + } + explicit DmlGraphNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DmlGraphNodeBuilder &operator=(const DmlGraphNodeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDmlGraphNode( + flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::NodeDesc desc_type = dml::ir::NodeDesc_NONE, + flatbuffers::Offset desc = 0, + flatbuffers::Offset name = 0, + flatbuffers::Offset>> inputNames = 0, + flatbuffers::Offset>> outputNames = 0) { + DmlGraphNodeBuilder builder_(_fbb); + builder_.add_outputNames(outputNames); + builder_.add_inputNames(inputNames); + builder_.add_name(name); + builder_.add_desc(desc); + builder_.add_desc_type(desc_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDmlGraphNodeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::NodeDesc desc_type = dml::ir::NodeDesc_NONE, + flatbuffers::Offset desc = 0, + const char *name = nullptr, + const std::vector> *inputNames = nullptr, + const std::vector> *outputNames = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto inputNames__ = inputNames ? _fbb.CreateVector>(*inputNames) : 0; + auto outputNames__ = outputNames ? _fbb.CreateVector>(*outputNames) : 0; + return dml::ir::CreateDmlGraphNode( + _fbb, + desc_type, + desc, + name__, + inputNames__, + outputNames__); +} + +struct DmlGraphDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DmlGraphDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NODES = 4, + VT_GRAPHINPUTNAMES = 6, + VT_GRAPHOUTPUTNAMES = 8 + }; + const flatbuffers::Vector> *nodes() const { + return GetPointer> *>(VT_NODES); + } + const flatbuffers::Vector> *graphInputNames() const { + return GetPointer> *>(VT_GRAPHINPUTNAMES); + } + const flatbuffers::Vector> *graphOutputNames() const { + return GetPointer> *>(VT_GRAPHOUTPUTNAMES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NODES) && + verifier.VerifyVector(nodes()) && + verifier.VerifyVectorOfTables(nodes()) && + VerifyOffset(verifier, VT_GRAPHINPUTNAMES) && + verifier.VerifyVector(graphInputNames()) && + verifier.VerifyVectorOfStrings(graphInputNames()) && + VerifyOffset(verifier, VT_GRAPHOUTPUTNAMES) && + verifier.VerifyVector(graphOutputNames()) && + verifier.VerifyVectorOfStrings(graphOutputNames()) && + verifier.EndTable(); + } +}; + +struct DmlGraphDescBuilder { + typedef DmlGraphDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_nodes(flatbuffers::Offset>> nodes) { + fbb_.AddOffset(DmlGraphDesc::VT_NODES, nodes); + } + void add_graphInputNames(flatbuffers::Offset>> graphInputNames) { + fbb_.AddOffset(DmlGraphDesc::VT_GRAPHINPUTNAMES, graphInputNames); + } + void add_graphOutputNames(flatbuffers::Offset>> graphOutputNames) { + fbb_.AddOffset(DmlGraphDesc::VT_GRAPHOUTPUTNAMES, graphOutputNames); + } + explicit DmlGraphDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DmlGraphDescBuilder &operator=(const DmlGraphDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDmlGraphDesc( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> nodes = 0, + flatbuffers::Offset>> graphInputNames = 0, + flatbuffers::Offset>> graphOutputNames = 0) { + DmlGraphDescBuilder builder_(_fbb); + builder_.add_graphOutputNames(graphOutputNames); + builder_.add_graphInputNames(graphInputNames); + builder_.add_nodes(nodes); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDmlGraphDescDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector> *nodes = nullptr, + const std::vector> *graphInputNames = nullptr, + const std::vector> *graphOutputNames = nullptr) { + auto nodes__ = nodes ? _fbb.CreateVector>(*nodes) : 0; + auto graphInputNames__ = graphInputNames ? _fbb.CreateVector>(*graphInputNames) : 0; + auto graphOutputNames__ = graphOutputNames ? _fbb.CreateVector>(*graphOutputNames) : 0; + return dml::ir::CreateDmlGraphDesc( + _fbb, + nodes__, + graphInputNames__, + graphOutputNames__); +} + +inline bool VerifyConstantNodeDescDetail(flatbuffers::Verifier &verifier, const void *obj, ConstantNodeDescDetail type) { + switch (type) { + case ConstantNodeDescDetail_NONE: { + return true; + } + case ConstantNodeDescDetail_ConstantName: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case ConstantNodeDescDetail_ConstantRawData: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyConstantNodeDescDetailVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyConstantNodeDescDetail( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline bool VerifyNodeDesc(flatbuffers::Verifier &verifier, const void *obj, NodeDesc type) { + switch (type) { + case NodeDesc_NONE: { + return true; + } + case NodeDesc_OperatorNodeDesc: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case NodeDesc_ConstantNodeDesc: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyNodeDescVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyNodeDesc( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline const dml::ir::DmlGraphDesc *GetDmlGraphDesc(const void *buf) { + return flatbuffers::GetRoot(buf); +} + +inline const dml::ir::DmlGraphDesc *GetSizePrefixedDmlGraphDesc(const void *buf) { + return flatbuffers::GetSizePrefixedRoot(buf); +} + +inline bool VerifyDmlGraphDescBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(nullptr); +} + +inline bool VerifySizePrefixedDmlGraphDescBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(nullptr); +} + +inline void FinishDmlGraphDescBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.Finish(root); +} + +inline void FinishSizePrefixedDmlGraphDescBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root); +} + +} // namespace ir +} // namespace dml + +#endif // FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_ diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h new file mode 100644 index 0000000000000..9decf0dce1bb2 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "DmlSerializedGraphDesc.h" + +struct NodeIndex +{ + uint32_t nodeIndex; + uint32_t nodeOutputIndex; +}; + +DmlSerializedGraphDesc DeserializeDmlGraph( + const uint8_t* flatbufferGraphDescBlob, + /*out*/ std::vector>& rawData); \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h new file mode 100644 index 0000000000000..d8d069da906b7 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "DmlGraphDesc_generated.h" + +struct DmlSerializedGraphDesc; + +flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& graphDesc); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h new file mode 100644 index 0000000000000..51c3d6c81244b --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h @@ -0,0 +1,73 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//----------------------------------------------------------------------------- + +#pragma once + +struct ConstantName +{ + std::string name; +}; + +struct ConstantData +{ + std::byte* data; + uint64_t dataSize; +}; + +using DmlSerializedGraphNodeConstantVariant = std::variant< + ConstantName, + ConstantData +>; + +using DmlSerializedGraphNodeDescVariant = std::variant< + AbstractOperatorDesc, + DmlSerializedGraphNodeConstantVariant +>; + +struct DmlSerializedGraphNode +{ + DmlSerializedGraphNodeDescVariant Desc; + std::string Name; +}; + +struct DmlInputSerializedGraphEdge +{ + uint32_t GraphInputIndex; + uint32_t ToNodeIndex; + uint32_t ToNodeInputIndex; + std::string Name; +}; + +struct DmlOutputSerializedGraphEdge +{ + uint32_t FromNodeIndex; + uint32_t FromNodeOutputIndex; + uint32_t GraphOutputIndex; + std::string Name; +}; + +struct DmlIntermediateSerializedGraphEdge +{ + uint32_t FromNodeIndex; + uint32_t FromNodeOutputIndex; + uint32_t ToNodeIndex; + uint32_t ToNodeInputIndex; + std::string Name; +}; + +struct DmlSerializedGraphDesc +{ + uint32_t InputCount; + uint32_t OutputCount; + // nodes must be present in topological order for deserialization to work + // because while creating a intermediate edge during deserialization, node (from + // which given intermediate edge is outputting) must be visited before than the node + // (to which given intermediate edge is inputting) + std::vector Nodes; + std::vector InputEdges; + std::vector OutputEdges; + std::vector IntermediateEdges; +}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 99218c135f058..4be41ad3924a2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -425,7 +425,6 @@ inline std::vector GetFields(const DML_AVERAGE_POOLING_OPERATOR_D OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.IncludePadding))), }; } - inline std::vector GetFields(const DML_AVERAGE_POOLING1_OPERATOR_DESC& desc) { return { @@ -502,24 +501,6 @@ inline std::vector GetFields(const DML_ROI_POOLING_OPERATOR_DESC& OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PooledSize))), }; } -inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), - }; -} inline std::vector GetFields(const DML_SLICE_OPERATOR_DESC& desc) { return { @@ -1488,6 +1469,37 @@ inline std::vector GetFields(const DML_MULTIHEAD_ATTENTION_OPERAT OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[17], ToOperatorFieldType(static_cast(desc.MaskType))), }; } +inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} +inline std::vector GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1680,6 +1692,23 @@ inline std::vector GetFields(const DML_ACTIVATION_GELU_OPERATOR_D OperatorField(&DML_ACTIVATION_GELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_ACTIVATION_SWISH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.SigmoidInputScale))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Beta))), + }; +} inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) { switch (operatorType) @@ -1826,6 +1855,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_RESAMPLE_GRAD1: return DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA; case DML_OPERATOR_DIAGONAL_MATRIX1: return DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA; case DML_OPERATOR_MULTIHEAD_ATTENTION: return DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA; + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; @@ -1850,6 +1881,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_SHRINK: return DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_GELU: return DML_ACTIVATION_GELU_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_SWISH: return DML_ACTIVATION_SWISH_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_HARD_SWISH: return DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA; default: ORT_THROW_HR(E_INVALIDARG); @@ -2431,6 +2464,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return AbstractOperatorDesc( + &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: + return AbstractOperatorDesc( + &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, @@ -2527,13 +2568,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ACTIVATION_GELU_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); -#pragma warning(push) -#pragma warning(disable: 4063) - case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + case DML_OPERATOR_ACTIVATION_SWISH: return AbstractOperatorDesc( - &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); -#pragma warning(pop) + &DML_ACTIVATION_SWISH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_HARD_SWISH: + return AbstractOperatorDesc( + &DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); default: ORT_THROW_HR(E_INVALIDARG); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h index 25f0dd26c6067..a94bb67b68d36 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h @@ -15,32 +15,34 @@ using ApiAttributeVariant = std::variant< const FLOAT*, const DML_SCALE_BIAS*, DML_SIZE_2D, - DML_SCALAR_UNION + DML_SCALAR_UNION, + BOOL >; namespace OperatorFieldTypes { using TensorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC using TensorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY - using OperatorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC - using OperatorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY + using FusedActivationOperatorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC + using FusedActivationOperatorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY using UInt = uint32_t; // DML_SCHEMA_FIELD_TYPE_UINT using UInt64 = uint64_t; // DML_SCHEMA_FIELD_TYPE_UINT64 using Int = int32_t; // DML_SCHEMA_FIELD_TYPE_INT using Float = float; // DML_SCHEMA_FIELD_TYPE_FLOAT - using UIntArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY - using IntArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_INT_ARRAY - using FloatArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY + using UIntArray = std::vector; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY + using IntArray = std::vector; // DML_SCHEMA_FIELD_TYPE_INT_ARRAY + using FloatArray = std::vector; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY using ScaleBias = std::optional; // DML_SCHEMA_FIELD_TYPE_SCALE_BIAS using Size2D = DML_SIZE_2D; // DML_SCHEMA_FIELD_TYPE_SIZE_2D using ScalarUnion = DML_SCALAR_UNION; // DML_SCHEMA_FIELD_TYPE_SCALAR_UNION + using Bool = bool; // DML_SCHEMA_FIELD_TYPE_BOOL } using OperatorFieldVariant = std::variant< OperatorFieldTypes::TensorDesc, OperatorFieldTypes::TensorDescArray, - OperatorFieldTypes::OperatorDesc, - OperatorFieldTypes::OperatorDescArray, + OperatorFieldTypes::FusedActivationOperatorDesc, + OperatorFieldTypes::FusedActivationOperatorDescArray, OperatorFieldTypes::UInt, OperatorFieldTypes::UInt64, OperatorFieldTypes::Int, @@ -50,7 +52,8 @@ using OperatorFieldVariant = std::variant< OperatorFieldTypes::FloatArray, OperatorFieldTypes::ScaleBias, OperatorFieldTypes::Size2D, - OperatorFieldTypes::ScalarUnion + OperatorFieldTypes::ScalarUnion, + OperatorFieldTypes::Bool >; class OperatorField @@ -80,11 +83,11 @@ class OperatorField const OperatorFieldTypes::TensorDescArray& AsTensorDescArray() const { return std::get(m_data); } OperatorFieldTypes::TensorDescArray& AsTensorDescArray() { return std::get(m_data); } - const OperatorFieldTypes::OperatorDesc& AsOperatorDesc() const { return std::get(m_data); } - OperatorFieldTypes::OperatorDesc& AsOperatorDesc() { return std::get(m_data); } + const OperatorFieldTypes::FusedActivationOperatorDesc& AsFusedActivationOperatorDesc() const { return std::get(m_data); } + OperatorFieldTypes::FusedActivationOperatorDesc& AsFusedActivationOperatorDesc() { return std::get(m_data); } - const OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() const { return std::get(m_data); } - OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() { return std::get(m_data); } + const OperatorFieldTypes::FusedActivationOperatorDescArray& AsFusedActivationOperatorDescArray() const { return std::get(m_data); } + OperatorFieldTypes::FusedActivationOperatorDescArray& AsFusedActivationOperatorDescArray() { return std::get(m_data); } const OperatorFieldTypes::UInt& AsUInt() const { return std::get(m_data); } OperatorFieldTypes::UInt& AsUInt() { return std::get(m_data); } @@ -116,6 +119,9 @@ class OperatorField const OperatorFieldTypes::ScalarUnion& AsScalarUnion() const { return std::get(m_data); } OperatorFieldTypes::ScalarUnion& AsScalarUnion() { return std::get(m_data); } + const OperatorFieldTypes::Bool& AsBool() const { return std::get(m_data); } + OperatorFieldTypes::Bool& AsBool() { return std::get(m_data); } + private: const DML_SCHEMA_FIELD* m_schema; OperatorFieldVariant m_data; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h new file mode 100644 index 0000000000000..167a913bb0132 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h @@ -0,0 +1,1318 @@ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_ +#define FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace dml { +namespace ir { +namespace operatorFieldTypes { + +struct AttributeDesc; +struct AttributeDescBuilder; + +struct Activation; +struct ActivationBuilder; + +struct ActivationArray; +struct ActivationArrayBuilder; + +struct UInt8; + +struct UInt16; + +struct UInt32; + +struct UInt64; + +struct Int8; + +struct Int16; + +struct Int32; + +struct Int64; + +struct Float32; + +struct Float64; + +struct UIntArray; +struct UIntArrayBuilder; + +struct IntArray; +struct IntArrayBuilder; + +struct FloatArray; +struct FloatArrayBuilder; + +struct ScaleBias; + +struct Size2D; + +struct ByteArray; + +struct ScalarUnionData; +struct ScalarUnionDataBuilder; + +struct Bool; + +enum AttributeFieldVariant { + AttributeFieldVariant_NONE = 0, + AttributeFieldVariant_Activation = 1, + AttributeFieldVariant_ActivationArray = 2, + AttributeFieldVariant_UInt32 = 3, + AttributeFieldVariant_UInt64 = 4, + AttributeFieldVariant_Int32 = 5, + AttributeFieldVariant_Float32 = 6, + AttributeFieldVariant_UIntArray = 7, + AttributeFieldVariant_IntArray = 8, + AttributeFieldVariant_FloatArray = 9, + AttributeFieldVariant_ScaleBias = 10, + AttributeFieldVariant_Size2D = 11, + AttributeFieldVariant_ScalarUnionData = 12, + AttributeFieldVariant_Bool = 13, + AttributeFieldVariant_MIN = AttributeFieldVariant_NONE, + AttributeFieldVariant_MAX = AttributeFieldVariant_Bool +}; + +inline const AttributeFieldVariant (&EnumValuesAttributeFieldVariant())[14] { + static const AttributeFieldVariant values[] = { + AttributeFieldVariant_NONE, + AttributeFieldVariant_Activation, + AttributeFieldVariant_ActivationArray, + AttributeFieldVariant_UInt32, + AttributeFieldVariant_UInt64, + AttributeFieldVariant_Int32, + AttributeFieldVariant_Float32, + AttributeFieldVariant_UIntArray, + AttributeFieldVariant_IntArray, + AttributeFieldVariant_FloatArray, + AttributeFieldVariant_ScaleBias, + AttributeFieldVariant_Size2D, + AttributeFieldVariant_ScalarUnionData, + AttributeFieldVariant_Bool + }; + return values; +} + +inline const char * const *EnumNamesAttributeFieldVariant() { + static const char * const names[15] = { + "NONE", + "Activation", + "ActivationArray", + "UInt32", + "UInt64", + "Int32", + "Float32", + "UIntArray", + "IntArray", + "FloatArray", + "ScaleBias", + "Size2D", + "ScalarUnionData", + "Bool", + nullptr + }; + return names; +} + +inline const char *EnumNameAttributeFieldVariant(AttributeFieldVariant e) { + if (flatbuffers::IsOutRange(e, AttributeFieldVariant_NONE, AttributeFieldVariant_Bool)) return ""; + const size_t index = static_cast(e); + return EnumNamesAttributeFieldVariant()[index]; +} + +template struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_NONE; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Activation; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_ActivationArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_UInt32; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_UInt64; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Int32; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Float32; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_UIntArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_IntArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_FloatArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_ScaleBias; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Size2D; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_ScalarUnionData; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Bool; +}; + +bool VerifyAttributeFieldVariant(flatbuffers::Verifier &verifier, const void *obj, AttributeFieldVariant type); +bool VerifyAttributeFieldVariantVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +enum ScalarVariant { + ScalarVariant_NONE = 0, + ScalarVariant_ByteArray = 1, + ScalarVariant_Int8 = 2, + ScalarVariant_UInt8 = 3, + ScalarVariant_Int16 = 4, + ScalarVariant_UInt16 = 5, + ScalarVariant_Int32 = 6, + ScalarVariant_UInt32 = 7, + ScalarVariant_Int64 = 8, + ScalarVariant_UInt64 = 9, + ScalarVariant_Float32 = 10, + ScalarVariant_Float64 = 11, + ScalarVariant_MIN = ScalarVariant_NONE, + ScalarVariant_MAX = ScalarVariant_Float64 +}; + +inline const ScalarVariant (&EnumValuesScalarVariant())[12] { + static const ScalarVariant values[] = { + ScalarVariant_NONE, + ScalarVariant_ByteArray, + ScalarVariant_Int8, + ScalarVariant_UInt8, + ScalarVariant_Int16, + ScalarVariant_UInt16, + ScalarVariant_Int32, + ScalarVariant_UInt32, + ScalarVariant_Int64, + ScalarVariant_UInt64, + ScalarVariant_Float32, + ScalarVariant_Float64 + }; + return values; +} + +inline const char * const *EnumNamesScalarVariant() { + static const char * const names[13] = { + "NONE", + "ByteArray", + "Int8", + "UInt8", + "Int16", + "UInt16", + "Int32", + "UInt32", + "Int64", + "UInt64", + "Float32", + "Float64", + nullptr + }; + return names; +} + +inline const char *EnumNameScalarVariant(ScalarVariant e) { + if (flatbuffers::IsOutRange(e, ScalarVariant_NONE, ScalarVariant_Float64)) return ""; + const size_t index = static_cast(e); + return EnumNamesScalarVariant()[index]; +} + +template struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_NONE; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_ByteArray; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int8; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt8; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int16; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt16; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int32; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt32; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int64; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt64; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Float32; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Float64; +}; + +bool VerifyScalarVariant(flatbuffers::Verifier &verifier, const void *obj, ScalarVariant type); +bool VerifyScalarVariantVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) UInt8 FLATBUFFERS_FINAL_CLASS { + private: + uint8_t data_; + + public: + UInt8() { + memset(static_cast(this), 0, sizeof(UInt8)); + } + UInt8(uint8_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + uint8_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(uint8_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt8, 1); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(2) UInt16 FLATBUFFERS_FINAL_CLASS { + private: + uint16_t data_; + + public: + UInt16() { + memset(static_cast(this), 0, sizeof(UInt16)); + } + UInt16(uint16_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + uint16_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(uint16_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt16, 2); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) UInt32 FLATBUFFERS_FINAL_CLASS { + private: + uint32_t data_; + + public: + UInt32() { + memset(static_cast(this), 0, sizeof(UInt32)); + } + UInt32(uint32_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + uint32_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(uint32_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt32, 4); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) UInt64 FLATBUFFERS_FINAL_CLASS { + private: + uint64_t data_; + + public: + UInt64() { + memset(static_cast(this), 0, sizeof(UInt64)); + } + UInt64(uint64_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + uint64_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(uint64_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt64, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) Int8 FLATBUFFERS_FINAL_CLASS { + private: + int8_t data_; + + public: + Int8() { + memset(static_cast(this), 0, sizeof(Int8)); + } + Int8(int8_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + int8_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(int8_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int8, 1); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(2) Int16 FLATBUFFERS_FINAL_CLASS { + private: + int16_t data_; + + public: + Int16() { + memset(static_cast(this), 0, sizeof(Int16)); + } + Int16(int16_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + int16_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(int16_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int16, 2); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Int32 FLATBUFFERS_FINAL_CLASS { + private: + int32_t data_; + + public: + Int32() { + memset(static_cast(this), 0, sizeof(Int32)); + } + Int32(int32_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + int32_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(int32_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int32, 4); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Int64 FLATBUFFERS_FINAL_CLASS { + private: + int64_t data_; + + public: + Int64() { + memset(static_cast(this), 0, sizeof(Int64)); + } + Int64(int64_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + int64_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(int64_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int64, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Float32 FLATBUFFERS_FINAL_CLASS { + private: + float data_; + + public: + Float32() { + memset(static_cast(this), 0, sizeof(Float32)); + } + Float32(float _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + float data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(float _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Float32, 4); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Float64 FLATBUFFERS_FINAL_CLASS { + private: + double data_; + + public: + Float64() { + memset(static_cast(this), 0, sizeof(Float64)); + } + Float64(double _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + double data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(double _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Float64, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) ScaleBias FLATBUFFERS_FINAL_CLASS { + private: + float scale_; + float bias_; + + public: + ScaleBias() { + memset(static_cast(this), 0, sizeof(ScaleBias)); + } + ScaleBias(float _scale, float _bias) + : scale_(flatbuffers::EndianScalar(_scale)), + bias_(flatbuffers::EndianScalar(_bias)) { + } + float scale() const { + return flatbuffers::EndianScalar(scale_); + } + void mutate_scale(float _scale) { + flatbuffers::WriteScalar(&scale_, _scale); + } + float bias() const { + return flatbuffers::EndianScalar(bias_); + } + void mutate_bias(float _bias) { + flatbuffers::WriteScalar(&bias_, _bias); + } +}; +FLATBUFFERS_STRUCT_END(ScaleBias, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Size2D FLATBUFFERS_FINAL_CLASS { + private: + uint32_t width_; + uint32_t height_; + + public: + Size2D() { + memset(static_cast(this), 0, sizeof(Size2D)); + } + Size2D(uint32_t _width, uint32_t _height) + : width_(flatbuffers::EndianScalar(_width)), + height_(flatbuffers::EndianScalar(_height)) { + } + uint32_t width() const { + return flatbuffers::EndianScalar(width_); + } + void mutate_width(uint32_t _width) { + flatbuffers::WriteScalar(&width_, _width); + } + uint32_t height() const { + return flatbuffers::EndianScalar(height_); + } + void mutate_height(uint32_t _height) { + flatbuffers::WriteScalar(&height_, _height); + } +}; +FLATBUFFERS_STRUCT_END(Size2D, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) ByteArray FLATBUFFERS_FINAL_CLASS { + private: + uint8_t data_[8]; + + public: + ByteArray() { + memset(static_cast(this), 0, sizeof(ByteArray)); + } + const flatbuffers::Array *data() const { + return reinterpret_cast *>(data_); + } + flatbuffers::Array *mutable_data() { + return reinterpret_cast *>(data_); + } +}; +FLATBUFFERS_STRUCT_END(ByteArray, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) Bool FLATBUFFERS_FINAL_CLASS { + private: + uint8_t data_; + + public: + Bool() { + memset(static_cast(this), 0, sizeof(Bool)); + } + Bool(bool _data) + : data_(flatbuffers::EndianScalar(static_cast(_data))) { + } + bool data() const { + return flatbuffers::EndianScalar(data_) != 0; + } + void mutate_data(bool _data) { + flatbuffers::WriteScalar(&data_, static_cast(_data)); + } +}; +FLATBUFFERS_STRUCT_END(Bool, 1); + +struct AttributeDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef AttributeDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_VAL_TYPE = 6, + VT_VAL = 8 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + flatbuffers::String *mutable_name() { + return GetPointer(VT_NAME); + } + dml::ir::operatorFieldTypes::AttributeFieldVariant val_type() const { + return static_cast(GetField(VT_VAL_TYPE, 0)); + } + const void *val() const { + return GetPointer(VT_VAL); + } + template const T *val_as() const; + const dml::ir::operatorFieldTypes::Activation *val_as_Activation() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::ActivationArray *val_as_ActivationArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt32 *val_as_UInt32() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt32 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt64 *val_as_UInt64() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt64 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int32 *val_as_Int32() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Int32 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Float32 *val_as_Float32() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Float32 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::UIntArray *val_as_UIntArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UIntArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::IntArray *val_as_IntArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_IntArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::FloatArray *val_as_FloatArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_FloatArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::ScaleBias *val_as_ScaleBias() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Size2D *val_as_Size2D() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Size2D ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::ScalarUnionData *val_as_ScalarUnionData() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ScalarUnionData ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Bool *val_as_Bool() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Bool ? static_cast(val()) : nullptr; + } + void *mutable_val() { + return GetPointer(VT_VAL); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyField(verifier, VT_VAL_TYPE) && + VerifyOffset(verifier, VT_VAL) && + VerifyAttributeFieldVariant(verifier, val(), val_type()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::operatorFieldTypes::Activation *AttributeDesc::val_as() const { + return val_as_Activation(); +} + +template<> inline const dml::ir::operatorFieldTypes::ActivationArray *AttributeDesc::val_as() const { + return val_as_ActivationArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt32 *AttributeDesc::val_as() const { + return val_as_UInt32(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt64 *AttributeDesc::val_as() const { + return val_as_UInt64(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int32 *AttributeDesc::val_as() const { + return val_as_Int32(); +} + +template<> inline const dml::ir::operatorFieldTypes::Float32 *AttributeDesc::val_as() const { + return val_as_Float32(); +} + +template<> inline const dml::ir::operatorFieldTypes::UIntArray *AttributeDesc::val_as() const { + return val_as_UIntArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::IntArray *AttributeDesc::val_as() const { + return val_as_IntArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::FloatArray *AttributeDesc::val_as() const { + return val_as_FloatArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::ScaleBias *AttributeDesc::val_as() const { + return val_as_ScaleBias(); +} + +template<> inline const dml::ir::operatorFieldTypes::Size2D *AttributeDesc::val_as() const { + return val_as_Size2D(); +} + +template<> inline const dml::ir::operatorFieldTypes::ScalarUnionData *AttributeDesc::val_as() const { + return val_as_ScalarUnionData(); +} + +template<> inline const dml::ir::operatorFieldTypes::Bool *AttributeDesc::val_as() const { + return val_as_Bool(); +} + +struct AttributeDescBuilder { + typedef AttributeDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(AttributeDesc::VT_NAME, name); + } + void add_val_type(dml::ir::operatorFieldTypes::AttributeFieldVariant val_type) { + fbb_.AddElement(AttributeDesc::VT_VAL_TYPE, static_cast(val_type), 0); + } + void add_val(flatbuffers::Offset val) { + fbb_.AddOffset(AttributeDesc::VT_VAL, val); + } + explicit AttributeDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + AttributeDescBuilder &operator=(const AttributeDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateAttributeDesc( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0, + dml::ir::operatorFieldTypes::AttributeFieldVariant val_type = dml::ir::operatorFieldTypes::AttributeFieldVariant_NONE, + flatbuffers::Offset val = 0) { + AttributeDescBuilder builder_(_fbb); + builder_.add_val(val); + builder_.add_name(name); + builder_.add_val_type(val_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateAttributeDescDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant val_type = dml::ir::operatorFieldTypes::AttributeFieldVariant_NONE, + flatbuffers::Offset val = 0) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return dml::ir::operatorFieldTypes::CreateAttributeDesc( + _fbb, + name__, + val_type, + val); +} + +struct Activation FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ActivationBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TYPE = 4, + VT_ATTRIBUTES = 6 + }; + const flatbuffers::String *type() const { + return GetPointer(VT_TYPE); + } + flatbuffers::String *mutable_type() { + return GetPointer(VT_TYPE); + } + const flatbuffers::Vector> *attributes() const { + return GetPointer> *>(VT_ATTRIBUTES); + } + flatbuffers::Vector> *mutable_attributes() { + return GetPointer> *>(VT_ATTRIBUTES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TYPE) && + verifier.VerifyString(type()) && + VerifyOffset(verifier, VT_ATTRIBUTES) && + verifier.VerifyVector(attributes()) && + verifier.VerifyVectorOfTables(attributes()) && + verifier.EndTable(); + } +}; + +struct ActivationBuilder { + typedef Activation Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_type(flatbuffers::Offset type) { + fbb_.AddOffset(Activation::VT_TYPE, type); + } + void add_attributes(flatbuffers::Offset>> attributes) { + fbb_.AddOffset(Activation::VT_ATTRIBUTES, attributes); + } + explicit ActivationBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ActivationBuilder &operator=(const ActivationBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateActivation( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset type = 0, + flatbuffers::Offset>> attributes = 0) { + ActivationBuilder builder_(_fbb); + builder_.add_attributes(attributes); + builder_.add_type(type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateActivationDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *type = nullptr, + const std::vector> *attributes = nullptr) { + auto type__ = type ? _fbb.CreateString(type) : 0; + auto attributes__ = attributes ? _fbb.CreateVector>(*attributes) : 0; + return dml::ir::operatorFieldTypes::CreateActivation( + _fbb, + type__, + attributes__); +} + +struct ActivationArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ActivationArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector> *data() const { + return GetPointer> *>(VT_DATA); + } + flatbuffers::Vector> *mutable_data() { + return GetPointer> *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.VerifyVectorOfTables(data()) && + verifier.EndTable(); + } +}; + +struct ActivationArrayBuilder { + typedef ActivationArray Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset>> data) { + fbb_.AddOffset(ActivationArray::VT_DATA, data); + } + explicit ActivationArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ActivationArrayBuilder &operator=(const ActivationArrayBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateActivationArray( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> data = 0) { + ActivationArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateActivationArrayDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector> *data = nullptr) { + auto data__ = data ? _fbb.CreateVector>(*data) : 0; + return dml::ir::operatorFieldTypes::CreateActivationArray( + _fbb, + data__); +} + +struct UIntArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef UIntArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + flatbuffers::Vector *mutable_data() { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct UIntArrayBuilder { + typedef UIntArray Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(UIntArray::VT_DATA, data); + } + explicit UIntArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + UIntArrayBuilder &operator=(const UIntArrayBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateUIntArray( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + UIntArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateUIntArrayDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::operatorFieldTypes::CreateUIntArray( + _fbb, + data__); +} + +struct IntArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef IntArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + flatbuffers::Vector *mutable_data() { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct IntArrayBuilder { + typedef IntArray Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(IntArray::VT_DATA, data); + } + explicit IntArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + IntArrayBuilder &operator=(const IntArrayBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateIntArray( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + IntArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateIntArrayDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::operatorFieldTypes::CreateIntArray( + _fbb, + data__); +} + +struct FloatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef FloatArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + flatbuffers::Vector *mutable_data() { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct FloatArrayBuilder { + typedef FloatArray Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(FloatArray::VT_DATA, data); + } + explicit FloatArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + FloatArrayBuilder &operator=(const FloatArrayBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateFloatArray( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + FloatArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateFloatArrayDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::operatorFieldTypes::CreateFloatArray( + _fbb, + data__); +} + +struct ScalarUnionData FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ScalarUnionDataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA_TYPE = 4, + VT_DATA = 6 + }; + dml::ir::operatorFieldTypes::ScalarVariant data_type() const { + return static_cast(GetField(VT_DATA_TYPE, 0)); + } + const void *data() const { + return GetPointer(VT_DATA); + } + template const T *data_as() const; + const dml::ir::operatorFieldTypes::ByteArray *data_as_ByteArray() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_ByteArray ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int8 *data_as_Int8() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int8 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt8 *data_as_UInt8() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt8 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int16 *data_as_Int16() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int16 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt16 *data_as_UInt16() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt16 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int32 *data_as_Int32() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int32 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt32 *data_as_UInt32() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt32 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int64 *data_as_Int64() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int64 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt64 *data_as_UInt64() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt64 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Float32 *data_as_Float32() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Float32 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Float64 *data_as_Float64() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Float64 ? static_cast(data()) : nullptr; + } + void *mutable_data() { + return GetPointer(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DATA_TYPE) && + VerifyOffset(verifier, VT_DATA) && + VerifyScalarVariant(verifier, data(), data_type()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::operatorFieldTypes::ByteArray *ScalarUnionData::data_as() const { + return data_as_ByteArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int8 *ScalarUnionData::data_as() const { + return data_as_Int8(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt8 *ScalarUnionData::data_as() const { + return data_as_UInt8(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int16 *ScalarUnionData::data_as() const { + return data_as_Int16(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt16 *ScalarUnionData::data_as() const { + return data_as_UInt16(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int32 *ScalarUnionData::data_as() const { + return data_as_Int32(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt32 *ScalarUnionData::data_as() const { + return data_as_UInt32(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int64 *ScalarUnionData::data_as() const { + return data_as_Int64(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt64 *ScalarUnionData::data_as() const { + return data_as_UInt64(); +} + +template<> inline const dml::ir::operatorFieldTypes::Float32 *ScalarUnionData::data_as() const { + return data_as_Float32(); +} + +template<> inline const dml::ir::operatorFieldTypes::Float64 *ScalarUnionData::data_as() const { + return data_as_Float64(); +} + +struct ScalarUnionDataBuilder { + typedef ScalarUnionData Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data_type(dml::ir::operatorFieldTypes::ScalarVariant data_type) { + fbb_.AddElement(ScalarUnionData::VT_DATA_TYPE, static_cast(data_type), 0); + } + void add_data(flatbuffers::Offset data) { + fbb_.AddOffset(ScalarUnionData::VT_DATA, data); + } + explicit ScalarUnionDataBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ScalarUnionDataBuilder &operator=(const ScalarUnionDataBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateScalarUnionData( + flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::operatorFieldTypes::ScalarVariant data_type = dml::ir::operatorFieldTypes::ScalarVariant_NONE, + flatbuffers::Offset data = 0) { + ScalarUnionDataBuilder builder_(_fbb); + builder_.add_data(data); + builder_.add_data_type(data_type); + return builder_.Finish(); +} + +inline bool VerifyAttributeFieldVariant(flatbuffers::Verifier &verifier, const void *obj, AttributeFieldVariant type) { + switch (type) { + case AttributeFieldVariant_NONE: { + return true; + } + case AttributeFieldVariant_Activation: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_ActivationArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_UInt32: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_UInt64: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_Int32: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_Float32: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_UIntArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_IntArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_FloatArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_ScaleBias: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_Size2D: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_ScalarUnionData: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_Bool: { + return verifier.Verify(static_cast(obj), 0); + } + default: return true; + } +} + +inline bool VerifyAttributeFieldVariantVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyAttributeFieldVariant( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline bool VerifyScalarVariant(flatbuffers::Verifier &verifier, const void *obj, ScalarVariant type) { + switch (type) { + case ScalarVariant_NONE: { + return true; + } + case ScalarVariant_ByteArray: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Int8: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_UInt8: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Int16: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_UInt16: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Int32: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_UInt32: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Int64: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_UInt64: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Float32: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Float64: { + return verifier.Verify(static_cast(obj), 0); + } + default: return true; + } +} + +inline bool VerifyScalarVariantVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyScalarVariant( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +} // namespace operatorFieldTypes +} // namespace ir +} // namespace dml + +#endif // FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_ diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h index 5285481485184..1bc694dfe90c2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h @@ -26,14 +26,14 @@ namespace SchemaHelpers return field; } - inline OperatorFieldTypes::OperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value) + inline OperatorFieldTypes::FusedActivationOperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value) { - return value ? OperatorFieldTypes::OperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt; + return value ? OperatorFieldTypes::FusedActivationOperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt; } - inline OperatorFieldTypes::OperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count) + inline OperatorFieldTypes::FusedActivationOperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count) { - OperatorFieldTypes::OperatorDescArray field; + OperatorFieldTypes::FusedActivationOperatorDescArray field; if (values && count != 0) { field.emplace(count); @@ -65,13 +65,17 @@ namespace SchemaHelpers return value; } + inline OperatorFieldTypes::Bool ToOperatorFieldType(bool value) + { + return value; + } + inline OperatorFieldTypes::UIntArray ToOperatorFieldType(const uint32_t* values, uint32_t count) { OperatorFieldTypes::UIntArray field; if (values && count != 0) { - field.emplace(count); - std::copy_n(values, count, field->begin()); + field.assign(values, values + count); } return field; } @@ -81,8 +85,7 @@ namespace SchemaHelpers OperatorFieldTypes::IntArray field; if (values && count != 0) { - field.emplace(count); - std::copy_n(values, count, field->begin()); + field.assign(values, values + count); } return field; } @@ -92,8 +95,7 @@ namespace SchemaHelpers OperatorFieldTypes::FloatArray field; if (values && count != 0) { - field.emplace(count); - std::copy_n(values, count, field->begin()); + field.assign(values, values + count); } return field; } @@ -237,7 +239,7 @@ namespace SchemaHelpers { DML_OPERATOR_DESC* desc = nullptr; - const auto& value = field.AsOperatorDesc(); + const auto& value = field.AsFusedActivationOperatorDesc(); if (value) { desc = allocator->template Allocate(); @@ -251,7 +253,7 @@ namespace SchemaHelpers { DML_OPERATOR_DESC* descs = nullptr; - const auto& values = field.AsOperatorDescArray(); + const auto& values = field.AsFusedActivationOperatorDescArray(); if (values) { descs = allocator->template Allocate(values->size()); @@ -288,16 +290,20 @@ namespace SchemaHelpers dst->Write(value); } break; + case DML_SCHEMA_FIELD_TYPE_BOOL: + { + // OperatorFieldTypes::Bool is a 'bool' (1 byte) but written as 'BOOL' in op descs (4 bytes). + BOOL value = static_cast(field.AsBool()); + dst->Write(value); + } break; + case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY: { uint32_t* arrayPtr = nullptr; const auto& values = field.AsUIntArray(); - if (values) - { - arrayPtr = allocator->template Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } + arrayPtr = allocator->template Allocate(values.size()); + std::copy(values.begin(), values.end(), arrayPtr); dst->Write(arrayPtr); } break; @@ -307,11 +313,8 @@ namespace SchemaHelpers int32_t* arrayPtr = nullptr; const auto& values = field.AsIntArray(); - if (values) - { - arrayPtr = allocator->template Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } + arrayPtr = allocator->template Allocate(values.size()); + std::copy(values.begin(), values.end(), arrayPtr); dst->Write(arrayPtr); } break; @@ -321,11 +324,8 @@ namespace SchemaHelpers float* arrayPtr = nullptr; const auto& values = field.AsFloatArray(); - if (values) - { - arrayPtr = allocator->template Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } + arrayPtr = allocator->template Allocate(values.size()); + std::copy(values.begin(), values.end(), arrayPtr); dst->Write(arrayPtr); } break; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 2456b396de3f6..e6f008af5c23f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -33,10 +33,10 @@ namespace Dml::GraphDescBuilder #pragma warning(pop) static void RemoveUnconnectedNodes( - std::vector& graphNodes, - std::vector& graphInputEdges, - std::vector& graphIntermediateEdges, - std::vector& graphOutputEdges) + std::vector& graphNodes, + std::vector& graphInputEdges, + std::vector& graphIntermediateEdges, + std::vector& graphOutputEdges) { enum class NodeState { @@ -52,7 +52,7 @@ namespace Dml::GraphDescBuilder }; std::vector nodesData(graphNodes.size()); - for (const DML_INTERMEDIATE_GRAPH_EDGE_DESC& intermediateEdge : graphIntermediateEdges) + for (const DmlIntermediateSerializedGraphEdge& intermediateEdge : graphIntermediateEdges) { nodesData[intermediateEdge.ToNodeIndex].predecessorIndices.push_back(intermediateEdge.FromNodeIndex); } @@ -60,7 +60,7 @@ namespace Dml::GraphDescBuilder std::stack nodeIndicesToVisit; // Start from the outputs of the graph and traverse upwards - for (const DML_OUTPUT_GRAPH_EDGE_DESC& outputEdge : graphOutputEdges) + for (const DmlOutputSerializedGraphEdge& outputEdge : graphOutputEdges) { nodeIndicesToVisit.push(outputEdge.FromNodeIndex); } @@ -143,17 +143,44 @@ namespace Dml::GraphDescBuilder } } + + uint32_t SetAndGetDmlGraphNodeIndex( + const uint32_t operatorDmlGraphNodeIndex, + const std::string& nodeNamePrefix, + AbstractOperatorDesc& operatorDesc, + /*in_out*/std::unordered_map& operatorDmlGraphToDmlGraphNodeIndexMap, + /*in_out*/std::vector& dmlGraphNodes) + { + auto iter = operatorDmlGraphToDmlGraphNodeIndexMap.find(operatorDmlGraphNodeIndex); + if (iter != operatorDmlGraphToDmlGraphNodeIndexMap.end()) + { + return iter->second; + } + operatorDmlGraphToDmlGraphNodeIndexMap[operatorDmlGraphNodeIndex] = static_cast(dmlGraphNodes.size()); + dmlGraphNodes.push_back({operatorDesc, nodeNamePrefix + std::to_string(operatorDmlGraphNodeIndex)}); + return operatorDmlGraphToDmlGraphNodeIndexMap[operatorDmlGraphNodeIndex]; + } + + // Terminology: + // Subgraph: partitioned ONNX graph from the original (main) ONNX graph + // DmlGraph: a graph in DML currency converted from subgraph. + // operatorDmlGraph: a graph in DML currency for a given node or operator + // Main Points to note: + // - GraphDesc will always has sequential indices for input and intermediate edges. + // - 1 onnx node can be converted to one or more dml nodes. GraphDesc BuildGraphDesc( const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, - IDMLDevice* device, const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, - gsl::span subgraphOutputs) + gsl::span subgraphOutputs, + /*out*/ std::unordered_map& serializedGraphInputIndexToSubgraphInputIndex, + /*out*/ std::unordered_map& serializedGraphLargeConstantNameToSubgraphInputIndex, + /*out*/ std::vector>& smallConstantData) { struct NodeAndIndex { @@ -161,19 +188,34 @@ namespace Dml::GraphDescBuilder uint32_t targetIndex; // The index of the input/output on the node (e.g. 1 for the second input on a node) }; - // Map from Lotus node argument names to the new node and index where it will be produced - std::unordered_map nameToNodeAndIndexMap; - std::unordered_map nodeOutputShapes; - // Map from Lotus node argument names to input indices of the fused kernel node. - std::unordered_map nameToDmlFusedNodeInputIndex; + // Map from ORT subgraph input names to indices + std::unordered_map subgraphInputNameToIndexMap; + + // - Map from ORT node's output names to DmlGraph . + // - Once a given ORT node (or operator) will be transformed into a operatorDmlGraph, + // then ORT node's output names will become output edges for the operatorDmlGraph. + // - This map will be populated for those output edges. + std::unordered_map dmlGraphNodeOutputNameToNodeAndIndexMap; + + // This map will be used to re-index an subGraphInputIndex to sequential input index + // for DmlGraph + std::unordered_map subGraphInputIndexToDmlGraphInputIndex; + + // Iterate through each node and create a corresponding node in the new graph + // We can iterate the nodes in any order because the edge connectivity will take care of the topological order + std::unordered_map> inferredOutputShapes; + + std::vector dmlGraphNodes; + std::vector dmlGraphInputEdges; + std::vector dmlGraphIntermediateEdges; + std::vector dmlGraphOutputEdges; for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex) { - const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex]; - - if (!graphInput) + const onnxruntime::NodeArg* subgraphInput = subgraphInputs[inputIndex]; + if (!subgraphInput) { // This is a workaround for when node inputs get manipulated by transformers outside of our control, // which then causes them to have a different name. If that happens we can't figure out how to @@ -181,45 +223,21 @@ namespace Dml::GraphDescBuilder // just bail early. ORT_THROW_HR(E_UNEXPECTED); } - - nameToDmlFusedNodeInputIndex.emplace(graphInput->Name(), gsl::narrow_cast(inputIndex)); - } - - StackAllocator<1024> allocator; // Used for converting abstract operator descs into DML_OPERATOR_DESC - - std::vector graphNodes; - std::vector graphInputEdges; - std::vector graphIntermediateEdges; - std::vector graphOutputEdges; - - // Avoid using separate command lists for small graphs. This value can be reduced by tuning the - // flushing behavior of DmlCommandRecorder. Its current behavior is to assume that graphs contain - // enough GPU work to be worth flushing immediately. - const uint32_t minNodeCountToReuseCommandList = 5; - bool reuseCommandList = false; - - if (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice()) - { - reuseCommandList = true; + subgraphInputNameToIndexMap.emplace(subgraphInput->Name(), gsl::narrow_cast(inputIndex)); } auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName) { ComPtr tensorWrapper; - auto iter = isInitializerTransferable.find(argName); if (iter != isInitializerTransferable.end()) { // Using const_cast here is simpler than making surrounding code const correct. tensorWrapper = wil::MakeOrThrow(const_cast(iter->second.first), modelPath); } - return tensorWrapper; }; - // Iterate through each node and create a corresponding node in the new graph - // We can iterate the nodes in any order because the edge connectivity will take care of the topological order - std::unordered_map> inferredOutputShapes; for (const onnxruntime::Node* subgraphNode : subgraphNodes) { @@ -277,195 +295,206 @@ namespace Dml::GraphDescBuilder } EdgeShapes outputShapes; - DmlGraphNodeCreateInfo graphNodeCreateInfo; + DmlGraphNodeCreateInfo operatorDmlGraphCreateInfo; graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, &inputShapesOverrides, /*out*/ &outputShapes, - /*out*/ &graphNodeCreateInfo + /*out*/ &operatorDmlGraphCreateInfo ); ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size()); for (int i = 0; i < node.OutputDefs().size(); ++i) { inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i); - } - - // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex. - std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; - uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); - const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0; - size_t firstOpDescGraphNodeIndex = graphNodes.size(); - - if (isNodeAsOpDesc) + } + + // Algorithm: + // 1. Create constant nodes by iterating through operatorDmlGraph's input edges and keep a map of it, + // because there would be an intermediate edge from the constantNode and source of the intermediate edge + // should come before the destination. + // 2. Again iterate through operatorDmlGraph's input edges to create mainGraph's input and intermediate edges. + // 3. Iterate through operatorDmlGraph's intermediate edges to create mainGraph's intermediate edges. + // 4. Iterate through operatorDmlGraph's output edges to populate outputEdgeNameToDmlGraphNodeAndIndex + // 5. While performing step 2, 3, and 4, insert operatorDmlGraphNode to the mainDmlGraphNode list. + + for (auto& operatorDmlGraphInputEdge : operatorDmlGraphCreateInfo.inputEdges) { - // Can't populate graphNodes vector at this point, because operatorDesc may get modified later. - for (uint32_t nodeIndex = 0; nodeIndex < graphNodeCreateInfo.nodeCount; nodeIndex++) + const onnxruntime::NodeArg* arg = node.InputDefs()[operatorDmlGraphInputEdge.GraphInputIndex]; + if (arg->Exists()) { - ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]); - operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); - } + auto iter = subgraphInputNameToIndexMap.find(arg->Name()); + if (iter != subgraphInputNameToIndexMap.end() && + iter->second < isConstGpuGraphInputCount && + isConstGpuGraphInput[iter->second]) + { + DmlSerializedGraphNode constantNode = {}; + constantNode.Name = arg->Name(); + + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently + // only used for small inputs. + auto& operatorDmlGraphInputNode = operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex]; + std::vector toNodeInputTensorDescs = operatorDmlGraphInputNode->GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorDmlGraphInputEdge.ToNodeInputIndex]; + ComPtr constantInput; + + if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) + { + constantInput = constantCpuGraphInputGetter(arg->Name()); + } - graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount); - } - else - { - for (uint32_t nodeIndex = 0; nodeIndex < graphNodeCreateInfo.nodeCount; nodeIndex++) - { - ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get()); - operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); - NodeInfo nodeInfo = {}; - nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]); - graphNodes.push_back(std::move(nodeInfo)); + if (constantInput) + { + // The tensor description's size should be no larger than the constant input unless it was rounded to + // the required alignment. + assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); + size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); + auto data = static_cast(constantInput->GetData()); + std::vector tensorData(data, data + minimumConstantSize); + + smallConstantData.push_back(std::make_unique(tensorData.size())); + std::transform(tensorData.begin(), tensorData.end(), smallConstantData.back().get(), [](uint8_t b) {return static_cast(b);}); + + ConstantData constantData = {smallConstantData.back().get(), tensorData.size()}; + constantNode.Desc = constantData; + } + else + { + ConstantName constantFileName = {GetSanitizedFileName(arg->Name())}; + constantNode.Desc = constantFileName; + } + dmlGraphNodeOutputNameToNodeAndIndexMap[arg->Name()] = {static_cast(dmlGraphNodes.size()), 0}; + dmlGraphNodes.push_back(constantNode); + } } } - // map operatorGraphInputEdge as either mainGraphInputEdge or mainGraphIntermediateEdge - for (auto& operatorGraphInputEdge : graphNodeCreateInfo.inputEdges) - { - // operatorGraphInputEdge.GraphInputIndex will be the ONNX input index. - const onnxruntime::NodeArg* arg = node.InputDefs()[operatorGraphInputEdge.GraphInputIndex]; + // Create a map between operatorGraphNodeIndex to dmlGraphNodeIndex. + std::unordered_map operatorDmlGraphToDmlGraphNodeIndexMap; + // map operatorDmlGraphInputEdge as either mainDmlGraphInputEdge or mainDmlGraphIntermediateEdge + for (auto& operatorDmlGraphInputEdge : operatorDmlGraphCreateInfo.inputEdges) + { + // operatorDmlGraphInputEdge.GraphInputIndex will be the ONNX input index. + const onnxruntime::NodeArg* arg = node.InputDefs()[operatorDmlGraphInputEdge.GraphInputIndex]; if (arg->Exists()) { - auto iter = nameToDmlFusedNodeInputIndex.find(arg->Name()); - uint32_t mainGraphNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphInputEdge.ToNodeIndex]; - - if (iter != nameToDmlFusedNodeInputIndex.end()) + uint32_t dmlGraphNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorDmlGraphInputEdge.ToNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + + auto iter = subgraphInputNameToIndexMap.find(arg->Name()); + if (iter != subgraphInputNameToIndexMap.end()) { - // This is a graph input - - const uint32_t dmlFusedNodeInputIndex = iter->second; - - // If this is a constant input, set the appropriate flags on the desc - if (isNodeAsOpDesc && - dmlFusedNodeInputIndex < isConstGpuGraphInputCount && - isConstGpuGraphInput[dmlFusedNodeInputIndex]) + const uint32_t subgraphInputIndex = iter->second; + + // Either this edge will be + // a constant input, then it will be an intermediate edge and + // set the OWNED_BY_DML flag if it is large constant + // or, + // a non-constant input, then it will be a mainDmlGraphInputEdge. + if (subgraphInputIndex < isConstGpuGraphInputCount && + isConstGpuGraphInput[subgraphInputIndex]) { - // This is a highly inefficient approach to generating constant nodes. It duplicates constant data - // across the graph input as well as every consumer's unique constant node. However it is currently - // only used for small inputs. - uint32_t c_maxConstNodeDataSize = 8; - - - auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; - std::vector toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors(); - DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; - ComPtr constantInput; - - if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) - { - constantInput = constantCpuGraphInputGetter(arg->Name()); - } - - if (constantInput) - { - // The tensor description's size should be no larger than the constant input unless it was rounded to - // the required alignment. - assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); - size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); - auto data = static_cast(constantInput->GetData()); - std::vector tensorData(data, data + minimumConstantSize); - - NodeInfo nodeInfo = {}; - nodeInfo.nodeDef = std::move(tensorData); - graphNodes.push_back(std::move(nodeInfo)); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; - edge.FromNodeIndex = static_cast(graphNodes.size() - 1); - edge.FromNodeOutputIndex = 0; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphIntermediateEdges.push_back(edge); - } - else + const auto& constantNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(arg->Name()); + auto& constantNodeVariant = std::get(dmlGraphNodes[constantNodeAndIndex.nodeIndex].Desc); + if (std::holds_alternative(constantNodeVariant)) { - DML_INPUT_GRAPH_EDGE_DESC edge = {}; - edge.GraphInputIndex = dmlFusedNodeInputIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphInputEdges.push_back(edge); - + auto& mainDmlGraphNode = dmlGraphNodes[dmlGraphNodeIndex]; + AbstractOperatorDesc& abstractOperatorDesc = std::get(mainDmlGraphNode.Desc); + std::vector toNodeInputTensorDescs = abstractOperatorDesc.GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorDmlGraphInputEdge.ToNodeInputIndex]; tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; + serializedGraphLargeConstantNameToSubgraphInputIndex[arg->Name()] = subgraphInputIndex; } + + DmlIntermediateSerializedGraphEdge edge = {}; + edge.FromNodeIndex = constantNodeAndIndex.nodeIndex; + edge.FromNodeOutputIndex = constantNodeAndIndex.targetIndex; + edge.ToNodeIndex = dmlGraphNodeIndex; + edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex; + edge.Name = arg->Name() + "-nodeIdx:" + std::to_string(edge.FromNodeIndex) + "-outputIdx:" + std::to_string(edge.FromNodeOutputIndex); + dmlGraphIntermediateEdges.push_back(edge); } else { - DML_INPUT_GRAPH_EDGE_DESC edge = {}; - edge.GraphInputIndex = dmlFusedNodeInputIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphInputEdges.push_back(edge); + DmlInputSerializedGraphEdge edge = {}; + if (subGraphInputIndexToDmlGraphInputIndex.find(subgraphInputIndex) == subGraphInputIndexToDmlGraphInputIndex.end()) + { + subGraphInputIndexToDmlGraphInputIndex[subgraphInputIndex] = static_cast(subGraphInputIndexToDmlGraphInputIndex.size()); + } + + edge.GraphInputIndex = subGraphInputIndexToDmlGraphInputIndex[subgraphInputIndex]; + edge.ToNodeIndex = dmlGraphNodeIndex; + edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex + edge.Name = arg->Name(); + + serializedGraphInputIndexToSubgraphInputIndex[edge.GraphInputIndex] = subgraphInputIndex; + dmlGraphInputEdges.push_back(edge); } } else { - const auto& inputNodeAndIndex = nameToNodeAndIndexMap.at(arg->Name()); + const auto& inputNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(arg->Name()); - DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; + DmlIntermediateSerializedGraphEdge edge = {}; edge.FromNodeIndex = inputNodeAndIndex.nodeIndex; edge.FromNodeOutputIndex = inputNodeAndIndex.targetIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphIntermediateEdges.push_back(edge); + edge.ToNodeIndex = dmlGraphNodeIndex; + edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex; + edge.Name = arg->Name(); + dmlGraphIntermediateEdges.push_back(edge); } } } // map operatorGraphIntermediateEdges as mainGraphIntermediateEdge - for (auto& operatorGraphIntermediateEdge : graphNodeCreateInfo.intermediateEdges) + for (auto& operatorGraphIntermediateEdge : operatorDmlGraphCreateInfo.intermediateEdges) { - DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; - edge.FromNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphIntermediateEdge.FromNodeIndex]; + DmlIntermediateSerializedGraphEdge edge = {}; + uint32_t shiftedFromNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorGraphIntermediateEdge.FromNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorGraphIntermediateEdge.FromNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + uint32_t shiftedToNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorGraphIntermediateEdge.ToNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorGraphIntermediateEdge.ToNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + + edge.FromNodeIndex = shiftedFromNodeIndex; edge.FromNodeOutputIndex = operatorGraphIntermediateEdge.FromNodeOutputIndex; - edge.ToNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphIntermediateEdge.ToNodeIndex]; + edge.ToNodeIndex = shiftedToNodeIndex; edge.ToNodeInputIndex = operatorGraphIntermediateEdge.ToNodeInputIndex; - graphIntermediateEdges.push_back(edge); + edge.Name = "nodeIdx:" + std::to_string(shiftedFromNodeIndex) + "-outputIdx:" + std::to_string(operatorGraphIntermediateEdge.FromNodeOutputIndex); + dmlGraphIntermediateEdges.push_back(edge); } - + // populate nameToNodeAndIndexMap (which will be used by above loop) for operatorGraphOutputEdges - for (auto& operatorGraphOutputEdge : graphNodeCreateInfo.outputEdges) + for (auto& operatorGraphOutputEdge : operatorDmlGraphCreateInfo.outputEdges) { const onnxruntime::NodeArg* arg = node.OutputDefs()[operatorGraphOutputEdge.GraphOutputIndex]; if (arg->Exists()) { - nameToNodeAndIndexMap[arg->Name()] = NodeAndIndex { - operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex], - operatorGraphOutputEdge.FromNodeOutputIndex - }; - + uint32_t shiftedNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorGraphOutputEdge.FromNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorGraphOutputEdge.FromNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + dmlGraphNodeOutputNameToNodeAndIndexMap[arg->Name()] = {shiftedNodeIndex, operatorGraphOutputEdge.FromNodeOutputIndex}; nodeOutputShapes[arg->Name()] = outputShapes; } } - - if (isNodeAsOpDesc) - { - for (size_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i) - { - auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i]; - - DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator); - - // TODO: Change as new header is ingested - if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) - dmlDesc.Type = (DML_OPERATOR_TYPE) 169; - - // TODO: Change as new header is ingested - if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) - dmlDesc.Type = (DML_OPERATOR_TYPE) 170; - - ComPtr op; - ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); - allocator.Reset(); - - NodeInfo nodeInfo = {}; - nodeInfo.nodeDef = std::move(op); - nodeInfo.name = node.Name(); - graphNodes[firstOpDescGraphNodeIndex + i] = std::move(nodeInfo); - } - } } EdgeShapes graphOutputShapes(subgraphOutputs.size()); @@ -476,24 +505,27 @@ namespace Dml::GraphDescBuilder const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex]; ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); - const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); + const auto& outputNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(graphOutput->Name()); - DML_OUTPUT_GRAPH_EDGE_DESC edge = {}; + DmlOutputSerializedGraphEdge edge = {}; edge.FromNodeIndex = outputNodeAndIndex.nodeIndex; edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex; edge.GraphOutputIndex = gsl::narrow_cast(outputIndex); - graphOutputEdges.push_back(edge); + edge.Name = graphOutput->Name(); + dmlGraphOutputEdges.push_back(edge); graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex); } - RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges); + RemoveUnconnectedNodes(dmlGraphNodes, dmlGraphInputEdges, dmlGraphIntermediateEdges, dmlGraphOutputEdges); GraphDesc graphDesc{}; - graphDesc.nodes = std::move(graphNodes); - graphDesc.inputEdges = std::move(graphInputEdges); - graphDesc.outputEdges = std::move(graphOutputEdges); - graphDesc.intermediateEdges = std::move(graphIntermediateEdges); - graphDesc.reuseCommandList = reuseCommandList; + graphDesc.InputCount = static_cast(dmlGraphInputEdges.size()); + graphDesc.OutputCount = static_cast(subgraphOutputs.size()); + graphDesc.Nodes = std::move(dmlGraphNodes); + graphDesc.InputEdges = std::move(dmlGraphInputEdges); + graphDesc.OutputEdges = std::move(dmlGraphOutputEdges); + graphDesc.IntermediateEdges = std::move(dmlGraphIntermediateEdges); + graphDesc.reuseCommandList = (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice()); graphDesc.outputShapes = std::move(graphOutputShapes); return graphDesc; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index c95e89b45541b..4055984b40405 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -22,22 +22,15 @@ namespace Dml namespace GraphDescBuilder { + constexpr uint32_t minNodeCountToReuseCommandList = 5; + constexpr uint32_t c_maxConstNodeDataSize = 8; + // Gets a unique name for the node which survives recreation and graph manipulations between the point // that graph partitioning occurs and kernel creation happens const std::string& GetUniqueNodeName(const onnxruntime::Node& node); - struct NodeInfo - { - std::variant, std::vector> nodeDef; - std::string name; - }; - - struct GraphDesc + struct GraphDesc : DmlSerializedGraphDesc { - std::vector nodes; - std::vector inputEdges; - std::vector outputEdges; - std::vector intermediateEdges; bool reuseCommandList; Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes; }; @@ -47,11 +40,13 @@ namespace Dml const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, - IDMLDevice* device, const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, - gsl::span subgraphOutputs); + gsl::span subgraphOutputs, + /*out*/ std::unordered_map& serializedGraphInputIndexToSubgraphInputIndex, + /*out*/ std::unordered_map& serializedGraphLargeConstantNameToSubgraphInputIndex, + /*out*/ std::vector>& smallConstantData); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index d524780de71b8..f29fbc7a1a65b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1508,31 +1508,17 @@ namespace Windows::AI::MachineLearning::Adapter ORT_TRY { assert(operatorGraphDesc != nullptr); - // Either nodesAsOpDesc or nodesIDMLOperator can be present. - assert(operatorGraphDesc->nodeCount == 0 || (!operatorGraphDesc->nodesAsOpDesc ^ !operatorGraphDesc->nodesAsIDMLOperator)); + assert(operatorGraphDesc->nodeCount == 0 || operatorGraphDesc->nodes); - if (operatorGraphDesc->nodesAsOpDesc) + m_graphNodeCreateInfo->nodes = std::vector>(); + for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) { - m_graphNodeCreateInfo->nodesAsOperatorDesc = std::vector>(); - for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) - { - auto* node = operatorGraphDesc->nodesAsOpDesc[nodeIndex]; - assert(node != nullptr); - AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node); - m_graphNodeCreateInfo->nodesAsOperatorDesc.push_back(std::make_unique(std::move(abstractDesc))); - } - } - else - { - m_graphNodeCreateInfo->nodesAsIDMLOperator = std::vector>(); - for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) - { - auto* node = operatorGraphDesc->nodesAsIDMLOperator[nodeIndex]; - assert(node != nullptr); - m_graphNodeCreateInfo->nodesAsIDMLOperator.push_back(node); - } + auto* node = operatorGraphDesc->nodes[nodeIndex]; + assert(node != nullptr); + AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node); + m_graphNodeCreateInfo->nodes.push_back(std::make_unique(std::move(abstractDesc))); } - + // There can be operators (or kernels) which don't require any input. assert(operatorGraphDesc->inputEdgeCount == 0 || operatorGraphDesc->inputEdges != nullptr); m_graphNodeCreateInfo->inputEdges.insert( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index c3bb1a52210f5..287f1e5b6dfe7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -53,7 +53,7 @@ namespace Dml MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = 1; const DML_OPERATOR_DESC* opDescs{&operatorDesc}; - operatorGraphDesc.nodesAsOpDesc = &opDescs; + operatorGraphDesc.nodes = &opDescs; std::vector inputEdges; for (uint32_t inputIndex = 0; inputIndex < m_kernelInputIndices.size(); inputIndex++) @@ -796,7 +796,7 @@ namespace Dml for (size_t i = 0; i < graphDesc.NodeCount; ++i) { // Create the operator. - ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodesAsOpDesc[i], IID_PPV_ARGS(&dmlOperators[i]))); + ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodes[i], IID_PPV_ARGS(&dmlOperators[i]))); dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{dmlOperators[i].Get()}; dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index c8ca6806e75f7..73c2d57e984af 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -531,7 +531,7 @@ class DmlOperatorAttention : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp index 1c851c94c4ddc..5aceebbdabfe3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp @@ -103,7 +103,7 @@ class DmlOperatorBiasAdd : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp index 501ce14f1fc08..1e10214ffd463 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp @@ -137,7 +137,7 @@ class DmlOperatorBiasSplitGelu : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp index 6a8333cd72561..3c9458658c4d0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp @@ -484,7 +484,7 @@ class DmlOperatorEmbedLayerNormalization : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp index fed0e4645ffd8..8b275fc550f3e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp @@ -287,7 +287,7 @@ class DmlOperatorGroupNorm : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp index 5c64059f7caa9..80e6fefc2fb80 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp @@ -247,7 +247,7 @@ class DmlOperatorLayerNormalization : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp index c97b03dc36b62..8727610ff3112 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp @@ -166,7 +166,7 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = static_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2; uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp index 35f926d62c92a..f658e7c7da323 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp @@ -113,7 +113,7 @@ class DmlOperatorQLinearSigmoid : public DmlOperator MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = 3; std::vector opDescs{&opDesc1, &opDesc2, &opDesc3}; - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); // set input edges std::pair nodeToNodeInputIndex[5] {{0, 0}, {0, 1}, {0, 2}, {2, 1}, {2, 2}}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp index 3683ab7b0b0b3..e62b7d707ba78 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp @@ -123,7 +123,7 @@ class DmlOperatorQuickGelu : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 44004b5d77f70..0f15ebf342b3a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -441,7 +441,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp index 4dafd78f21ea8..094c45a0e38e5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp @@ -198,7 +198,7 @@ class DmlOperatorSkipLayerNormalization : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h new file mode 100644 index 0000000000000..02166f992449e --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include + + +namespace Dml +{ + static inline std::wstring ConvertToWString(std::string_view str) + { + std::wstring_convert,wchar_t> g_converterToUtf16; + return g_converterToUtf16.from_bytes(str.data()); + } + + static inline std::wstring GetModelName(const onnxruntime::Path& modelPath) + { + if (modelPath.GetComponents().empty()) + { + return L""; + } + + const onnxruntime::PathString& pathString = modelPath.GetComponents().back(); + size_t dotPosition = pathString.find_last_of('.'); + if (dotPosition == std::string::npos) + { + return L""; + } + + return pathString.substr(0, dotPosition); + } + + static inline std::wstring GetSanitizedFileName(std::wstring_view name) + { + std::wstring newName(name); + for (wchar_t& c : newName) + { + switch (c) + { + case '\\': + case '/': + case '\"': + case '|': + case '<': + case '>': + case ':': + case '?': + case '*': + c = '_'; + break; + } + } + return newName; + } + + static inline std::string GetSanitizedFileName(std::string_view name) + { + std::string newName(name); + for (char& c : newName) + { + switch (c) + { + case '\\': + case '/': + case '\"': + case '|': + case '<': + case '>': + case ':': + case '?': + case '*': + c = '_'; + break; + } + } + return newName; + } + + static inline void WriteToFile(std::wstring_view directoryName, std::wstring_view fileName, std::uint8_t* data, size_t dataSize) + { + std::wstring sanitizedFileName = GetSanitizedFileName(fileName); + std::filesystem::create_directory(directoryName); + std::wstring fullSanitizedFileName = std::wstring(directoryName) + + (directoryName.empty() ? L"" : L"/") + + sanitizedFileName; + std::ofstream file(fullSanitizedFileName, std::ios::binary); + if (!file.is_open()) + { + std::wstring_convert,wchar_t> g_converterToUtf16; + std::stringstream errorMessage; + errorMessage << "File named: " << g_converterToUtf16.to_bytes(fileName.data()) << " could not be opened\n"; + throw std::ios::failure(errorMessage.str()); + } + file.write(reinterpret_cast(data), dataSize); + } + +} + +namespace StringUtil +{ + struct NameAndIndex + { + const char* name; // Null terminated. + uint32_t index; + }; + + struct WideNameAndIndex + { + const wchar_t* name; // Null terminated. + uint32_t index; + }; + + inline std::optional MapToIndex(std::string_view mode, gsl::span nameAndIndexList) + { + for (auto& nameAndIndex : nameAndIndexList) + { + if (strncmp(nameAndIndex.name, mode.data(), mode.size()) == 0) + { + return nameAndIndex.index; + } + } + + return {}; + } + + inline std::optional MapToIndex(std::wstring_view mode, gsl::span nameAndIndexList) + { + for (auto& nameAndIndex : nameAndIndexList) + { + if (wcsncmp(nameAndIndex.name, mode.data(), mode.size()) == 0) + { + return nameAndIndex.index; + } + } + + return {}; + } +} \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h index 83737d2ba4848..332bf86685e8a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h @@ -17,6 +17,8 @@ #include #include #include +#include +#include #include #include @@ -37,6 +39,7 @@ #include #include "External/D3DX12/d3dx12.h" #endif +#include "flatbuffers/flatbuffers.h" #include "GraphicsUnknownHelper.h" @@ -53,6 +56,9 @@ #include "External/DirectMLHelpers/SchemaHelpers.h" #include "External/DirectMLHelpers/GeneratedSchemaHelpers.h" #include "External/DirectMLHelpers/DirectMLX.h" +#include "External/DirectMLHelpers/DmlSerializedGraphDesc.h" +#include "External/DirectMLHelpers/DmlGraphSerialization.h" +#include "External/DirectMLHelpers/DmlGraphDeserialization.h" using Microsoft::WRL::ComPtr; @@ -67,3 +73,4 @@ using Microsoft::WRL::ComPtr; #include "TensorDesc.h" #include "DescriptorPool.h" #include "IExecutionProvider.h" +#include "Utility.h" \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index 3bec8d3864cba..ac3a3eb1268b8 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -10,18 +10,11 @@ struct DML_INPUT_GRAPH_EDGE_DESC; struct DML_OUTPUT_GRAPH_EDGE_DESC; struct DML_INTERMEDIATE_GRAPH_EDGE_DESC; -// Either nodesAsOpDesc or nodesAsIDMLOperator is present. -// 1) Operator kernels which implement operators using only a single DML operator will pass a DML_OPERATOR_DESC. -// These kernels pass DML_OPERATOR_DESC, because while building Dml graph (inside FusedGraphKernel.cpp) we can change the -// the flag of constant inputs to DML_TENSOR_FLAG_OWNED_BY_DML. -// 2) Operator kernels which implement operators using DMLX graph, they will pass IDMLOperator and won't be able -// to use DML_TENSOR_FLAG_OWNED_BY_DML. struct MLOperatorGraphDesc { uint32_t nodeCount; - _Field_size_opt_(nodeCount) const DML_OPERATOR_DESC** nodesAsOpDesc; - _Field_size_opt_(nodeCount) IDMLOperator** nodesAsIDMLOperator; - + _Field_size_opt_(nodeCount) const DML_OPERATOR_DESC** nodes; + uint32_t inputEdgeCount; _Field_size_(inputEdgeCount) const DML_INPUT_GRAPH_EDGE_DESC* inputEdges; diff --git a/onnxruntime/core/providers/dml/dml_session_options_config_keys.h b/onnxruntime/core/providers/dml/dml_session_options_config_keys.h index d11fa7516e713..5b5f371f51616 100644 --- a/onnxruntime/core/providers/dml/dml_session_options_config_keys.h +++ b/onnxruntime/core/providers/dml/dml_session_options_config_keys.h @@ -21,3 +21,4 @@ // "1": disabled (disallowed). Graph fusion will never be used. // The default value is "0" static const char* const kOrtSessionOptionsConfigDisableDmlGraphFusion = "ep.dml.disable_graph_fusion"; +static const char* const kOrtSessionOptionsConfigEnableGraphSerialization = "ep.dml.enable_graph_serialization"; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index efd7db4ea7629..5fd66c459d382 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1725,10 +1725,17 @@ common::Status InferenceSession::Initialize() { // graph optimization level and is generally always applied. bool dml_graph_fusion_enabled = session_options_.optimized_model_filepath.empty() && session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigDisableDmlGraphFusion, "0") == "0"; + std::string dml_graph_serialization_enabled_config_val = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigEnableGraphSerialization, "0"); + std::transform(dml_graph_serialization_enabled_config_val.begin(), + dml_graph_serialization_enabled_config_val.end(), + dml_graph_serialization_enabled_config_val.begin(), + [](char ch) { return std::tolower(ch); }); + bool dml_graph_serialization_enabled = dml_graph_serialization_enabled_config_val == "true"; if (dml_graph_fusion_enabled) { std::unique_ptr dmlGraphFusionTransformer = std::make_unique("DmlGraphFusionTransformer", - dmlExecutionProvider); + dmlExecutionProvider, + dml_graph_serialization_enabled); if (dmlGraphFusionTransformer == nullptr) { return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr"); } diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 3874901f86387..7d4111e3b9c39 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -68,6 +68,7 @@ namespace perftest { "\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" "\t [DML only] [disable_metacommands]: Options: 'true', 'false', \n" "\t [DML only] [enable_dynamic_graph_fusion]: Options: 'true', 'false', \n" + "\t [DML only] [enable_graph_serialization]: Options: 'true', 'false', \n" "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" "\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 87506c7240578..1934314b8ce43 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -18,6 +18,7 @@ #ifdef USE_DML #include "core/providers/dml/dml_provider_factory.h" +#include "core/providers/dml/dml_session_options_config_keys.h" #endif #ifdef _WIN32 @@ -542,6 +543,15 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); "[ERROR] [DML] You have selcted wrong value for the key 'enable_dynamic_graph_fusion'. " "Select from 'true' or 'false' \n"); } + } else if (key == "enable_graph_serialization") { + std::set ov_supported_values = {"true", "True", "false", "False"}; + if (ov_supported_values.find(value) != ov_supported_values.end()) { + session_options.AddConfigEntry(kOrtSessionOptionsConfigEnableGraphSerialization, value.data()); + } else { + ORT_THROW( + "[ERROR] [DML] You have selcted wrong value for the key 'enable_graph_serialization'. " + "Select from 'true' or 'false' \n"); + } } } session_options.AppendExecutionProvider("DML", dml_options); From 8bd943be39301639e3f50f524f8fd71c7f2b2a34 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 27 Feb 2024 09:31:32 +1000 Subject: [PATCH 148/207] Retry flaky XCode iOS UI tests if we get a known error (#19639) ### Description Xcode UI tests seem to be flaky: https://github.com/orgs/community/discussions/68807 Add a couple of retries if we get a "Timed out while loading Accessibility." error which is transient. ### Motivation and Context --- .../github/apple/test_apple_packages.py | 61 ++++++++++++++----- 1 file changed, 45 insertions(+), 16 deletions(-) diff --git a/tools/ci_build/github/apple/test_apple_packages.py b/tools/ci_build/github/apple/test_apple_packages.py index cd360a63a3a0f..3c0df994ffd3d 100644 --- a/tools/ci_build/github/apple/test_apple_packages.py +++ b/tools/ci_build/github/apple/test_apple_packages.py @@ -130,22 +130,51 @@ def _test_apple_packages(args): simulator_device_info = json.loads(simulator_device_info) - subprocess.run( - [ - "xcrun", - "xcodebuild", - "test", - "-workspace", - "./apple_package_test.xcworkspace", - "-scheme", - "ios_package_test", - "-destination", - f"platform=iOS Simulator,id={simulator_device_info['device_udid']}", - ], - shell=False, - check=True, - cwd=target_proj_path, - ) + # Xcode UI tests seem to be flaky: https://github.com/orgs/community/discussions/68807 + # Add a couple of retries if we get this error: + # ios_package_testUITests-Runner Failed to initialize for UI testing: + # Error Domain=com.apple.dt.XCTest.XCTFuture Code=1000 "Timed out while loading Accessibility." + attempts = 0 + cmd = [ + "xcrun", + "xcodebuild", + "test", + "-workspace", + "./apple_package_test.xcworkspace", + "-scheme", + "ios_package_test", + "-destination", + f"platform=iOS Simulator,id={simulator_device_info['device_udid']}", + ] + + while True: + attempts += 1 + completed_process = subprocess.run( + cmd, + shell=False, + capture_output=True, + check=False, + text=True, + cwd=target_proj_path, + ) + + # print so it's in CI output + print(completed_process.stdout) + + if completed_process.returncode != 0: + print(f"Running ios_package_test failed. Return code was {completed_process.returncode}") + print("xcrun xcodebuild test stderr:") + print(completed_process.stderr) + print("---") + + if "Timed out while loading Accessibility" in completed_process.stderr and attempts < 3: + continue + + raise subprocess.CalledProcessError( + completed_process.returncode, " ".join(cmd), completed_process.stdout, completed_process.stderr + ) + + break if PackageVariant[args.variant] != PackageVariant.Mobile and not args.skip_macos_test: subprocess.run( From 18c8fab1ae03e68a906fe42698ac322e9e49e218 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Mon, 26 Feb 2024 15:58:09 -0800 Subject: [PATCH 149/207] Fix a bug in build.py (#19652) ### Description Fix a bug in build.py that accidentally disabled C# tests for most builds when "--build_nuget" is specified. ### Motivation and Context The bug was introduced in PR #8892 . --- tools/ci_build/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 5b715bb29e5a1..74c473d34f548 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2592,7 +2592,7 @@ def main(): raise BuildError("Using --get-api-doc requires a single build config") # Disabling unit tests for GPU on nuget creation - if args.use_openvino != "CPU_FP32" and args.build_nuget: + if args.use_openvino and args.use_openvino != "CPU_FP32" and args.build_nuget: args.test = False # GDK builds don't support testing From 8a71b657654d63437267014b324bf124a80de347 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Tue, 27 Feb 2024 11:35:27 +1000 Subject: [PATCH 150/207] Remove skipping of Reshape from NNAPI EP (#19618) ### Description A number of Qualcomm Snapdragon chipsets do not produce correct output if we skip the Reshape, which ironically was a performance optimization for Snapdragon chips. Perf testing showed that Squeeze also seems to execute on CPU so there's no benefit to using that as an alternative where possible e.g. Global*Pool -> Reshape to 2D -> Gemm could be potentially be replaced with Global*Pool -> Squeeze dims 2 and 3 -> Gemm if that offered better performance. ### Motivation and Context #19518 --- .../builders/op_builder_helpers.cc | 30 ++++++++++++++----- .../builders/op_builder_helpers.h | 3 -- 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc index a066c64dac67d..466865f23f49a 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.cc @@ -965,6 +965,18 @@ Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, return Status::OK(); } +// NOTE: Skipping Reshape results in invalid output on some SnapDragon chipsets. Whilst the NNAPI spec says the input +// to FullyConnnected can be > 2D, those chipsets don't handle this correctly. +// +// CanSkipReshape could potentially be re-enabled in the future if we no longer want to support those old chipsets. +// However, the Reshape of newer chipsets may not run on CPU so there may not be a performance issue to try and avoid, +// so CanSkipReshape could be redundant anyway. +// +// Known bad chipsets: Qualcomm Snapdragon 850, 855, 865, 870. +// +// See https://github.com/microsoft/onnxruntime/issues/19518 + +/* // We can skip the Reshape if all the output edges satisfies both the following conditions // 1. The output of the reshape/flatten is not an output of the graph // 2. The output of the reshape/flatten is the input 0 of one or more GEMM/Matmul operators, @@ -977,7 +989,7 @@ Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, // between NNAPI CPU impl and Hardware Accelerator impl and will speed up the execution // If we are going to skip the reshape, we will still add correct shape and operand type for the output in // onnxruntime::nnapi::Model. -bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit, +static bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit, size_t input_rank, size_t output_rank) { // Since we know this is a Reshape NodeUnit, so we can safely assume there is only 1 output // and the node_unit has only one output node. @@ -1039,33 +1051,37 @@ bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit << node_unit.Name() << "] with output, " << output_name; return true; } +*/ Status AddReshapeOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, const std::string& input, const std::vector& shape) { auto& shaper(model_builder.GetShaper()); - const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); const auto& output = node_unit.Outputs()[0].node_arg.Name(); const auto input_shape = shaper[input]; const auto output_shape = shaper[output]; - const auto input_rank = input_shape.size(); - const auto output_rank = output_shape.size(); // For reshape, the output type should be the same as the input type except the shape is different auto output_operand_type = operand_types.at(input); output_operand_type.SetDimensions(output_shape); + /* See CanSkipReshape definition above for explanation of why this is disabled. // Since Reshape is not running using hardware in NNAPI for some CPU (e.g. Qualcomm SD for now) // We will try to see if we the skip the Reshape to prevent context switching between // NNAPI CPU impl and NNAPI hardware accelerator impl if (CanSkipReshape(model_builder, node_unit, input_rank, output_rank)) { - // Since reshape can be skipped, only register the dimension and type, with same index and new name + const auto& operand_indices(model_builder.GetOperandIndices()); + const auto input_rank = input_shape.size(); + const auto output_rank = output_shape.size(); + // Since reshape can be skipped, only register the dimension and type, with same index and new name. + // This essentially redirects the downstream operator builders to the input of the skipped Reshape node, + // but with the output shape of the Reshape node. model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type); - } else { - // We still need to perform a reshape here + } else */ + { std::string shape_name = model_builder.GetUniqueName(node_unit.Name() + input + "newshape"); ORT_RETURN_IF_ERROR(op_builder_helpers::AddNnapiReshape(model_builder, input, shape_name, shape, output)); } diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h index 7ccf4c1ef7555..61a16ceff752f 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_builder_helpers.h @@ -181,9 +181,6 @@ Status AddMinMaxOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, Status AddReshapeOperator(ModelBuilder& model_builder, const NodeUnit& node_unit, const std::string& input, const std::vector& shape); -bool CanSkipReshape(const ModelBuilder& model_builder, const NodeUnit& node_unit, - size_t input_rank, size_t output_rank); - Status GetAxesForSqueezeAndUnSqueeze(ModelBuilder& model_builder, const NodeUnit& node_unit, std::vector& axes); From 6f566562cedff9996e55dbf623b1f0141733d52c Mon Sep 17 00:00:00 2001 From: kailums <109063327+kailums@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:31:03 +0800 Subject: [PATCH 151/207] support user_compute_stream for rocm ep (#19619) ### Description According to the pr #19229 supporting cuda EP use external compute stream, we add support for rocm EP. And when we testing this feature with torch, we found torch use stream 0 for the default stream, and `torch.cuda.current_stream()` returns `0` for current stream, but ort treat `0` or `nullptr` as invalid, and reset has_user_compute_stream to false. Will remove has_user_compute_stream option in the future. ### Motivation and Context The motivation for this pr is that we want to use torch.cuda.graph to capture ort running kernel, which requires torch and ort are running in the same stream, so we use this API to set ort's working stream. --- .../rocm/rocm_execution_provider_info.cc | 20 +++++++++++++++++++ .../test/python/onnxruntime_test_python.py | 10 ++++++++++ 2 files changed, 30 insertions(+) diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index b557f92287f2b..3cb826437a54f 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -13,6 +13,8 @@ namespace onnxruntime { namespace rocm { namespace provider_option_names { constexpr const char* kDeviceId = "device_id"; +constexpr const char* kHasUserComputeStream = "has_user_compute_stream"; +constexpr const char* kUserComputeStream = "user_compute_stream"; constexpr const char* kMemLimit = "gpu_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; constexpr const char* kMiopenConvExhaustiveSearch = "miopen_conv_exhaustive_search"; @@ -38,6 +40,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P void* alloc = nullptr; void* free = nullptr; void* empty_cache = nullptr; + void* user_compute_stream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -52,6 +55,15 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); return Status::OK(); }) + .AddAssignmentToReference(rocm::provider_option_names::kHasUserComputeStream, info.has_user_compute_stream) + .AddValueParser( + rocm::provider_option_names::kUserComputeStream, + [&user_compute_stream](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + user_compute_stream = reinterpret_cast(address); + return Status::OK(); + }) .AddValueParser( rocm::provider_option_names::kGpuExternalAlloc, [&alloc](const std::string& value_str) -> Status { @@ -108,12 +120,18 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P ROCMExecutionProviderExternalAllocatorInfo alloc_info{alloc, free, empty_cache}; info.external_allocator_info = alloc_info; + + info.user_compute_stream = user_compute_stream; + info.has_user_compute_stream = (user_compute_stream != nullptr); + return info; } ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecutionProviderInfo& info) { const ProviderOptions options{ {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {rocm::provider_option_names::kGpuExternalAlloc, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.alloc))}, {rocm::provider_option_names::kGpuExternalFree, MakeStringWithClassicLocale(reinterpret_cast(info.external_allocator_info.free))}, @@ -135,6 +153,8 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const OrtROCMProviderOptions& info) { const ProviderOptions options{ {rocm::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {rocm::provider_option_names::kHasUserComputeStream, MakeStringWithClassicLocale(info.has_user_compute_stream)}, + {rocm::provider_option_names::kUserComputeStream, MakeStringWithClassicLocale(reinterpret_cast(info.user_compute_stream))}, {rocm::provider_option_names::kMemLimit, MakeStringWithClassicLocale(info.gpu_mem_limit)}, {rocm::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, static_cast(info.arena_extend_strategy))}, {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 91b6c71e735a8..ab56f3fa0f37f 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -559,6 +559,16 @@ def test_get_and_set_option_with_values(option_name, option_values): test_get_and_set_option_with_values("enable_hip_graph", ["1", "0"]) + # test for user_compute_stream + option = options["ROCMExecutionProvider"] + option["user_compute_stream"] = "1" + sess.set_providers(["ROCMExecutionProvider"], [option]) + new_options = sess.get_provider_options() + new_option = new_options["ROCMExecutionProvider"] + self.assertEqual(new_option["user_compute_stream"], "1") + # set user_compute_stream will set has_user_compute_stream to 1 too + self.assertEqual(new_option["has_user_compute_stream"], "1") + run_rocm_options_test() def test_invalid_set_providers(self): From 5bb58a10e739f8720e9867d19c4313081b12d948 Mon Sep 17 00:00:00 2001 From: Rachel Guo <35738743+YUNQIUGUO@users.noreply.github.com> Date: Mon, 26 Feb 2024 20:00:14 -0800 Subject: [PATCH 152/207] Enable the most verbose logging level in detox E2E React Native CI (#19659) ### Description The RN CI has intermittent failure error with "app seems to idle". enable the most verbose logging level (and can add steps to dump device.log from the detox folder/artifacts if necessary) to at least get more information. ### Motivation and Context --------- Co-authored-by: rachguo --- .../github/azure-pipelines/templates/react-native-ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 47cd72f412c67..1b7962059e301 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -279,7 +279,7 @@ stages: - script: | JEST_JUNIT_OUTPUT_FILE=$(Build.SourcesDirectory)/js/react_native/e2e/android-test-results.xml \ - detox test --record-logs all --configuration android.emu.release + detox test --record-logs all --configuration android.emu.release --loglevel trace workingDirectory: '$(Build.SourcesDirectory)/js/react_native/e2e' displayName: Run React Native Detox Android e2e Tests @@ -329,7 +329,7 @@ stages: - script: | JEST_JUNIT_OUTPUT_FILE=$(Build.SourcesDirectory)/js/react_native/e2e/ios-test-results.xml \ - detox test --record-logs all --configuration ios.sim.release + detox test --record-logs all --configuration ios.sim.release --loglevel trace workingDirectory: '$(Build.SourcesDirectory)/js/react_native/e2e' displayName: Run React Native Detox iOS e2e Tests From 9e19684944adfda4a414fc91a67259894fce2898 Mon Sep 17 00:00:00 2001 From: duanshengliu <44742794+duanshengliu@users.noreply.github.com> Date: Tue, 27 Feb 2024 12:56:32 +0800 Subject: [PATCH 153/207] Fix the TypeError issue in quantize.py (#19459) ### Description Fix related bug as described in https://github.com/microsoft/onnxruntime/issues/19430 --- onnxruntime/python/tools/quantization/quantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 1bd2ef42151d0..05d3ac248c92c 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -479,7 +479,7 @@ def inc_dataloader(): del dataloader model = sq.transform(extra_options.get("SmoothQuantAlpha", 0.5), extra_options.get("SmoothQuantFolding", True)) sq_path = tempfile.TemporaryDirectory(prefix="ort.quant.") - model_input = Path(sq_path).joinpath("sq_model.onnx").as_posix() + model_input = Path(sq_path.name).joinpath("sq_model.onnx").as_posix() model.save(model_input) nodes_to_exclude.extend([i.name for i in model.model.graph.node if i.name not in orig_nodes]) model = load_model_with_shape_infer(Path(model_input)) # use smooth quant model for calibration From 1e69b612382205b0588f08d2b808b12e32a50a51 Mon Sep 17 00:00:00 2001 From: cloudhan Date: Tue, 27 Feb 2024 16:06:06 +0800 Subject: [PATCH 154/207] Make version string detection more robust (#19615) `/opt/rocm/.info/version-dev` is only available if the `rocm-dev` metapackage is installed. This will bring a lot of unused packages which are not needed by the users, they may opt for fine grained control. Fallback to `rocm_version.h` in case `rocm-dev` is not installed. --- cmake/CMakeLists.txt | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ed9043f2adc4a..1376c90fbcefe 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -324,15 +324,27 @@ if (onnxruntime_USE_ROCM) endif() # replicate strategy used by pytorch to get ROCM_VERSION - # https://github.com/pytorch/pytorch/blob/8eb21488fdcdb8b0e6fa2e46179b5fa6c42e75af/cmake/public/LoadHIP.cmake#L153-L173 - file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW) - string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW}) - if (ROCM_VERSION_DEV_MATCH) + # https://github.com/pytorch/pytorch/blob/5c5b71b6eebae76d744261715231093e62f0d090/cmake/public/LoadHIP.cmake + # with modification + if (EXISTS "${onnxruntime_ROCM_HOME}/.info/version-dev") + file(READ "${onnxruntime_ROCM_HOME}/.info/version-dev" ROCM_VERSION_DEV_RAW) + string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+)-.*$" ROCM_VERSION_MATCH ${ROCM_VERSION_DEV_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm_version.h") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + elseif (EXISTS "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h") + file(READ "${onnxruntime_ROCM_HOME}/include/rocm-core/rocm_version.h" ROCM_VERSION_H_RAW) + string(REGEX MATCH "\"([0-9]+)\.([0-9]+)\.([0-9]+).*\"" ROCM_VERSION_MATCH ${ROCM_VERSION_H_RAW}) + endif() + + if (ROCM_VERSION_MATCH) set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") + else() + message(FATAL_ERROR "Cannot determine ROCm version string") endif() message("\n***** ROCm version from ${onnxruntime_ROCM_HOME}/.info/version-dev ****\n") message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") From 4838cb6b3e98273fcdd6a3e54da74cd584167780 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Tue, 27 Feb 2024 02:27:35 -0800 Subject: [PATCH 155/207] [QNN Quantization] Ensure fused nodes have names (#19650) ### Description - Updates the `qnn_preprocess_model()` method to set a name for any new nodes added to the graph (due to fusion). - Updates the `qnn_preprocess_model()` method to set a name for any unnamed nodes that previously existed in the original graph. - Adds unit tests for fusions (previously missing) - Checks that fused node names exist and are unique - Checks that fused graph is equivalent to original graph ### Motivation and Context Nodes are not strictly required to have names. However, a planned/upcoming feature to support mixed-precision (integer) quantized models needs nodes to have names. --- .../execution_providers/qnn/fusion_lpnorm.py | 7 +- .../execution_providers/qnn/preprocess.py | 11 + .../tools/quantization/fusions/fusion.py | 15 + .../tools/quantization/fusions/fusion_gelu.py | 25 +- .../quantization/fusions/fusion_layernorm.py | 1 + .../python/tools/quantization/onnx_model.py | 17 + .../test/python/quantization/test_fusions.py | 401 ++++++++++++++++++ 7 files changed, 465 insertions(+), 12 deletions(-) create mode 100644 onnxruntime/test/python/quantization/test_fusions.py diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py index 9ebf400498e0e..fbf954febdda4 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/fusion_lpnorm.py @@ -122,6 +122,11 @@ def fuse( self.nodes_to_remove.extend(subgraph_nodes) fused_node = onnx.helper.make_node( - self.fused_op_type, inputs=[subgraph_input], outputs=[subgraph_output], p=2, axis=-1 + self.fused_op_type, + name=self.create_unique_node_name(), + inputs=[subgraph_input], + outputs=[subgraph_output], + p=2, + axis=-1, ) self.nodes_to_add.append(fused_node) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index becbaceab184e..b1c114fe1f9fd 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -44,6 +44,17 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: if fusion_layernorm.apply(): modified = True + # Make sure all nodes have a name. + unnamed_node_prefix = "qnn_preproc_node_" + available_suffix = onnx_model.get_largest_node_name_suffix(unnamed_node_prefix) + 1 + for node in onnx_model.model.graph.node: + if node.op_type != "Constant" and not node.name: + new_node_name = f"{unnamed_node_prefix}{available_suffix!s}" + available_suffix += 1 + node.name = new_node_name + modified = True + logging.warning(f"Node of type {node.op_type} does not have a name. Renamed to {new_node_name}.") + if modified: onnx_model.topological_sort() onnx.save_model(model, model_output) diff --git a/onnxruntime/python/tools/quantization/fusions/fusion.py b/onnxruntime/python/tools/quantization/fusions/fusion.py index b54b421226f1a..4bdc5c26cc946 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion.py @@ -24,6 +24,9 @@ def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str): self.nodes_to_remove: list = [] self.nodes_to_add: list = [] + self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_" + self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops. + def fuse( self, node: onnx.NodeProto, @@ -57,6 +60,18 @@ def apply(self) -> bool: return graph_updated + def create_unique_node_name(self): + prefix = self._new_node_name_prefix + + if self._new_node_name_suffix is None: + largest_suffix: int = self.model.get_largest_node_name_suffix(prefix) + self._new_node_name_suffix = largest_suffix + 1 + + new_name = f"{prefix}{self._new_node_name_suffix!s}" + self._new_node_name_suffix += 1 + + return new_name + @staticmethod def is_safe_to_fuse_nodes( nodes_to_remove: list[onnx.NodeProto], diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py index a20d6dbffd7a7..42c4a11833641 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion_gelu.py @@ -112,7 +112,9 @@ def fuse_1( return False self.nodes_to_remove.extend(subgraph_nodes) - fused_node = onnx.helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output]) + fused_node = onnx.helper.make_node( + "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output] + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) return True @@ -173,11 +175,9 @@ def fuse_2( if not self.has_constant_input(sqrt_node, 2.0): return False - root_node = self.model.get_parent(div, 0, output_name_to_node) - if root_node is None: - return False + subgraph_input = div.input[0] - if root_node.output[0] not in mul.input: + if subgraph_input not in mul.input: return False subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul] @@ -188,7 +188,9 @@ def fuse_2( return False self.nodes_to_remove.extend(subgraph_nodes) - fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]]) + fused_node = onnx.helper.make_node( + "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]] + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) return True @@ -239,9 +241,8 @@ def fuse_3( if i < 0: return False - root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node) - if root_node is None: - return False + root_input_index = 1 - i + subgraph_input = first_mul.input[root_input_index] if mul_half.output[0] not in input_name_to_nodes: return False @@ -250,7 +251,7 @@ def fuse_3( return False last_mul = children[0] - if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]): + if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input): return False subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul] @@ -263,7 +264,9 @@ def fuse_3( return False self.nodes_to_remove.extend(subgraph_nodes) - fused_node = onnx.helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]]) + fused_node = onnx.helper.make_node( + "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]] + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) return True diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py index d7fb89236d3d2..7d58c1c180822 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py @@ -127,6 +127,7 @@ def fuse( normalize_node = onnx.helper.make_node( "LayerNormalization", + name=self.create_unique_node_name(), inputs=[reduce_mean_node.input[0], weight_input, bias_input], outputs=[last_add_node.output[0]], ) diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index 4591c9c950e6e..46d245d353a07 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -283,6 +283,23 @@ def find_node_by_name(self, node_name, new_nodes_list, graph): node = find_by_name(node_name, graph_nodes_list) return node + def get_largest_node_name_suffix(self, node_name_prefix): + """ + Gets the largest node name (int) suffix for all node names that begin with `node_name_prefix`. + Example: for nodes my_prefix_0 and my_prefix_3, this method returns 3. + """ + suffix = -1 + + for node in self.model.graph.node: + if node.name and node.name.startswith(node_name_prefix): + try: + index = int(node.name[len(node_name_prefix) :]) + suffix = max(index, suffix) + except ValueError: + continue + + return suffix + def find_nodes_by_initializer(self, graph, initializer): """ Find all nodes with given initializer as an input. diff --git a/onnxruntime/test/python/quantization/test_fusions.py b/onnxruntime/test/python/quantization/test_fusions.py new file mode 100644 index 0000000000000..bea110e566fb9 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_fusions.py @@ -0,0 +1,401 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import unittest + +import numpy as np +import onnx + +import onnxruntime +from onnxruntime.quantization.execution_providers.qnn.fusion_lpnorm import FusionLpNormalization +from onnxruntime.quantization.fusions import FusionGelu, FusionLayerNormalization +from onnxruntime.quantization.onnx_model import ONNXModel + + +class TestFusions(unittest.TestCase): + def check_fused_model_correctness(self, orig_model, fused_model, inputs, rtol=1e-7, atol=0): + """ + Checks that the output of the fused model matches the output of the original model. + """ + orig_session = onnxruntime.InferenceSession(orig_model.SerializeToString(), providers=["CPUExecutionProvider"]) + orig_results = orig_session.run(None, inputs) + + fused_session = onnxruntime.InferenceSession( + fused_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + fused_results = fused_session.run([], inputs) + + self.assertEqual(len(orig_results), len(fused_results), "Number of outputs for fused model differs") + for idx, expected_output in enumerate(orig_results): + actual_output = fused_results[idx] + np.testing.assert_allclose( + expected_output, + actual_output, + rtol=rtol, + atol=atol, + err_msg=f"Fused model output {idx} differs", + ) + + def build_erf_sequence_1_model(self, shape): + """ + Erf sequence that fuses into Gelu: + +-------Mul(0.5)---------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul --> + (B=1.4142...) (1) + + This method builds 2 of these Erf sequences: + + [root] -> ERF_SEQUENCE1 -> ERF_SEQUENCE2 -> output + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + # First Erf sequence + mul0_node = onnx.helper.make_node("Mul", ["root", "half_const"], ["mul0_out"]) + div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"]) + erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul1_node = onnx.helper.make_node("Mul", ["add_out", "mul0_out"], ["seq1_output"]) + + # Second Erf sequence + mul0_node_dup = onnx.helper.make_node("Mul", ["seq1_output", "half_const"], ["mul0_out_dup"]) + div_node_dup = onnx.helper.make_node("Div", ["seq1_output", "root2_const"], ["div_out_dup"]) + erf_node_dup = onnx.helper.make_node("Erf", ["div_out_dup"], ["erf_out_dup"]) + add_node_dup = onnx.helper.make_node("Add", ["erf_out_dup", "one_const"], ["add_out_dup"]) + mul1_node_dup = onnx.helper.make_node("Mul", ["add_out_dup", "mul0_out_dup"], ["output"]) + + graph = onnx.helper.make_graph( + [ + mul0_node, + div_node, + erf_node, + add_node, + mul1_node, + mul0_node_dup, + div_node_dup, + erf_node_dup, + add_node_dup, + mul1_node_dup, + ], + "two_erf_sequences", + [root_inp], + [output], + initializer=[one_const, half_const, root2_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_erf_sequence_2_model(self, shape): + """ + +------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul --> + (B=1.4142...) (1) (0.5) + + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"]) + erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul0_node = onnx.helper.make_node("Mul", ["add_out", "root"], ["mul0_out"]) + mul1_node = onnx.helper.make_node("Mul", ["mul0_out", "half_const"], ["output"]) + + graph = onnx.helper.make_graph( + [div_node, erf_node, add_node, mul0_node, mul1_node], + "erf_sequence_2", + [root_inp], + [output], + initializer=[one_const, half_const, root2_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_erf_sequence_3_model(self, shape): + """ + +------------------------------------------+ + | | + | v + [root] --> Div -----> Erf --> Add --> Mul -->Mul + (B=1.4142...) (A=1) (A=0.5) + + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["div_out"]) + erf_node = onnx.helper.make_node("Erf", ["div_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul0_node = onnx.helper.make_node("Mul", ["add_out", "half_const"], ["mul0_out"]) + mul1_node = onnx.helper.make_node("Mul", ["mul0_out", "root"], ["output"]) + + graph = onnx.helper.make_graph( + [div_node, erf_node, add_node, mul0_node, mul1_node], + "erf_sequence_3", + [root_inp], + [output], + initializer=[one_const, half_const, root2_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_erf_sequence_4_model(self, shape): + """ + +----------------------------------------------+ + | | + | v + [root] --> Mul -----> Erf --> Add --> Mul -->Mul + (A=0.7071067690849304) (B=1) (B=0.5) + + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + frac_const = onnx.numpy_helper.from_array(np.array(0.7071067690849304, dtype=np.float32), "frac_const") + + mul0_node = onnx.helper.make_node("Mul", ["root", "frac_const"], ["mul0_out"]) + erf_node = onnx.helper.make_node("Erf", ["mul0_out"], ["erf_out"]) + add_node = onnx.helper.make_node("Add", ["erf_out", "one_const"], ["add_out"]) + mul1_node = onnx.helper.make_node("Mul", ["add_out", "half_const"], ["mul1_out"]) + mul2_node = onnx.helper.make_node("Mul", ["mul1_out", "root"], ["output"]) + + graph = onnx.helper.make_graph( + [mul0_node, erf_node, add_node, mul1_node, mul2_node], + "erf_sequence_4", + [root_inp], + [output], + initializer=[one_const, half_const, frac_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + onnx.helper.make_opsetid("com.microsoft", 1), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_reduce_mean_sequence_model(self, shape, scale_val, bias_val, axis=-1): + """ + +----------------------+ + | | + | v + [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^ ^ ^ + | | | | + +-------------------------------------------------+ [Scale] [Bias] + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const") + bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const") + axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const") + two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const") + eps_const = onnx.numpy_helper.from_array(np.array(1.0e-8, dtype=np.float32), "eps_const") + + rm0_node = onnx.helper.make_node("ReduceMean", ["root", "axes_const"], ["rm0_out"]) + sub_node = onnx.helper.make_node("Sub", ["root", "rm0_out"], ["sub_out"]) + pow_node = onnx.helper.make_node("Pow", ["sub_out", "two_const"], ["pow_out"]) + rm1_node = onnx.helper.make_node("ReduceMean", ["pow_out", "axes_const"], ["rm1_out"]) + add0_node = onnx.helper.make_node("Add", ["rm1_out", "eps_const"], ["add0_out"]) + sqrt_node = onnx.helper.make_node("Sqrt", ["add0_out"], ["sqrt_out"]) + div_node = onnx.helper.make_node("Div", ["sub_out", "sqrt_out"], ["div_out"]) + mul_node = onnx.helper.make_node("Mul", ["div_out", "scale_const"], ["mul_out"]) + add1_node = onnx.helper.make_node("Add", ["mul_out", "bias_const"], ["output"]) + + graph = onnx.helper.make_graph( + [rm0_node, sub_node, pow_node, rm1_node, add0_node, sqrt_node, div_node, mul_node, add1_node], + "reduce_mean_sequence", + [root_inp], + [output], + initializer=[scale_const, bias_const, axes_const, two_const, eps_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def build_reduce_l2_sequence_model(self, shape, epsilon_val, axis=-1): + """ + [root] --> ReduceL2 -----> Clip --> Expand ----> Div --> + | (axis=-1) (min=epsilon) (shape=root) ^ + | (keepdims=True) | + | | + +-----------------------------------------------+ + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + axes_const = onnx.numpy_helper.from_array(np.array([axis], dtype=np.int64), "axes_const") + eps_const = onnx.numpy_helper.from_array(np.array(epsilon_val, dtype=np.float32), "eps_const") + shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const") + + rl2_node = onnx.helper.make_node("ReduceL2", ["root", "axes_const"], ["rl2_out"], keepdims=1) + clip_node = onnx.helper.make_node("Clip", ["rl2_out", "eps_const"], ["clip_out"]) + expand_node = onnx.helper.make_node("Expand", ["clip_out", "shape_const"], ["expand_out"]) + div_node = onnx.helper.make_node("Div", ["root", "expand_out"], ["output"]) + + graph = onnx.helper.make_graph( + [rl2_node, clip_node, expand_node, div_node], + "reducel2_sequence", + [root_inp], + [output], + initializer=[axes_const, eps_const, shape_const], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return ONNXModel(model) + + def test_fuse_erf_to_gelu_1(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_1_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 2 Gelu nodes. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 2) + + gelu_node_0 = model.model.graph.node[0] + gelu_node_1 = model.model.graph.node[1] + self.assertEqual(gelu_node_0.op_type, "Gelu") + self.assertEqual(gelu_node_1.op_type, "Gelu") + + self.assertTrue(gelu_node_0.name) + self.assertTrue(gelu_node_1.name) + self.assertNotEqual(gelu_node_0.name, gelu_node_1.name) # Generated names should not be equal + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_erf_to_gelu_2(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_2_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 Gelu node. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + gelu_node = model.model.graph.node[0] + self.assertEqual(gelu_node.op_type, "Gelu") + self.assertTrue(gelu_node.name) + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_erf_to_gelu_3(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_3_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 Gelu node. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + gelu_node = model.model.graph.node[0] + self.assertEqual(gelu_node.op_type, "Gelu") + self.assertTrue(gelu_node.name) + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_erf_to_gelu_4(self): + shape = (1, 2, 3) + model = self.build_erf_sequence_4_model(shape) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 Gelu node. + modified = FusionGelu(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + gelu_node = model.model.graph.node[0] + self.assertEqual(gelu_node.op_type, "Gelu") + self.assertTrue(gelu_node.name) + + # Check that fusion is equivalent to original Erf model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + def test_fuse_reduce_l2_to_lpnorm(self): + shape = (1, 2, 3) + model = self.build_reduce_l2_sequence_model(shape, 1e-12, axis=-1) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 LpNormalization node. + modified = FusionLpNormalization(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + lpnorm_node = model.model.graph.node[0] + self.assertEqual(lpnorm_node.op_type, "LpNormalization") + self.assertTrue(lpnorm_node.name) + + # LpNorm's p attribute should be set to 2 + p_attr = next(attr for attr in lpnorm_node.attribute if attr.name == "p") + self.assertEqual(p_attr.i, 2) + + def test_fuse_reduce_mean_to_layer_norm(self): + shape = (1, 2, 3) + model = self.build_reduce_mean_sequence_model(shape, [2.0, 2.0, 2.0], [1.0, 1.0, 1.0], axis=-1) + orig_model = onnx.ModelProto() + orig_model.CopyFrom(model.model) + + # Check that fusion simplified model to 1 LayerNormalization node. + modified = FusionLayerNormalization(model).apply() + self.assertTrue(modified) + self.assertEqual(len(model.model.graph.node), 1) + + layer_norm_node = model.model.graph.node[0] + self.assertEqual(layer_norm_node.op_type, "LayerNormalization") + self.assertTrue(layer_norm_node.name) + + # Check that fused model is equivalent to original model. + inputs = {"root": np.ones(shape, dtype=np.float32)} + self.check_fused_model_correctness(orig_model, model.model, inputs) + + +if __name__ == "__main__": + unittest.main() From 3b46ab643944a3bcc9e4d9eb2c155ead0bad5cdb Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 28 Feb 2024 00:46:29 +0800 Subject: [PATCH 156/207] Re-add testing removed by mistake. (#19647) --- .../azure-pipelines/linux-ci-pipeline.yml | 42 ++++++++++++++++++- .../docker/scripts/manylinux/requirements.txt | 1 + 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index a4bd24b4dd18b..02147c321fab3 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -115,6 +115,7 @@ stages: searchFolder: '$(Build.BinariesDirectory)' testRunTitle: 'Unit Test Run' condition: succeededOrFailed() + - job: Linux_Release timeoutInMinutes: 180 workspace: @@ -243,7 +244,46 @@ stages: ln -s /data/models $(Build.BinariesDirectory)/models displayName: link model dir - + - bash: | + mkdir -p $HOME/.onnx + docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + onnxruntimecpubuild \ + /bin/bash -c " + set -ex; \ + pushd /onnxruntime_src/csharp; \ + dotnet restore /onnxruntime_src/csharp/OnnxRuntime.DesktopOnly.CSharp.sln; \ + dotnet build /onnxruntime_src/csharp/OnnxRuntime.DesktopOnly.CSharp.sln -c Release; \ + dotnet test /onnxruntime_src/csharp/OnnxRuntime.DesktopOnly.CSharp.sln -c Release -f net6.0 --no-build -l \"console;verbosity=normal\"; \ + popd + " + displayName: 'Dotnet build C# sln and Test' + + - bash: | + mkdir -p $HOME/.onnx + docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + onnxruntimecpubuild \ + /bin/bash -c " + set -ex; \ + /bin/bash /onnxruntime_src/tools/scripts/python_test.sh /onnxruntime_src /build Release && \ + /bin/bash /onnxruntime_src/tools/scripts/symbolic_shape_infer_test.sh /build + " + displayName: 'Run Release tests and symbolic shape infer test' - task: PublishTestResults@2 displayName: 'Publish unit test results' diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index 94f52f476579b..886f19388d01e 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -10,3 +10,4 @@ protobuf==4.21.12 sympy==1.12 flatbuffers neural-compressor>=2.2.1 +triton From 580ee20dfce2849029229eb213dc8c7c87a89483 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 28 Feb 2024 02:56:16 +1000 Subject: [PATCH 157/207] Tweak Windows build parallelization settings (#19664) ### Description Use UseMultiToolTask and limit the number of cl.exe instances running. MultiToolTask info: https://devblogs.microsoft.com/cppblog/improved-parallelism-in-msbuild/ Info on why limiting CL_MPCount can help: https://github.com/Microsoft/checkedc-clang/wiki/Parallel-builds-of-clang-on-Windows The current CIs have 4 cores (both physical and logical). Hardcoded the GPU build in win-ci.yml to use CL_MPCount of 2 as that seems to work fine. Can adjust if needed to base it on the actual number of cores or to use build.py to build. Caveat: I've run about 16 builds and haven't seen a slow build yet, but as the root cause of the slow builds isn't really known this isn't guaranteed to be a fix. ### Motivation and Context Try and prevent super slow GPU builds by reducing number of tasks potentially running in parallel. --- tools/ci_build/build.py | 15 ++++++++++++++- .../github/azure-pipelines/templates/win-ci.yml | 3 ++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 74c473d34f548..1056c4ed84510 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1451,6 +1451,13 @@ def generate_build_tree( # tools need to use the symbols. add_default_definition(cmake_extra_defines, "CMAKE_MSVC_DEBUG_INFORMATION_FORMAT", "ProgramDatabase") + if number_of_parallel_jobs(args) > 0: + # https://devblogs.microsoft.com/cppblog/improved-parallelism-in-msbuild/ + # NOTE: this disables /MP if set (according to comments on blog post). + # By default, MultiProcMaxCount and CL_MPCount value are equal to the number of CPU logical processors. + # See logic around setting CL_MPCount below + cmake_args += ["-DCMAKE_VS_GLOBALS=UseMultiToolTask=true;EnforceProcessCountAcrossBuilds=true"] + cmake_args += [f"-D{define}" for define in cmake_extra_defines] cmake_args += cmake_extra_args @@ -1662,11 +1669,17 @@ def build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, targe build_tool_args = [] if num_parallel_jobs != 1: if is_windows() and args.cmake_generator != "Ninja" and not args.build_wasm: + # https://github.com/Microsoft/checkedc-clang/wiki/Parallel-builds-of-clang-on-Windows suggests + # not maxing out CL_MPCount + # Start by having one less than num_parallel_jobs (default is num logical cores), + # limited to a range of 1..3 + # that gives maxcpucount projects building using up to 3 cl.exe instances each build_tool_args += [ f"/maxcpucount:{num_parallel_jobs}", + # one less than num_parallel_jobs, at least 1, up to 3 + f"/p:CL_MPCount={min(max(num_parallel_jobs - 1, 1), 3)}", # if nodeReuse is true, msbuild processes will stay around for a bit after the build completes "/nodeReuse:False", - f"/p:CL_MPCount={num_parallel_jobs}", ] elif args.cmake_generator == "Xcode": build_tool_args += [ diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 8ed22153fd947..e32956d6eb913 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -162,10 +162,11 @@ stages: platform: ${{ parameters.msbuildPlatform }} configuration: RelWithDebInfo msbuildArchitecture: ${{ parameters.buildArch }} - maximumCpuCount: true + maximumCpuCount: true # default is num logical cores worth of projects building concurrently logProjectEvents: true workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' createLogFile: true + msbuildArgs: "/p:CL_MPCount=2" # 2x cl.exe per project building. - task: PythonScript@0 displayName: 'test' From 1c468a03b90aa8122d49b3148152a67b0519d36e Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 28 Feb 2024 03:27:43 +1000 Subject: [PATCH 158/207] Improve Nuget-CUDA-Packaging-Pipeline (#19668) ### Description * Publish the artifacts as late as possible * once published the artifacts are immutable, and any retry will fail if they exist * if any step fails after publishing the stage cannot be retried * use powershell to cleanup * DeleteFiles is taking >30 mins and causing the stage to timeout * powershell took < 1s ### Motivation and Context Make pipeline more robust --- .../stages/nuget-combine-cuda-stage.yml | 13 ++++++------- ...mponent-governance-component-detection-steps.yml | 7 ++----- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml index 8ca3d9148b514..064e2ea91d194 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -213,13 +213,6 @@ stages: PlatformsSupported: 'linux-x64' VerifyNugetSigning: false - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline NuGet Artifact' - inputs: - artifactName: 'drop-signed-nuget-GPU' - targetPath: '$(Build.ArtifactStagingDirectory)' - - - task: MSBuild@1 displayName: 'Clean C#' inputs: @@ -241,6 +234,12 @@ stages: parameters: condition: 'succeeded' + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline NuGet Artifact' + inputs: + artifactName: 'drop-signed-nuget-GPU' + targetPath: '$(Build.ArtifactStagingDirectory)' + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 displayName: 'Clean Agent Directories' condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml index f1418e75bffa2..3d128fdb78eee 100644 --- a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml @@ -6,11 +6,8 @@ parameters: steps: - ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: - - task: DeleteFiles@1 - inputs: - SourceFolder: '$(Build.BinariesDirectory)' - contents: | - **/* + - powershell: | + Remove-Item $(Build.BinariesDirectory)/* -Recurse -Force displayName: 'Clean up build directory' - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 From 2e4d1b8f1ba928fe5879077eced9cd5191760cfb Mon Sep 17 00:00:00 2001 From: zesongw Date: Wed, 28 Feb 2024 02:01:12 +0800 Subject: [PATCH 159/207] [WebNN EP] Add support for Op MatMul of WebNN CPU backend (#19413) Enable MatMul support for WebNN CPU backend to support more models. --- onnxruntime/core/providers/webnn/builders/helper.h | 2 +- .../webnn/builders/impl/gemm_op_builder.cc | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index d94729e60d029..d7892fe02c1ba 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -195,7 +195,7 @@ static const InlinedHashMap op_map = { {"LessOrEqual", {"lesserOrEqual", false}}, {"Log", {"log", false}}, {"LpPool", {"l2Pool2d", false}}, - {"MatMul", {"matmul", false}}, + {"MatMul", {"matmul", true}}, {"MatMulInteger", {"matmulInteger", false}}, {"Max", {"max", true}}, {"MaxPool", {"maxPool2d", true}}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 4bf991a1b0105..d5f84f853f7de 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -29,7 +29,7 @@ class GemmOpBuilder : public BaseOpBuilder { // Add operator related. Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /* logger */) const { + const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C @@ -38,7 +38,17 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name()); emscripten::val output = emscripten::val::object(); if (op_type == "MatMul") { - output = model_builder.GetBuilder().call("matmul", a, b); + std::vector a_shape; + if (!GetShape(*input_defs[a_idx], a_shape, logger)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Can not get shape of A."); + } + // The inputs of MatMul must be at least 3D for WebNN CPU backend. Use GEMM for 2D case. + // TODO: Remove this workaround when it is fixed in Chromium. + if (model_builder.GetWebnnDeviceType() == WebnnDeviceType::CPU && a_shape.size() == 2) { + output = model_builder.GetBuilder().call("gemm", a, b); + } else { + output = model_builder.GetBuilder().call("matmul", a, b); + } } else if (op_type == "MatMulInteger") { emscripten::val a_zero_point = emscripten::val::null(); emscripten::val b_zero_point = emscripten::val::null(); From 3cb81cdde25d059af5674506f6a5b899c9c0f5ee Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:07:15 -0800 Subject: [PATCH 160/207] [js/common] move 'env.wasm.trace' to 'env.trace' (#19617) ### Description Try to move 'env.wasm.trace' to 'env.trace' to make it less confusing, because it also works in webgpu. Marked 'env.wasm.trace' as deprecated. --- js/common/lib/env.ts | 9 +++++++++ js/common/lib/trace.ts | 6 +++--- js/web/lib/wasm/jsep/backend-webgpu.ts | 3 ++- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 6299c26159400..73a47d1a4f937 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -36,6 +36,7 @@ export declare namespace Env { /** * set or get a boolean value indicating whether to enable trace. * + * @deprecated Use `env.trace` instead. If `env.trace` is set, this property will be ignored. * @defaultValue `false` */ trace?: boolean; @@ -167,6 +168,7 @@ export interface Env { * @defaultValue `'warning'` */ logLevel?: 'verbose'|'info'|'warning'|'error'|'fatal'; + /** * Indicate whether run in debug mode. * @@ -174,6 +176,13 @@ export interface Env { */ debug?: boolean; + /** + * set or get a boolean value indicating whether to enable trace. + * + * @defaultValue `false` + */ + trace?: boolean; + /** * Get version of the current package. */ diff --git a/js/common/lib/trace.ts b/js/common/lib/trace.ts index 404f7ef8089af..7e0487b350198 100644 --- a/js/common/lib/trace.ts +++ b/js/common/lib/trace.ts @@ -4,7 +4,7 @@ import {env} from './env-impl.js'; export const TRACE = (deviceType: string, label: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } // eslint-disable-next-line no-console @@ -30,14 +30,14 @@ const TRACE_FUNC = (msg: string, extraMsg?: string) => { }; export const TRACE_FUNC_BEGIN = (extraMsg?: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } TRACE_FUNC('BEGIN', extraMsg); }; export const TRACE_FUNC_END = (extraMsg?: string) => { - if (!env.wasm.trace) { + if (typeof env.trace === 'undefined' ? !env.wasm.trace : !env.trace) { return; } TRACE_FUNC('END', extraMsg); diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 3e3a191ec3ead..27c5566ab9fed 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -710,7 +710,8 @@ export class WebGpuBackend { } setQueryType(): void { this.queryType = 'none'; - if (this.env.webgpu.profiling?.mode === 'default' || this.env.wasm.trace) { + if (this.env.webgpu.profiling?.mode === 'default' || + (typeof this.env.trace === 'undefined' ? this.env.wasm.trace : this.env.trace)) { if (this.device.features.has('chromium-experimental-timestamp-query-inside-passes')) { this.queryType = 'inside-passes'; } else if (this.device.features.has('timestamp-query')) { From c20ced4132d111e3e63844e292f0d8e318cffab2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Tue, 27 Feb 2024 20:26:48 +0100 Subject: [PATCH 161/207] Use CMake's find package for CUDA libs (#19673) ### Description Answers issue #19640 More details are in the issue, basically I am changing all the include directory and link directory usage to CMake's `CUDA::*` targets --- cmake/CMakeLists.txt | 4 ++++ cmake/adjust_global_compile_flags.cmake | 2 +- .../external/onnxruntime_external_deps.cmake | 3 +-- cmake/onnxruntime_providers_cuda.cmake | 20 +++++++++---------- cmake/onnxruntime_providers_tensorrt.cmake | 11 +++++----- cmake/onnxruntime_python.cmake | 5 +---- cmake/onnxruntime_unittests.cmake | 4 ++-- .../core/providers/cuda/nvtx_profile.cc | 5 ----- 8 files changed, 25 insertions(+), 29 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 1376c90fbcefe..8453da19ce3a6 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1412,6 +1412,10 @@ endif() if (onnxruntime_USE_CUDA) set(CMAKE_CUDA_RUNTIME_LIBRARY Shared) set(CMAKE_CUDA_STANDARD 17) + if(onnxruntime_CUDA_HOME) + file(TO_CMAKE_PATH CUDAToolkit_ROOT ${onnxruntime_CUDA_HOME}) + endif() + find_package(CUDAToolkit REQUIRED) if(onnxruntime_CUDNN_HOME) file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME) endif() diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 8161ea574b8cc..d3f9256105127 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -205,7 +205,7 @@ endif() macro(check_nvcc_compiler_flag _FLAG _RESULT) - execute_process(COMMAND ${onnxruntime_CUDA_HOME}/bin/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR) + execute_process(COMMAND ${CUDAToolkit_BIN_DIR}/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR) message("NVCC_ERROR = ${NVCC_ERROR}") message("NVCC_OUT = ${NVCC_OUT}") if ("${NVCC_OUT}" MATCHES "0") diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 22d12b128dc1f..09d57164b4ee1 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -556,16 +556,15 @@ message("Finished fetching external dependencies") set(onnxruntime_LINK_DIRS ) if (onnxruntime_USE_CUDA) #TODO: combine onnxruntime_CUDNN_HOME and onnxruntime_CUDA_HOME, assume they are the same + find_package(CUDAToolkit REQUIRED) if (WIN32) if(onnxruntime_CUDNN_HOME) list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib/x64) endif() - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/x64/lib64) else() if(onnxruntime_CUDNN_HOME) list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDNN_HOME}/lib ${onnxruntime_CUDNN_HOME}/lib64) endif() - list(APPEND onnxruntime_LINK_DIRS ${onnxruntime_CUDA_HOME}/lib64) endif() endif() diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 9887d615c92d7..0f6d48bdb6ec8 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -178,15 +178,16 @@ add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if(onnxruntime_CUDA_MINIMAL) target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL) - target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface CUDA::cudart) else() - target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) + target_link_libraries(${target} PRIVATE CUDA::cublasLt CUDA::cublas cudnn CUDA::curand CUDA::cufft CUDA::cudart + ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface) if(onnxruntime_CUDNN_HOME) target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include) target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) endif() endif() - + if (onnxruntime_USE_TRITON_KERNEL) # compile triton kernel, generate .a and .h files include(onnxruntime_compile_triton_kernel.cmake) @@ -196,25 +197,24 @@ target_include_directories(${target} PRIVATE ${triton_kernel_header_dir}) target_link_libraries(${target} PUBLIC -Wl,--whole-archive ${triton_kernel_obj_file} -Wl,--no-whole-archive) # lib cuda needed by cuLaunchKernel - target_link_libraries(${target} PRIVATE cuda) + target_link_libraries(${target} PRIVATE CUDA::cuda_driver) endif() include(cutlass) target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) - target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} + PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) # ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA) set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime") if (onnxruntime_ENABLE_CUDA_PROFILING) # configure cupti for cuda profiling - target_include_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/include) - target_link_directories(${target} PRIVATE ${onnxruntime_CUDA_HOME}/extras/CUPTI/lib64) - target_link_libraries(${target} PRIVATE cupti) + target_link_libraries(${target} PRIVATE CUDA::cupti) endif() - if (onnxruntime_ENABLE_NVTX_PROFILE AND NOT WIN32) - target_link_libraries(${target} PRIVATE nvToolsExt) + if (onnxruntime_ENABLE_NVTX_PROFILE) + target_link_libraries(${target} PRIVATE CUDA::nvtx3) endif() if (onnxruntime_ENABLE_TRAINING_OPS) diff --git a/cmake/onnxruntime_providers_tensorrt.cmake b/cmake/onnxruntime_providers_tensorrt.cmake index 686a993de3a4a..15ffc29e79ff4 100644 --- a/cmake/onnxruntime_providers_tensorrt.cmake +++ b/cmake/onnxruntime_providers_tensorrt.cmake @@ -8,7 +8,7 @@ set(BUILD_LIBRARY_ONLY 1) add_definitions("-DONNX_ML=1") add_definitions("-DONNX_NAMESPACE=onnx") - set(CUDA_INCLUDE_DIRS ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + set(CUDA_INCLUDE_DIRS ${CUDAToolkit_INCLUDE_DIRS}) set(TENSORRT_ROOT ${onnxruntime_TENSORRT_HOME}) set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) set(PROTOBUF_LIBRARY ${PROTOBUF_LIB}) @@ -58,7 +58,7 @@ URL_HASH SHA1=${DEP_SHA1_onnx_tensorrt} ) if (NOT CUDA_INCLUDE_DIR) - set(CUDA_INCLUDE_DIR ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) # onnx-tensorrt repo needs this variable to build + set(CUDA_INCLUDE_DIR ${CUDAToolkit_INCLUDE_DIRS}) # onnx-tensorrt repo needs this variable to build endif() # The onnx_tensorrt repo contains a test program, getSupportedAPITest, which doesn't support Windows. It uses # unistd.h. So we must exclude it from our build. onnxruntime_fetchcontent_makeavailable is for the purpose. @@ -102,11 +102,12 @@ onnxruntime_add_include_to_target(onnxruntime_providers_tensorrt onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface) add_dependencies(onnxruntime_providers_tensorrt onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER) - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS}) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface ${ABSEIL_LIBS} PUBLIC CUDA::cudart) else() - target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} cudart ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS}) + target_link_libraries(onnxruntime_providers_tensorrt PRIVATE ${onnxparser_link_libs} ${trt_link_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} ${PROTOBUF_LIB} flatbuffers::flatbuffers ${ABSEIL_LIBS} PUBLIC CUDA::cudart) endif() - target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} + PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) if(onnxruntime_CUDNN_HOME) target_include_directories(onnxruntime_providers_tensorrt PRIVATE ${onnxruntime_CUDNN_HOME}/include) endif() diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 3f20787e87425..23c6e5e430875 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -282,10 +282,7 @@ if (WIN32) get_filename_component(CUDNN_DLL_NAME ${CUDNN_DLL_PATH} NAME_WE) string(REPLACE "cudnn64_" "" CUDNN_VERSION "${CUDNN_DLL_NAME}") if(NOT onnxruntime_CUDA_VERSION) - message("Reading json file ${onnxruntime_CUDA_HOME}/version.json") - set(CUDA_SDK_JSON_FILE_PATH "${onnxruntime_CUDA_HOME}/version.json") - file(READ ${CUDA_SDK_JSON_FILE_PATH} CUDA_SDK_JSON_CONTENT) - string(JSON onnxruntime_CUDA_VERSION GET ${CUDA_SDK_JSON_CONTENT} "cuda" "version") + set(onnxruntime_CUDA_VERSION ${CUDAToolkit_VERSION}) message("onnxruntime_CUDA_VERSION=${onnxruntime_CUDA_VERSION}") endif() file(APPEND "${VERSION_INFO_FILE}" diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 3ed695327c183..88f662075e177 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -67,7 +67,7 @@ function(AddTest) if(onnxruntime_USE_CUDA) #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, # otherwise it will impact when CUDA DLLs can be unloaded. - target_link_libraries(${_UT_TARGET} PRIVATE cudart) + target_link_libraries(${_UT_TARGET} PRIVATE CUDA::cudart) endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -1268,7 +1268,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo) endif() if (onnxruntime_USE_CUDA) - list(APPEND onnxruntime_shared_lib_test_LIBS cudart) + list(APPEND onnxruntime_shared_lib_test_LIBS CUDA::cudart) endif() if (onnxruntime_USE_ROCM) list(APPEND onnxruntime_shared_lib_test_LIBS hip::host) diff --git a/onnxruntime/core/providers/cuda/nvtx_profile.cc b/onnxruntime/core/providers/cuda/nvtx_profile.cc index 6c7c594066b86..867e7c1f24584 100644 --- a/onnxruntime/core/providers/cuda/nvtx_profile.cc +++ b/onnxruntime/core/providers/cuda/nvtx_profile.cc @@ -4,13 +4,8 @@ #ifdef ENABLE_NVTX_PROFILE #include "nvtx_profile.h" #include "core/common/common.h" -#if defined(_WIN32) || defined(WIN32) || defined(__CYGWIN__) || defined(__MINGW32__) || defined(__BORLANDC__) #include #include -#else -#include -#include -#endif namespace onnxruntime { namespace profile { From f95c0773a129a4605b2161f5f9fddb8116c948d0 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Wed, 28 Feb 2024 10:40:40 +0800 Subject: [PATCH 162/207] Add share memory Flag in docker (#19672) ### Description ### Motivation and Context Ref: https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem Co-authored-by: Your Name --- tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 822bc559d992d..165bd804a8ad5 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -241,7 +241,7 @@ stages: script: | set -e -x mkdir -p $HOME/.onnx - docker run --gpus all --rm \ + docker run --gpus all --shm-size=1g --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --rm \ --volume $(Build.SourcesDirectory):/onnxruntime_src \ --volume $(Build.BinariesDirectory)/Release:/build/Release \ --volume /data/models:/build/models:ro \ From 026e3178ae71cfcc5cfa2decde9a7d64b935d255 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 28 Feb 2024 15:57:05 +0800 Subject: [PATCH 163/207] Improve memory matrix for ORTModule (#19620) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Memory matrix for ORTModule Collect parameter/gradient/buffers sizes also. Exposed as a function, can be used externally for debugging purpose. ``` 2024-02-27 07:18:55,283 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,322 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,358 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 816 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,438 orttraining.rank-0 [INFO] - rank-0 step 1 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|■| 2/3200 [01:27<32:05:11, 36.12s/it]2024-02-27 07:18:55,498 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,537 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,576 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,657 orttraining.rank-0 [INFO] - rank-0 step 2 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|■| 3/3200 [01:27<17:30:57, 19.72s/it]2024-02-27 07:18:55,711 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,750 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,786 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,867 orttraining.rank-0 [INFO] - rank-0 step 3 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 [2024-02-27 07:18:55,886] [INFO] [loss_scaler.py:190:update_scale] [deepspeed] OVERFLOW! Rank 0 Skipping step. Attempted loss scale: 65536, but hysteresis is 2. Reducing hysteresis to 1 0%|▎ | 4/3200 [01:28<10:39:52, 12.01s/it]2024-02-27 07:18:55,902 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,944 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:55,979 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,060 orttraining.rank-0 [INFO] - rank-0 step 4 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|■| 5/3200 [01:28<6:53:04, 7.76s/it]2024-02-27 07:18:56,115 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,154 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,190 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,270 orttraining.rank-0 [INFO] - rank-0 step 5 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|■| 6/3200 [01:28<4:36:19, 5.19s/it]2024-02-27 07:18:56,323 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,365 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,398 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,478 orttraining.rank-0 [INFO] - rank-0 step 6 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▌ | 7/3200 [01:28<3:09:33, 3.56s/it]2024-02-27 07:18:56,533 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,572 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,608 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,727 orttraining.rank-0 [INFO] - rank-0 step 7 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▌ | 8/3200 [01:28<2:13:48, 2.52s/it]2024-02-27 07:18:56,806 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,846 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,882 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: pre_backward | allocated: 8926 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:56,962 orttraining.rank-0 [INFO] - rank-0 step 8 memory (MiB) | phase: post_backward | allocated: 6098 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 218 | max inactive: 831 | param: 5314 | grad: 12 | buffer: 8 0%|▋ | 9/3200 [01:29<1:36:03, 1.81s/it]2024-02-27 07:18:57,053 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: pre_forward | allocated: 5331 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 219 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 2024-02-27 07:18:57,094 orttraining.rank-0 [INFO] - rank-0 step 9 memory (MiB) | phase: post_forward | allocated: 8162 | max allocated: 9039 | cached: 9382 | max cached: 9382 | inactive: 400 | max inactive: 831 | param: 5314 | grad: 0 | buffer: 8 ``` --- .../training/ortmodule/_runtime_inspector.py | 37 +++------ .../python/training/utils/__init__.py | 2 + .../training/utils/torch_profile_utils.py | 76 +++++++++++++++++++ 3 files changed, 88 insertions(+), 27 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index 078ce4d27cd6f..772b9bd9e31ae 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -14,7 +14,7 @@ from sympy import Symbol, simplify from sympy.parsing.sympy_parser import parse_expr -from onnxruntime.training.utils import PTable +from onnxruntime.training.utils import PTable, log_memory_usage from ._execution_agent import TrainingAgent from .options import _MemoryOptimizationLevel, _RuntimeOptions @@ -509,6 +509,8 @@ def __init__(self, m: torch.nn.Module, logger: Logger): self._is_first_inspect = True + self._m = m + def is_enabled(self) -> bool: """Check if memory inspector is enabled.""" return self._is_enabled @@ -621,29 +623,13 @@ def inspect_memory(self, cur_phase: Phase): need_print = self._current_step < 10 or (self._current_step & (self._current_step - 1) == 0) if need_print: - cur_mem_allocated = self._normalize(torch.cuda.memory_allocated()) - max_mem_allocated = self._normalize(torch.cuda.max_memory_allocated()) - cur_mem_cached = self._normalize(torch.cuda.memory_reserved()) - max_mem_cached = self._normalize(torch.cuda.max_memory_reserved()) - torch_mem_stat = torch.cuda.memory_stats() - cur_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) - max_mem_inactive = self._normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) - - mem_stats = [ - ["phase", _convert_phase_to_string(cur_phase)], - ["allocated", cur_mem_allocated], # current memory allocated for tensors - ["max allocated", max_mem_allocated], # peak memory allocated for tensors - ["cached", cur_mem_cached], # current memory cached for the caching allocator - ["max cached", max_mem_cached], # peak memory cached for caching allocator. - ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory - ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory - ] - - summ = f"{self._rank_info} step {self._current_step} memory ({MemoryObserver.NORMALIZER_UNIT})" - for stat in mem_stats: - summ += f" | {stat[0]}: {stat[1]}" - - self._logger.info(summ) + log_memory_usage( + _convert_phase_to_string(cur_phase), + rank_0_only=True, + step_info=f"step {self._current_step}", + logger=self._logger, + module=self._m, + ) if cur_phase == self._last_phase: self._increase_step() @@ -655,9 +641,6 @@ def inspect_memory(self, cur_phase: Phase): def _increase_step(self): self._current_step += 1 - def _normalize(self, mem_size_in_bytes: Union[float, int]) -> str: - return f"{float(mem_size_in_bytes) / MemoryObserver.NORMALIZER_FACTOR:.0f}" - def display_memory_optimization_plans(self, memory_optimizer_config, details=False) -> Tuple[List[str], PTable]: mem_plan_count = len(self.cluster_id_combination_to_saving_symbolics_map) diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index b4a518d573998..ecfb7d7907f3c 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -12,6 +12,7 @@ unflatten_data_using_schema, ) from onnxruntime.training.utils.torch_profile_utils import ( + log_memory_usage, nvtx_function_decorator, torch_nvtx_range_pop, torch_nvtx_range_push, @@ -31,6 +32,7 @@ "torch_nvtx_range_push", "torch_nvtx_range_pop", "nvtx_function_decorator", + "log_memory_usage", "pytorch_type_to_onnx_dtype", "onnx_dtype_to_pytorch_dtype", "pytorch_scalar_type_to_pytorch_dtype", diff --git a/orttraining/orttraining/python/training/utils/torch_profile_utils.py b/orttraining/orttraining/python/training/utils/torch_profile_utils.py index 382d7dac142fe..9e8a41e0dc7c8 100644 --- a/orttraining/orttraining/python/training/utils/torch_profile_utils.py +++ b/orttraining/orttraining/python/training/utils/torch_profile_utils.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + import torch @@ -26,3 +28,77 @@ def wrapped_fn(*args, **kwargs): return ret_val return wrapped_fn + + +def log_memory_usage(cur_phase: str, rank_0_only=True, step_info="", logger=None, module=None): + """Log memory usage for the current phase. + Args: + cur_phase (str): The current phase. + rank_0_only (bool, optional): Only log the memory usage for rank 0. Defaults to True. + step_info (str, optional): The step information. Defaults to "". + logger (logging.Logger, optional): The logger to log the memory usage. Defaults to None, which means print to stdout. + module (torch.nn.Module, optional): The module to get parameter, buffer and grad sizes. Defaults to None. + """ + rank = 0 + if rank_0_only is True: + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + if rank != 0: + return + + _normalizer_factor = float(1024 * 1024) + _normalizer_unit = "MiB" + + def _normalize(mem_size_in_bytes: float | int) -> str: + return f"{float(mem_size_in_bytes) / _normalizer_factor:.0f}" + + cur_mem_allocated = _normalize(torch.cuda.memory_allocated()) + max_mem_allocated = _normalize(torch.cuda.max_memory_allocated()) + cur_mem_cached = _normalize(torch.cuda.memory_reserved()) + max_mem_cached = _normalize(torch.cuda.max_memory_reserved()) + torch_mem_stat = torch.cuda.memory_stats() + cur_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.current", 0)) + max_mem_inactive = _normalize(torch_mem_stat.get("inactive_split_bytes.all.peak", 0)) + + mem_stats = [ + ["phase", cur_phase], + ["allocated", cur_mem_allocated], # current memory allocated for tensors + ["max allocated", max_mem_allocated], # peak memory allocated for tensors + ["cached", cur_mem_cached], # current memory cached for the caching allocator + ["max cached", max_mem_cached], # peak memory cached for caching allocator. + ["inactive", cur_mem_inactive], # amount of inactive, non-releasable memory + ["max inactive", max_mem_inactive], # peak of inactive, non-releasable memory + ] + + # Calculate the total size of parameters and gradients in the model + if module: + param_total_size = 0 + grad_total_size = 0 + for p in module.parameters(): + if p.is_cuda: + param_total_size += p.numel() * p.element_size() + if p.grad is not None and p.grad.is_cuda: + grad_total_size += p.grad.numel() * p.grad.element_size() + + # Calculate the total size of buffers in the model + buffer_total_size = 0 + for b in module.buffers(): + if b.is_cuda: + buffer_total_size += b.numel() * b.element_size() + + mem_stats.extend( + [ + ["param", _normalize(param_total_size)], + ["grad", _normalize(grad_total_size)], + ["buffer", _normalize(buffer_total_size)], + ] + ) + + summ = f"rank-{rank} {step_info} memory ({_normalizer_unit})" + for stat in mem_stats: + summ += f" | {stat[0]}: {stat[1]}" + + if logger is None: + print(summ) + else: + logger.info(summ) From 7a147fc6f76a30b8d5875352afced515431ec7e5 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 28 Feb 2024 02:20:53 -0800 Subject: [PATCH 164/207] Remove a bash task from webgpu CI pipeline (#19682) ### Description It is a "Bash" task that requires running bash on Windows. Most Windows operating systems do not have Bash installed. Given this task is only debugging purposes, we can remove it for now. ### Motivation and Context I am making this change because I am regenerating the VM image in a different manner, and the new image does not contain bash. Once this PR is in, I can switch the images. --- .../github/azure-pipelines/templates/win-web-ci.yml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 8ba3517530edd..043da233cc674 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -155,12 +155,7 @@ jobs: path: $(Build.SourcesDirectory)/js/test/ cacheHitVar: CACHE_RESTORED displayName: 'Cache ONNX node test data' - - task: Bash@3 - inputs: - targetType: 'inline' - script: find "$(Build.SourcesDirectory)/js/test/" -type f - condition: and(not(canceled()), eq(variables.CACHE_RESTORED, 'true')) - displayName: 'List ONNX node test data' + - task: PowerShell@2 inputs: filePath: '$(Build.SourcesDirectory)\tools\ci_build\github\js\pack-npm-packages.ps1' From 913bdc7306e11b65644f733861684a3a460e8db0 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 28 Feb 2024 08:30:12 -0800 Subject: [PATCH 165/207] [QNN Quant] Handle external data for QNN preprocessing/quant (#19670) ### Description - Adds parameters to `qnn_preprocess_model()` to allow saving the new model with external data. - Updates `get_qnn_qdq_config()` to: - Load model without external data (it is not needed) - Return a quantization configuration with `use_external_data_format` set to `True` if the model has external data or if the model is >= 2GB. ### Motivation and Context Update QNN quantization to better handle large models that use external data. --- .../execution_providers/qnn/preprocess.py | 51 +++++- .../execution_providers/qnn/quant_config.py | 15 +- .../quantization/test_qnn_preprocess_model.py | 170 ++++++++++++++++++ .../test_tensor_quant_overrides_option.py | 30 ++++ 4 files changed, 261 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/test/python/quantization/test_qnn_preprocess_model.py diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index b1c114fe1f9fd..b0dab81830c8b 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -3,6 +3,8 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +from __future__ import annotations + import logging from pathlib import Path @@ -13,7 +15,44 @@ from .fusion_lpnorm import FusionLpNormalization -def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: bool = False) -> bool: +def qnn_preprocess_model( + model_input: Path, + model_output: Path, + fuse_layernorm: bool = False, + save_as_external_data: bool = False, + all_tensors_to_one_file: bool = False, + external_data_location: str | None = None, + external_data_size_threshold: int = 1024, + external_data_convert_attribute: bool = False, +) -> bool: + """ + If necessary, this method creates a new "pre-processed" model in preparation for + quantization of a model to be used in QNN EP. Returns true if a new model was created. + + This method perfoms the following operations: + - Fuse Erf sequence into a single Gelu node. + - Fuse ReduceL2 sequence into a single LpNormalization node (p == 2). + - (Optional) Fuse ReduceMean sequence into a single LayerNormalization node. + + Args: + model_input: Path to the input model file. + model_output: Path the output model file, which is only created if this method returns True. + fuse_layernorm: True if ReduceMean sequences should be fused into LayerNormalization nodes. + Defaults to False. + save_as_external_data: True if output model should be saved with external data. Defaults to false. + all_tensors_to_one_file: Effective only if save_as_external_data is true. Defaults to false. + If true, save all tensors to one external file specified by external_data_location. + If false, save each tensor to a file named with the tensor name. + external_data_location: Effective only if save_as_external_data is true. Defaults to None. + Specify the external file to which all tensors are saved. Path is relative + to the model path. If not specified, the model's name is used. + external_data_size_threshold: Effective only if save_as_external_data is true. Defaults to 1024. + Tensors with a data size >= external_data_size_threshold are converted to external data. + To convert every tensor with raw data to external data, set to 0. + external_data_convert_attribute: Effective only if save_as_external_data is true. Defaults to false. + If true, convert all tensors to external data. + If false, convert only non-attribute tensors to external data. + """ modified = False model = onnx.load_model(model_input) onnx_model = ONNXModel(model) @@ -57,6 +96,14 @@ def qnn_preprocess_model(model_input: Path, model_output: Path, fuse_layernorm: if modified: onnx_model.topological_sort() - onnx.save_model(model, model_output) + onnx.save_model( + model, + model_output, + save_as_external_data=save_as_external_data, + all_tensors_to_one_file=all_tensors_to_one_file, + location=external_data_location, + size_threshold=external_data_size_threshold, + convert_attribute=external_data_convert_attribute, + ) return modified diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py index 7c2fa4f65ae1b..e9affae7ac263 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/quant_config.py @@ -15,6 +15,7 @@ Q16_TYPES = {QuantType.QInt16, QuantType.QUInt16} Q8_TYPES = {QuantType.QInt8, QuantType.QUInt8} OP_TYPES_TO_EXCLUDE = {"Cast"} +MODEL_SIZE_THRESHOLD = 2147483648 # Quant model should use external data if >= 2GB def get_qnn_qdq_config( @@ -28,14 +29,21 @@ def get_qnn_qdq_config( if per_channel: raise ValueError("QNN EP does not yet support per-channel quantization.") - # Process model nodes to setup overrides. - model = onnx.load_model(model_input) + model = onnx.load_model(model_input, load_external_data=False) op_types = set() tensor_quant_overrides = {} + model_has_external_data = False + name_to_initializer = {} - name_to_initializer = {initializer.name: initializer for initializer in model.graph.initializer} + # Build map of initializers (name -> initializer) and + # check if the model has external data. + for initializer in model.graph.initializer: + name_to_initializer[initializer.name] = initializer + if onnx.external_data_helper.uses_external_data(initializer): + model_has_external_data = True + # Setup quantization overrides for specific operator types for node in model.graph.node: op_types.add(node.op_type) @@ -89,5 +97,6 @@ def get_qnn_qdq_config( activation_type=activation_type, weight_type=weight_type, op_types_to_quantize=list(op_types.difference(OP_TYPES_TO_EXCLUDE)), + use_external_data_format=(model_has_external_data or model.ByteSize() >= MODEL_SIZE_THRESHOLD), extra_options=extra_options, ) diff --git a/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py new file mode 100644 index 0000000000000..9b67fd41caac3 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import math +import unittest +from pathlib import Path + +import numpy as np +import onnx + +from onnxruntime.quantization.execution_providers.qnn import qnn_preprocess_model +from onnxruntime.quantization.quant_utils import model_has_external_data, ms_domain + + +class TestQnnPreprocessModel(unittest.TestCase): + def build_model(self, shape, scale_val, bias_val): + """ + Build a model that supports 3 kinds of fusions: + - Erf sequence to Gelu + - ReduceL2 sequence to LpNormalization + - ReduceMean sequence to LayerNormalization + """ + root_inp = onnx.helper.make_tensor_value_info("root", onnx.TensorProto.FLOAT, shape) + output = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, shape) + + # Erf sequence + one_const = onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), "one_const") + half_const = onnx.numpy_helper.from_array(np.array(0.5, dtype=np.float32), "half_const") + root2_const = onnx.numpy_helper.from_array(np.array(math.sqrt(2.0), dtype=np.float32), "root2_const") + + e_mul0_node = onnx.helper.make_node("Mul", ["root", "half_const"], ["e_mul0_out"]) + e_div_node = onnx.helper.make_node("Div", ["root", "root2_const"], ["e_div_out"]) + e_erf_node = onnx.helper.make_node("Erf", ["e_div_out"], ["e_erf_out"]) + e_add_node = onnx.helper.make_node("Add", ["e_erf_out", "one_const"], ["e_add_out"]) + e_mul1_node = onnx.helper.make_node("Mul", ["e_add_out", "e_mul0_out"], ["erf_seq_output"]) + + # ReduceL2 sequence + axes_const = onnx.numpy_helper.from_array(np.array([-1], dtype=np.int64), "axes_const") + eps_const = onnx.numpy_helper.from_array(np.array(1e-12, dtype=np.float32), "eps_const") + shape_const = onnx.numpy_helper.from_array(np.array(list(shape), dtype=np.int64), "shape_const") + + l2_rl2_node = onnx.helper.make_node("ReduceL2", ["erf_seq_output", "axes_const"], ["l2_rl2_out"], keepdims=1) + l2_clip_node = onnx.helper.make_node("Clip", ["l2_rl2_out", "eps_const"], ["l2_clip_out"]) + l2_expand_node = onnx.helper.make_node("Expand", ["l2_clip_out", "shape_const"], ["l2_expand_out"]) + l2_div_node = onnx.helper.make_node("Div", ["erf_seq_output", "l2_expand_out"], ["l2_seq_output"]) + + # ReduceMean sequence + scale_const = onnx.numpy_helper.from_array(np.array(scale_val, dtype=np.float32), "scale_const") + bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const") + two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const") + + m_rm0_node = onnx.helper.make_node("ReduceMean", ["l2_seq_output", "axes_const"], ["m_rm0_out"]) + m_sub_node = onnx.helper.make_node("Sub", ["l2_seq_output", "m_rm0_out"], ["m_sub_out"]) + m_pow_node = onnx.helper.make_node("Pow", ["m_sub_out", "two_const"], ["m_pow_out"]) + m_rm1_node = onnx.helper.make_node("ReduceMean", ["m_pow_out", "axes_const"], ["m_rm1_out"]) + m_add0_node = onnx.helper.make_node("Add", ["m_rm1_out", "eps_const"], ["m_add0_out"]) + m_sqrt_node = onnx.helper.make_node("Sqrt", ["m_add0_out"], ["m_sqrt_out"]) + m_div_node = onnx.helper.make_node("Div", ["m_sub_out", "m_sqrt_out"], ["m_div_out"]) + m_mul_node = onnx.helper.make_node("Mul", ["m_div_out", "scale_const"], ["m_mul_out"]) + m_add1_node = onnx.helper.make_node("Add", ["m_mul_out", "bias_const"], ["output"]) + + graph = onnx.helper.make_graph( + [ + e_mul0_node, + e_div_node, + e_erf_node, + e_add_node, + e_mul1_node, + l2_rl2_node, + l2_clip_node, + l2_expand_node, + l2_div_node, + m_rm0_node, + m_sub_node, + m_pow_node, + m_rm1_node, + m_add0_node, + m_sqrt_node, + m_div_node, + m_mul_node, + m_add1_node, + ], + "qnn_f32_model", + [root_inp], + [output], + initializer=[ + one_const, + half_const, + root2_const, + axes_const, + eps_const, + shape_const, + scale_const, + bias_const, + two_const, + ], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_all_fusions(self): + """ + Test calling qnn_preprocess_model() with a model that supports all 3 fusions. + """ + model = self.build_model((1, 2, 3), [2.0, 2.0, 2.0], [1.0, 1.0, 1.0]) + onnx.save_model(model, "model.onnx") + modified = qnn_preprocess_model("model.onnx", "model.qnn_pp.onnx", fuse_layernorm=True) + + self.assertTrue(modified) + + fused_model = onnx.load_model("model.qnn_pp.onnx") + + # 3 fused Ops: Gelu, LpNorm, LayerNorm + self.assertEqual(len(fused_model.graph.node), 3) + expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"} + for node in fused_model.graph.node: + self.assertIn(node.op_type, expected_op_types) + + # Should have added "com.microsoft" opset import because we added a Gelu. + ms_domain_opset = next((opset for opset in fused_model.opset_import if opset.domain == ms_domain), None) + self.assertIsNotNone(ms_domain_opset) + self.assertEqual(ms_domain_opset.version, 1) + + def test_external_data(self): + """ + Test calling qnn_preprocess_model() with a model that uses external data. + The new preprocessed model should also have external data. + """ + model = self.build_model((1, 2, 3), [2.0, 2.0, 2.0], [1.0, 1.0, 1.0]) + onnx.save_model( + model, + "model.onnx", + save_as_external_data=True, + all_tensors_to_one_file=True, + location="weights.bin", + size_threshold=0, + ) + modified = qnn_preprocess_model( + "model.onnx", + "model.qnn_pp.onnx", + fuse_layernorm=True, + save_as_external_data=True, + all_tensors_to_one_file=True, + external_data_location="weights2.bin", + external_data_size_threshold=0, + ) + + self.assertTrue(modified) + + # Model should still have external data. + self.assertTrue(model_has_external_data(Path("model.qnn_pp.onnx"))) + + fused_model = onnx.load_model("model.qnn_pp.onnx", load_external_data=False) + + # 3 fused Ops: Gelu, LpNorm, LayerNorm + self.assertEqual(len(fused_model.graph.node), 3) + expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"} + for node in fused_model.graph.node: + self.assertIn(node.op_type, expected_op_types) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index 0470953e385b6..cbb6b3ae2e776 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -555,6 +555,36 @@ def test_get_qnn_qdq_config(self): self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16) self.assertEqual(sig_out_sc.float_data[0], np.float32(1.0 / 65536.0)) + def test_get_qnn_qdq_config_ext_data(self): + """ + Test that get_qnn_qdq_config() returns a config that enables external data + if the input model has external data. + """ + + # Create model with a weight large enough (> 1024 bytes) to be stored externally. + large_weight = onnx.numpy_helper.from_array(np.random.random((1, 32, 32)).astype(np.float32), "weight") + graph = onnx.helper.make_graph( + [onnx.helper.make_node("Add", ["input", "weight"], ["output"])], + "add_ext_data", + [onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, (1, 32, 32))], + [onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, (1, 32, 32))], + initializer=[large_weight], + ) + model = onnx.helper.make_model( + graph, + opset_imports=[onnx.helper.make_opsetid("", 18)], + ) + onnx.save_model( + model, + "add_ext_data.onnx", + save_as_external_data=True, + all_tensors_to_one_file=True, + location="add_ext_data.bin", + ) + + qnn_config = get_qnn_qdq_config("add_ext_data.onnx", DummyDataReader(self.activations)) + self.assertTrue(qnn_config.use_external_data_format) + if __name__ == "__main__": t = TestTensorQuantOverridesOption() From a93c31e3c9971063d8dfe45a627a80cbdcf99ed9 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 28 Feb 2024 12:03:17 -0800 Subject: [PATCH 166/207] Update dml-vs-2022.yml (#19687) ### Description Fix a build error in "Zip-Nuget-Java-Nodejs Packaging Pipeline" which deletes files too early. --- .../nuget/templates/dml-vs-2022.yml | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 9393fb07d718a..d6bb415a68ee6 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -55,6 +55,9 @@ stages: - checkout: self clean: true submodules: recursive + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() - powershell: | if($env:TELEMETRYGUID) @@ -231,14 +234,7 @@ stages: searchPattern: '**/*.pdb' symbolServerType: teamServices - - ${{ if eq(parameters['DoCompliance'], 'true') }}: - - template: ../../templates/compliance.yml - parameters : - msbuildPlatform: ${{ parameters.sln_platform }} - - template: ../../templates/component-governance-component-detection-steps.yml - parameters : - condition : 'succeeded' # Node.js Publish - ${{ if eq(parameters['DoNodejsPack'], 'true') }}: @@ -294,6 +290,12 @@ stages: targetPath: '$(Build.SourcesDirectory)\js\node\bin\napi-v3\win32\${{ parameters.sln_platform }}' artifactName: 'drop-onnxruntime-nodejs-win-${{ parameters.sln_platform }}-dml' - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() + + - ${{ if eq(parameters['DoCompliance'], 'true') }}: + - template: ../../templates/compliance.yml + parameters : + msbuildPlatform: ${{ parameters.sln_platform }} + + - template: ../../templates/component-governance-component-detection-steps.yml + parameters : + condition : 'succeeded' From e30618d05535d3fe0fdc34d350d78e8ad01b64d5 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 28 Feb 2024 16:05:08 -0800 Subject: [PATCH 167/207] [js/webgpu] use Headless for webgpu test by default (#19702) ### Description use Chromium Headless for webgpu test by default. Still use normal Chromium with window when debug=true or perfMode=true. Use the [`--headless=new`](https://developer.chrome.com/docs/chromium/new-headless) mode. ### Motivation and Context try to use a more stable way to launch npm tests to avoid a "chrome not found" issue in pipeline, which may potentially caused by windowed application. --- js/web/karma.conf.js | 4 ++-- js/web/script/test-runner-cli.ts | 29 +++++++---------------------- 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/js/web/karma.conf.js b/js/web/karma.conf.js index 8fce79843f617..9e44d9c0d9652 100644 --- a/js/web/karma.conf.js +++ b/js/web/karma.conf.js @@ -86,11 +86,11 @@ module.exports = function(config) { hostname, listenAddress, customLaunchers: { - // the following flags are used to make sure Edge on CI agents to initialize WebGPU correctly. + // Chromium-based browsers EdgeTest: {base: 'Edge', flags: chromiumFlags}, ChromeTest: {base: 'Chrome', flags: chromiumFlags}, - ChromeTestHeadless: {base: 'ChromeHeadless', flags: chromiumFlags}, ChromeCanaryTest: {base: 'ChromeCanary', flags: chromiumFlags}, + // // ==== BrowserStack browsers ==== // diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index 9105c02412e34..59bd0d5f6313a 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -542,14 +542,13 @@ async function main() { npmlog.info('TestRunnerCli.Run', '(4/4) Running karma to start test runner...'); const webgpu = args.backends.indexOf('webgpu') > -1; const webnn = args.backends.indexOf('webnn') > -1; - const browser = getBrowserNameFromEnv( - args.env, - args.bundleMode === 'perf' ? 'perf' : - args.debug ? 'debug' : - 'test', - webgpu); + const browser = getBrowserNameFromEnv(args.env); const karmaArgs = ['karma', 'start', `--browsers ${browser}`]; const chromiumFlags = ['--enable-features=SharedArrayBuffer', ...args.chromiumFlags]; + if (args.bundleMode === 'dev' && !args.debug) { + // use headless for 'test' mode (when 'perf' and 'debug' are OFF) + chromiumFlags.push('--headless=new'); + } if (args.debug) { karmaArgs.push('--log-level info --timeout-mocha 9999999'); chromiumFlags.push('--remote-debugging-port=9333'); @@ -662,10 +661,10 @@ async function main() { fs.writeJSONSync(path.join(TEST_ROOT, './testdata-config.json'), config); } - function getBrowserNameFromEnv(env: TestRunnerCliArgs['env'], mode: 'debug'|'perf'|'test', webgpu: boolean) { + function getBrowserNameFromEnv(env: TestRunnerCliArgs['env']) { switch (env) { case 'chrome': - return selectChromeBrowser(mode, webgpu); + return 'ChromeTest'; case 'edge': return 'EdgeTest'; case 'firefox': @@ -680,20 +679,6 @@ async function main() { throw new Error(`env "${env}" not supported.`); } } - - function selectChromeBrowser(mode: 'debug'|'perf'|'test', webgpu: boolean) { - if (webgpu) { - return 'ChromeTest'; - } else { - switch (mode) { - case 'debug': - case 'perf': - return 'ChromeTest'; - default: - return 'ChromeTestHeadless'; - } - } - } } void main(); From 250779474de0ce50f0ef4b39f7b050755e1019ba Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 28 Feb 2024 19:36:26 -0800 Subject: [PATCH 168/207] Change "onnxruntime-Linux-CPU-For-Android-CI" machine pool to "onnxruntime-Ubuntu2204-AMD-CPU" (#19698) ### Description The original one reports "out of disk space", which needs to be investigated. --- .../android-x86_64-crosscompile-ci-pipeline.yml | 6 +++--- .../azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml | 2 +- .../github/azure-pipelines/mac-react-native-ci-pipeline.yml | 2 +- .../templates/android-binary-size-check-stage.yml | 3 ++- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml index 9136b21aec626..d0a22aae07741 100644 --- a/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-x86_64-crosscompile-ci-pipeline.yml @@ -53,7 +53,7 @@ stages: Codeql.Enabled: false jobs: - job: Build_CPU_EP - pool: onnxruntime-Linux-CPU-For-Android-CI + pool: onnxruntime-Ubuntu2204-AMD-CPU workspace: clean: all timeoutInMinutes: 30 @@ -140,7 +140,7 @@ stages: jobs: - job: Build_NNAPI_EP - pool: onnxruntime-Linux-CPU-For-Android-CI + pool: onnxruntime-Ubuntu2204-AMD-CPU timeoutInMinutes: ${{ variables.JobsTimeout }} workspace: clean: all @@ -456,7 +456,7 @@ stages: variables: - name: skipComponentGovernanceDetection value: true - pool: 'onnxruntime-Linux-CPU-For-Android-CI' + pool: 'onnxruntime-Ubuntu2204-AMD-CPU' condition: and(succeeded(), in(variables['Build.Reason'], 'IndividualCI', 'BatchedCI')) dependsOn: - NNAPI_EP_MASTER diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml index 1053a2518125f..bbea7a0d114e8 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-minimal-build-ci-pipeline.yml @@ -59,7 +59,7 @@ jobs: timeoutInMinutes: 120 workspace: clean: all - pool: onnxruntime-Linux-CPU-For-Android-CI + pool: onnxruntime-Ubuntu2204-AMD-CPU variables: ORT_CACHE_DIR: $(Pipeline.Workspace)/ort_ccache TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] diff --git a/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml index e8f4931d5ad9f..886bacf5aac4d 100644 --- a/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/mac-react-native-ci-pipeline.yml @@ -61,4 +61,4 @@ stages: parameters: NpmPackagingMode: ${{ variables.NpmPackagingMode }} BuildConfig: 'Release' - PoolName: 'onnxruntime-Linux-CPU-For-Android-CI' + PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' diff --git a/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml b/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml index 733cafdeeb8c0..9822950127112 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-binary-size-check-stage.yml @@ -31,7 +31,7 @@ stages: timeoutInMinutes: 60 workspace: clean: all - pool: onnxruntime-Linux-CPU-For-Android-CI + pool: onnxruntime-Ubuntu2204-AMD-CPU steps: - checkout: self clean: true @@ -49,6 +49,7 @@ stages: - task: PythonScript@0 displayName: 'Set variables from config file "${{ parameters.BuildConfigFile }}"' inputs: + pythonInterpreter: /usr/bin/python3 scriptSource: inline script: | import json From 7455dd1f32af760984f42e8e6d752b675a4a0852 Mon Sep 17 00:00:00 2001 From: Sophie Schoenmeyer <107952697+sophies927@users.noreply.github.com> Date: Wed, 28 Feb 2024 21:10:25 -0800 Subject: [PATCH 169/207] Update labeler.yml to change permissions (#19709) ### Description Updated github/issue-labeler permissions to give write access for issues. Tried to submit the same PR last week, but the checks kept failing, so I couldn't merge. ### Motivation and Context Enables issue labeling again, which has been broken since GitHub Actions permissions were changed a couple weeks ago. --- .github/workflows/labeler.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 936ab0de899a2..a196226a4b836 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -3,6 +3,9 @@ on: issues: types: [opened, edited] +permissions: + issues: write + jobs: triage: runs-on: ubuntu-latest From d2e6dd25ea8bd528f614250ba0165a535734305e Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 29 Feb 2024 13:45:58 +0800 Subject: [PATCH 170/207] Merge GatherToSplitFusion and #19218 to a General Fusion (#19600) #19218 tried to fuse Gather/Slice to Split, but the logic has problem. Scalar value or 1-dim value of indices in Gather node will produce different result, scalar value will produce a result tensor by removing the axis dim, will 1-dim indices value will keep that dim, even when the dim value is 1. For example, Node |-> Gather(indices=[0], axis=axis) |-> Gather(indices=[1], axis=axis) |-> Slice(index=2, axis=axis) is same as Node |-> Split(axis=axis) But Node |-> Gather(indices=0, axis=axis) |-> Gather(indices=1, axis=axis) |-> Slice(index=2, axis=axis) is same as Node |-> Split(axis=axis) ||-> Squeeze(axis=axis) ||-> Squeeze(axis=axis) ||-> Previous PR doesn't take such case related to Squeeze/Unsqueeze into account. This PR merges #19218 and GatherToSplitFusion to a general fusion, which relaxes the limit the number of Gather and Slice node number, check all Gather and Slice consumers, if the indices of Gather and start/end of Slice can cover the specific dim of the input tensor, then we can fuse them to a Split, and adding Squeeze if necessary according to the dim count of the indices tensor in Gather. @rui-ren, please check if the fix can still be applied to your model. --- onnxruntime/core/optimizer/gather_fusion.cc | 318 ++++++---- onnxruntime/core/optimizer/gather_fusion.h | 16 +- .../core/optimizer/gather_slice_fusion.cc | 344 ----------- .../core/optimizer/gather_slice_fusion.h | 32 - .../core/optimizer/graph_transformer_utils.cc | 4 +- .../test/optimizer/graph_transform_test.cc | 550 +++++------------- .../core/optimizer/graph_transformer_utils.cc | 4 +- 7 files changed, 352 insertions(+), 916 deletions(-) delete mode 100644 onnxruntime/core/optimizer/gather_slice_fusion.cc delete mode 100644 onnxruntime/core/optimizer/gather_slice_fusion.h diff --git a/onnxruntime/core/optimizer/gather_fusion.cc b/onnxruntime/core/optimizer/gather_fusion.cc index 4903bc1d6b961..90cabff88122c 100644 --- a/onnxruntime/core/optimizer/gather_fusion.cc +++ b/onnxruntime/core/optimizer/gather_fusion.cc @@ -9,55 +9,144 @@ namespace onnxruntime { -bool GatherToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, - int64_t& indices_n_dims) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || +namespace { +static int64_t GetGatherAxis(const Node& node, int64_t rank) { + int64_t axis = 0; + auto& attrs = node.GetAttributes(); + if (attrs.find("axis") != attrs.end()) { + auto& axis_attr = attrs.at("axis"); + if (utils::HasInt(axis_attr)) { + axis = axis_attr.i(); + if (axis < 0) axis += rank; + } + } + return axis; +} + +static bool GetScalarInt64Initializer(const Graph& graph, const NodeArg& node_arg, int64_t& value, int64_t& rank) { + if (!optimizer_utils::IsScalar(node_arg)) return false; + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, node_arg.Name()); + if (!tensor_proto || tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; + Initializer init_const{*tensor_proto, graph.ModelPath()}; + value = *(init_const.data()); + rank = tensor_proto->dims_size(); + return true; +} + +static bool GetSliceAxis(const Graph& graph, const Node& node, int64_t rank, int64_t& axis) { + if (node.InputDefs().size() < 4) return false; + int64_t unused = 0; + if (!GetScalarInt64Initializer(graph, *node.InputDefs()[3], axis, unused)) return false; + if (axis < 0) axis += rank; + return true; +} + +static bool GetAxis(const Graph& graph, const Node& node, int64_t rank, int64_t& axis) { + if (node.OpType() == "Gather") { + axis = GetGatherAxis(node, rank); + return true; + } + if (node.OpType() == "Slice") { + return GetSliceAxis(graph, node, rank, axis); + } + return false; +} + +} // namespace + +bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t rank, + int64_t target_axis, int64_t dim_size, InlinedVector& consumed, + int64_t& start, bool& need_squeeze) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {13}) || !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { return false; } - const NodeArg& input_arg = *(node.InputDefs()[1]); - if (!optimizer_utils::IsScalar(input_arg)) return false; - const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); - if (!tensor_proto) return false; - if (tensor_proto->data_type() != ONNX_NAMESPACE::TensorProto_DataType_INT64) return false; - Initializer init_const{*tensor_proto, graph.ModelPath()}; - index = *(init_const.data()); - axis = 0; // Default value. - auto& attrs = node.GetAttributes(); - if (attrs.find("axis") != attrs.end()) { - auto& axis_attr = attrs.at("axis"); - if (utils::HasInt(axis_attr)) axis = axis_attr.i(); + if (GetGatherAxis(node, rank) != target_axis) return false; + // Require the indices input to be a scalar tensor for now. Normally if not, the exporter will choose Slice. + // We can relax this later if needed. + int64_t indices_n_dims = 0; + if (!GetScalarInt64Initializer(graph, *(node.InputDefs()[1]), start, indices_n_dims)) return false; + if (start < 0) start += dim_size; + if (start < 0 || start >= dim_size || consumed[static_cast(start)]) return false; + consumed[static_cast(start)] = true; + need_squeeze = indices_n_dims == 0; + return true; +} + +bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, + int64_t dim_size, InlinedVector& consumed, int64_t& start, + int64_t& end) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + int64_t axis = 0; + if (!GetSliceAxis(graph, node, rank, axis) || axis != target_axis) return false; + int64_t unused = 0; + if (!GetScalarInt64Initializer(graph, *node.InputDefs()[1], start, unused) || + !GetScalarInt64Initializer(graph, *node.InputDefs()[2], end, unused)) { + return false; + } + // Handling start and end according to schema definition. + if (start < 0) start += dim_size; + if (end < 0) end += dim_size; + if (start < 0) + start = 0; + else if (start > dim_size) + start = dim_size; + if (end < 0) + end = 0; + else if (end > dim_size) + end = dim_size; + if (start >= end) return false; + if (node.InputDefs().size() >= 5) { + int64_t step = 0; + if (!GetScalarInt64Initializer(graph, *node.InputDefs()[4], step, unused) || step != 1) return false; + } + for (int64_t i = start; i < end; ++i) { + if (consumed[static_cast(i)]) return false; + consumed[static_cast(i)] = true; } - indices_n_dims = tensor_proto->dims_size(); return true; } /* -GatherToSplitFusion is to fuse: -Node -> Gather(index=0, axis=axis) - |-> Gather(index=1, axis=axis) - |-> Gather(index=2, axis=axis) +GatherSliceToSplitFusion is to fuse: +Node -> Gather(indices=0, axis=axis) + |-> Gather(indices=[1], axis=axis) + |-> Slice(start=2, end=3, axes=[axis]) |... To Node -> Split -> Squeeze(axis=axis) - |-> Squeeze(axis=axis) - |-> Squeeze(axis=axis) + |-> + |-> |... So that we can use one kernel to finish the job. +The fusion requires that the indices of Gather nodes and start/end of Slice nodes are not overlapping and cover +all the elements in the target axis. Step of Slice node should be 1. */ -Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { +Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + // Squeeze, Gather, Slice and Split have different schemas before and after OpSet 13. + // To make code simple, support OpSet >= 13 only. + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + if (onnx_opset_version < 13) return Status::OK(); + GraphViewer graph_viewer(graph); const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - InlinedVector node_args; + InlinedVector candidate_args; for (auto node_arg : graph.GetInputs()) { if (node_arg && graph.GetConsumerNodes(node_arg->Name()).size() > 1) { - node_args.push_back(node_arg); + candidate_args.push_back(node_arg); } } @@ -65,7 +154,7 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le if (graph.GetConsumerNodes(entry.first).size() > 1) { auto node_arg = graph.GetNodeArg(entry.first); if (node_arg) { - node_args.push_back(node_arg); + candidate_args.push_back(node_arg); } } } @@ -90,129 +179,108 @@ Status GatherToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le size_t output_count = node.GetOutputEdgesCount(); if (output_count <= 1) continue; - node_args.push_back(node.OutputDefs()[0]); + candidate_args.push_back(node.OutputDefs()[0]); } - for (const NodeArg* node_arg : node_args) { + for (const NodeArg* node_arg : candidate_args) { auto shape = node_arg->Shape(); if (!shape) continue; int64_t rank = static_cast(shape->dim_size()); - - bool can_fuse = true; - bool first_edge = true; - int64_t split_axis = 0; - int64_t indices_n_dims = -1; auto consumers = graph.GetConsumerNodes(node_arg->Name()); - size_t consumer_count = consumers.size(); - InlinedVector gather_outputs(consumer_count, nullptr); - InlinedVector> nodes_to_fuse; + InlinedVector condidate_consumers; for (auto consumer : consumers) { - int64_t index, axis, dims; - if (!consumer || consumer->InputDefs()[0] != node_arg || - !IsSupportedGather(graph, *consumer, index, axis, dims)) { - can_fuse = false; - break; - } - if (indices_n_dims == -1) { - indices_n_dims = dims; - } else if (indices_n_dims != dims) { - // Not the same number of dimensions (0 or 1) for all scalar indices. - can_fuse = false; - break; + if (consumer && consumer->InputDefs()[0] == node_arg && + (consumer->OpType() == "Gather" || consumer->OpType() == "Slice")) { + condidate_consumers.emplace_back(consumer); } - if (axis < 0) axis += rank; - if (first_edge) { - auto dim = shape->dim(static_cast(axis)); - if (!utils::HasDimValue(dim) || dim.dim_value() != static_cast(consumer_count)) { - can_fuse = false; - break; - } - split_axis = axis; - first_edge = false; - } else if (axis != split_axis) { + } + if (condidate_consumers.size() < 2) continue; + int64_t axis = 0; + if (!GetAxis(graph, *condidate_consumers[0], rank, axis)) continue; + auto dim = shape->dim(static_cast(axis)); + if (!utils::HasDimValue(dim)) continue; + int64_t dim_size = dim.dim_value(); + InlinedVector consumed(static_cast(dim_size), false); + bool can_fuse = true; + InlinedVector> nodes_to_fuse; + InlinedVector starts; + InlinedHashMap> output_info_map; + for (auto consumer : condidate_consumers) { + if (!consumer || consumer->InputDefs()[0] != node_arg) { can_fuse = false; break; } - if (index < 0) index += static_cast(consumer_count); - if (index < 0 || index >= static_cast(consumer_count) || gather_outputs[static_cast(index)]) { + int64_t start = 0, end = 0; + bool need_squeeze = false; + if (IsSupportedGather(graph, *consumer, rank, axis, dim_size, consumed, start, need_squeeze)) { + Node& gather_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.emplace_back(gather_node); + starts.emplace_back(start); + output_info_map[start] = std::make_tuple(gather_node.MutableOutputDefs()[0], 1, need_squeeze); + } else if (IsSupportedSlice(graph, *consumer, rank, axis, dim_size, consumed, start, end)) { + Node& slice_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.emplace_back(slice_node); + starts.emplace_back(start); + output_info_map[start] = std::make_tuple(slice_node.MutableOutputDefs()[0], end - start, false); + } else { can_fuse = false; break; } - Node& gather_node = *graph.GetNode(consumer->Index()); - nodes_to_fuse.emplace_back(gather_node); - gather_outputs[static_cast(index)] = gather_node.MutableOutputDefs()[0]; - } - - if (!can_fuse) continue; - - ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = - static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); - split_output_type.mutable_tensor_type()->set_elem_type(element_type); - for (int64_t i = 0; i < rank; ++i) { - if (i == split_axis) { - split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); - } else { - *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); - } } + if (!can_fuse || std::find(consumed.begin(), consumed.end(), false) != consumed.end()) continue; + std::sort(starts.begin(), starts.end()); InlinedVector split_outputs; - bool add_squeeze_node = indices_n_dims == 0; - if (add_squeeze_node) { - for (size_t i = 0; i < consumer_count; ++i) { - split_outputs.emplace_back( - &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split" + std::to_string(i)), &split_output_type)); - } - } - - Node& split_node = - graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", - {graph.GetNodeArg(node_arg->Name())}, add_squeeze_node ? split_outputs : gather_outputs); - split_node.AddAttribute("axis", split_axis); - split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); - - // Squeeze-11, Squeee-13, Split-13, Split-18 have different schemas. - int onnx_opset_version = -1; - if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { - onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); - } - - if (onnx_opset_version < 13) { - if (add_squeeze_node) { - for (size_t i = 0; i < consumer_count; ++i) { - Node& squeeze_node = graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", - "Squeeze for Fused Gather nodes", {split_outputs[i]}, {gather_outputs[i]}); - squeeze_node.AddAttribute("axes", std::vector{split_axis}); - squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + InlinedVector split_values; + for (int64_t start : starts) { + auto& output_info = output_info_map[start]; + NodeArg* original_output_arg = std::get<0>(output_info); + int64_t split_value = std::get<1>(output_info); + split_values.emplace_back(split_value); + if (std::get<2>(output_info)) { + ONNX_NAMESPACE::TypeProto split_output_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = + static_cast(node_arg->TypeAsProto()->tensor_type().elem_type()); + split_output_type.mutable_tensor_type()->set_elem_type(element_type); + for (int64_t i = 0; i < rank; ++i) { + if (i == axis) { + split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(split_value); + } else { + *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); + } } - } - } else { - if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(consumer_count)); - } - - if (add_squeeze_node) { + NodeArg* split_output_arg = + &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("split_output"), &split_output_type); ONNX_NAMESPACE::TensorProto axes_initializer_proto; - axes_initializer_proto.set_name(graph.GenerateNodeName("SqueezeAxesInitializer")); + axes_initializer_proto.set_name(graph.GenerateNodeName("squeeze_axes")); axes_initializer_proto.add_dims(static_cast(1)); axes_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - InlinedVector axes_value{split_axis}; - axes_initializer_proto.set_raw_data(axes_value.data(), axes_value.size() * sizeof(int64_t)); + axes_initializer_proto.add_int64_data(axis); NodeArg* axes_arg = &graph_utils::AddInitializer(graph, axes_initializer_proto); - - for (size_t i = 0; i < consumer_count; ++i) { - Node& squeeze_node = - graph.AddNode(graph.GenerateNodeName("Squeeze" + std::to_string(i)), "Squeeze", - "Squeeze for Fused Gather nodes", {split_outputs[i], axes_arg}, {gather_outputs[i]}); - squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); - } + Node& squeeze_node = + graph.AddNode(graph.GenerateNodeName("Squeeze"), "Squeeze", "Squeeze for Fused Gather nodes", + {split_output_arg, axes_arg}, {original_output_arg}); + squeeze_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + split_outputs.emplace_back(split_output_arg); + } else { + split_outputs.emplace_back(original_output_arg); } } - for (Node& n : nodes_to_fuse) { - graph_utils::RemoveNodeOutputEdges(graph, n); - graph.RemoveNode(n.Index()); + ONNX_NAMESPACE::TensorProto split_initializer_proto; + split_initializer_proto.set_name(graph.GenerateNodeName("splits")); + split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + split_initializer_proto.add_dims(static_cast(split_values.size())); + split_initializer_proto.mutable_int64_data()->Add(split_values.begin(), split_values.end()); + NodeArg* split_initializer_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + Node& split_node = graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for Fused Gather nodes", + {graph.GetNodeArg(node_arg->Name()), split_initializer_arg}, split_outputs); + split_node.AddAttribute("axis", axis); + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + + for (Node& node : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, node); + graph.RemoveNode(node.Index()); } modified = true; diff --git a/onnxruntime/core/optimizer/gather_fusion.h b/onnxruntime/core/optimizer/gather_fusion.h index 44c235915b6cc..098278a77dafe 100644 --- a/onnxruntime/core/optimizer/gather_fusion.h +++ b/onnxruntime/core/optimizer/gather_fusion.h @@ -8,19 +8,23 @@ namespace onnxruntime { /** -@Class GatherToSplitFusion +@Class GatherSliceToSplitFusion -Fuse multiple Gather nodes that comsuming one output to one Split node. +Fuse multiple Gather/Slice nodes that comsuming one output to one Split node. */ -class GatherToSplitFusion : public GraphTransformer { +class GatherSliceToSplitFusion : public GraphTransformer { public: - GatherToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("GatherToSplitFusion", compatible_execution_providers) {} + GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; private: - bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, int64_t& indices_n_dims) const; + bool IsSupportedGather(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size, + InlinedVector& consumed, int64_t& start, bool& need_squeeze) const; + + bool IsSupportedSlice(const Graph& graph, const Node& node, int64_t rank, int64_t target_axis, int64_t dim_size, + InlinedVector& consumed, int64_t& start, int64_t& end) const; }; /** diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc deleted file mode 100644 index 21266d356a020..0000000000000 --- a/onnxruntime/core/optimizer/gather_slice_fusion.cc +++ /dev/null @@ -1,344 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/optimizer/gather_slice_fusion.h" -#include "core/graph/graph_utils.h" -#include "core/optimizer/initializer.h" -#include "core/optimizer/utils.h" - -namespace onnxruntime { - -bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, - int64_t& axis, int64_t& indices_n_dims) const { - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || - !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { - return false; - } - - const NodeArg& input_arg = *(node.InputDefs()[1]); - - if (!optimizer_utils::IsScalar(input_arg)) return false; - - const ONNX_NAMESPACE::TensorProto* indices_init = graph_utils::GetConstantInitializer(graph, input_arg.Name()); - - if (!indices_init) return false; - - if (indices_init->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; - - // get the index value - Initializer init_const(*indices_init, graph.ModelPath()); - index = *(init_const.data()); - - // get attributes value - axis = 0; - auto& attrs = node.GetAttributes(); - if (attrs.find("axis") != attrs.end()) { - auto& axis_attr = attrs.at("axis"); - if (utils::HasInt(axis_attr)) axis = axis_attr.i(); - } - - indices_n_dims = indices_init->dims_size(); - return true; -} - -bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, - InlinedVector& starts, - InlinedVector& ends, - InlinedVector& axes, - InlinedVector& steps) const { - // check the version of Slice ops - if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}) || - !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { - return false; - } - - // get the opset version - int onnx_opset_version = -1; - if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { - onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); - } - - // If Slice op of opset version 1 - if (onnx_opset_version == 1) { - if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) || - !graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) || - starts.size() != ends.size()) { - return false; - } - - if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { - return false; - } - } - - // If Slice op of opset version >= 10 - if (onnx_opset_version >= 10) { - // node inputs include: starts - ends - axes - steps - - // return a pointer to the corresponding NodeArg if input of the node at the index exists - auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* { - const auto& input_defs = node.InputDefs(); - const NodeArg* input = (input_defs.size() > input_index) ? input_defs[input_index] : nullptr; - return (input == nullptr || !input->Exists()) ? nullptr : input; - }; - - // return a pointer to the initializer if it is constant; otherwise, a nullptr - auto get_initializer_if_constant = - [&graph, get_input_if_exists](size_t input_index) -> const ONNX_NAMESPACE::TensorProto* { - const NodeArg* input = get_input_if_exists(input_index); - return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr; - }; - - // return the initialization data if it is constant - auto get_initializer_data = - [&graph](const ONNX_NAMESPACE::TensorProto* slice_initializer) -> InlinedVector { - Initializer init(*slice_initializer, graph.ModelPath()); - if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { - int32_t* init_data = init.data(); - return InlinedVector(init_data, init_data + init.size()); - } - - if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) { - int64_t* init_data = init.data(); - return InlinedVector(init_data, init_data + init.size()); - } - return {}; - }; - - // starts and ends inputs have to exist, be constants and be of the same size. - const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1); - const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2); - const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3); - const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4); - - if (!starts_init || !ends_init || !axes_init || !steps_init) { - return false; - } - - starts = get_initializer_data(starts_init); - ends = get_initializer_data(ends_init); - axes = get_initializer_data(axes_init); - steps = get_initializer_data(steps_init); - - if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) { - return false; - } - - if (axes_init->dims_size() != 1 || static_cast(axes_init->dims().Get(0)) != starts.size()) { - return false; - } - - // if steps exists, it should be constant and all value should be 1 - if (steps.size() != starts.size()) { - return false; - } - - for (int64_t step : steps) { - if (step != 1) { - return false; - } - } - } - - return true; -} - -/* -GatherToSplitFusion is to fuse: - Node - |-> Gather(index=0, axis=axis) - |-> Gather(index=1, axis=axis) - |-> Slice(index=2, axis=axis) -To - Node - |-> Split(index=0) -So that we can use one kernel to finish the job. -*/ - -Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, - const logging::Logger& logger) const { - GraphViewer graph_viewer(graph); - - const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); - - InlinedVector output_args; - - // Iterate the topological order and get Reshape ops - for (auto node_index : node_topology_list) { - auto* p_node = graph.GetNode(node_index); - - if (p_node == nullptr) continue; - - Node& node = *p_node; - - ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); - - // Currently only catch after Reshape ops, optimize in the future - if (node.OpType() != "Reshape") continue; - - size_t output_count = node.GetOutputEdgesCount(); - - // We only catch 1 scenario for Multi Query Attention for now. - // |---> Gather - // Reshape |---> Gather - // |---> Slice - // |... or (other ops) - - // Get the output into node args - if (output_count < 3) continue; - - output_args.push_back(node.OutputDefs()[0]); - } - - // iterate the children of Reshape node - for (const NodeArg* node_arg : output_args) { - auto shape = node_arg->Shape(); - if (!shape) continue; - - auto consumers = graph.GetConsumerNodes(node_arg->Name()); - size_t consumer_count = consumers.size(); - - // get the tensor rank - int64_t rank = static_cast(shape->dim_size()); - - bool can_fuse = true; - bool first_edge = true; - int64_t split_axis = 0; - int64_t indices_n_dims = -1; - - // Fuse 2 Gathers and 1 slice to Split - // Get those outputs as Split outputs - InlinedVector split_outputs(3); - - InlinedVector> nodes_to_fuse; - size_t gather_node_count = 2, slice_node_count = 0; - - // find the nodes to be merged - for (auto consumer : consumers) { - int64_t index, axis, dims; - InlinedVector starts, ends, axes, steps; - - bool IsSupportedGatherOps = IsSupportedGather(graph, *consumer, index, axis, dims); - bool IsSupportedSliceOps = IsSupportedSlice(graph, *consumer, starts, ends, axes, steps); - - if ((!consumer || consumer->InputDefs()[0] != node_arg) || - (!IsSupportedGatherOps && !IsSupportedSliceOps)) { - break; - } - - if (IsSupportedGatherOps) { - if (indices_n_dims == -1) { - indices_n_dims = dims; - } else if (indices_n_dims != dims) { - // Not the same number of dimensions (0 or 1) for all scalar indices. - can_fuse = false; - break; - } - - if (axis < 0) axis += rank; - - if (first_edge) { - auto dim = shape->dim(static_cast(axis)); - // dim.dim_value() = 73 - if (!utils::HasDimValue(dim)) { - can_fuse = false; - break; - } - split_axis = axis; - first_edge = false; - } else if (axis != split_axis) { - can_fuse = false; - break; - } - - if (index < 0) index += static_cast(consumer_count); - if (index < 0 || index >= static_cast(consumer_count)) { - can_fuse = false; - break; - } - - Node& gather_node = *graph.GetNode(consumer->Index()); - nodes_to_fuse.push_back(gather_node); - NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; - split_outputs[gather_node_count--] = gather_output_args; - } - - // check the Slice Ops - if (IsSupportedSliceOps) { - if (axes[0] != axis && !first_edge) { - can_fuse = false; - break; - } - - Node& slice_node = *graph.GetNode(consumer->Index()); - NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; - nodes_to_fuse.push_back(slice_node); - split_outputs[slice_node_count++] = slice_output_args; - } - } - - // condition check - if (!can_fuse || gather_node_count != 0 || slice_node_count != 1) continue; - - // generate the split node and merge the kernel - ONNX_NAMESPACE::TypeProto split_output_type; - const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( - node_arg->TypeAsProto()->tensor_type().elem_type()); - - split_output_type.mutable_tensor_type()->set_elem_type(element_type); - - for (int64_t i = 0; i < rank; i++) { - if (i == split_axis) - split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); - else - *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); - } - - InlinedVector split_output_types; - - for (size_t i = 0; i < consumer_count; ++i) { - split_output_types.push_back( - &graph.GetOrCreateNodeArg( - graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type)); - } - - // Generate the Split Node - ONNX_NAMESPACE::TensorProto split_initializer_proto; - split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); - split_initializer_proto.add_dims(static_cast(3)); - split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - - auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); - // Optimize 2 Gather Nodes, so Slice_dim = dim_value - 2 - int64_t slice_dim = static_cast(dim_value - 2); - InlinedVector split_value{{slice_dim, 1, 1}}; - split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); - NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); - - Node& split_node = - graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", - {graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs); - - split_node.AddAttribute("axis", split_axis); - - split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); - - int onnx_opset_version = -1; - if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { - onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); - } - - if (onnx_opset_version >= 18) { - split_node.AddAttribute("num_outputs", static_cast(consumer_count)); - } - - for (Node& node_to_fuse : nodes_to_fuse) { - graph_utils::RemoveNodeOutputEdges(graph, node_to_fuse); - graph.RemoveNode(node_to_fuse.Index()); - } - modified = true; - } - - return Status::OK(); -} -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.h b/onnxruntime/core/optimizer/gather_slice_fusion.h deleted file mode 100644 index 1c5c307efed7f..0000000000000 --- a/onnxruntime/core/optimizer/gather_slice_fusion.h +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/optimizer/graph_transformer.h" - -namespace onnxruntime { - -/** -@class GatherSliceToSplitFusion -Fuse (2 Gather nodes + 1 Slice) to 1 split node. -*/ - -class GatherSliceToSplitFusion : public GraphTransformer { - private: - bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, - int64_t& indices_n_dims) const; - - bool IsSupportedSlice(const Graph& graph, const Node& node, - InlinedVector& starts, - InlinedVector& ends, - InlinedVector& axes, - InlinedVector& steps) const; - - public: - GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept - : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} - - Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; -}; -} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 4e939fe3c7b6b..8376b87aee6b2 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -37,7 +37,6 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" -#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -307,9 +306,8 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); - transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index e1fcf835c6043..16f38bac62713 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -42,7 +42,6 @@ #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/gather_fusion.h" -#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -7059,13 +7058,13 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShouldNotShareForGraphOutput) { } } -TEST_F(GraphTransformationTests, GatherToSplitFusion) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_AllGather) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{54}}); auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_2 = builder.MakeInitializer({1}, {static_cast(1)}); auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); auto* gather_out_1 = builder.MakeIntermediate(); auto* gather_out_2 = builder.MakeIntermediate(); @@ -7082,7 +7081,8 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) .AddAttribute("axis", static_cast(2)); builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); }; @@ -7091,27 +7091,16 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { return Status::OK(); }; - // OpSet-12 + // OpSet-12, not support { auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axes").ints().at(0))); - } - } + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); + std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } @@ -7121,7 +7110,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { auto post_graph_checker = [&](Graph& graph) { TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 2); for (auto& node : graph.Nodes()) { if (node.OpType() == "Split") { auto& attrs = node.GetAttributes(); @@ -7140,249 +7129,140 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); + std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } - - // OpSet-18 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - const NodeArg& input_arg = *(node.InputDefs()[1]); - const ONNX_NAMESPACE::TensorProto* tensor_proto = - graph_utils::GetConstantInitializer(graph, input_arg.Name()); - TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; - TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); - TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } } -TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_AllSlice_GraphInput) { auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{4}}); - auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); - auto* gather_index_1 = builder.MakeInitializer({1}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({1}, {static_cast(1)}); - auto* gather_index_3 = builder.MakeInitializer({1}, {static_cast(2)}); - auto* gather_out_1 = builder.MakeIntermediate(); - auto* gather_out_2 = builder.MakeIntermediate(); - auto* gather_out_3 = builder.MakeIntermediate(); + auto* data_arg = builder.MakeInput({{2, 3, 8, 3}}); + auto* starts_1 = builder.MakeInitializer({1}, {0}); + auto* ends_1 = builder.MakeInitializer({1}, {2}); + auto* axes_1 = builder.MakeInitializer({1}, {2}); + auto* steps_1 = builder.MakeInitializer({1}, {1}); + auto* starts_2 = builder.MakeInitializer({1}, {2}); + auto* ends_2 = builder.MakeInitializer({1}, {-2}); + auto* axes_2 = builder.MakeInitializer({1}, {-2}); + auto* steps_2 = builder.MakeInitializer({1}, {1}); + auto* starts_3 = builder.MakeInitializer({1}, {-2}); + auto* ends_3 = builder.MakeInitializer({1}, {16}); + auto* axes_3 = builder.MakeInitializer({1}, {2}); + auto* slice_out_1 = builder.MakeIntermediate(); + auto* slice_out_2 = builder.MakeIntermediate(); + auto* slice_out_3 = builder.MakeIntermediate(); auto* transpose_out_1 = builder.MakeOutput(); auto* transpose_out_2 = builder.MakeOutput(); auto* transpose_out_3 = builder.MakeOutput(); - builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); - builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(-2)); - builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Slice", {data_arg, starts_1, ends_1, axes_1, steps_1}, {slice_out_1}); + builder.AddNode("Slice", {data_arg, starts_2, ends_2, axes_2, steps_2}, {slice_out_2}); + builder.AddNode("Slice", {data_arg, starts_3, ends_3, axes_3}, {slice_out_3}); + builder.AddNode("Transpose", {slice_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_3}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); }; auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 3); return Status::OK(); }; - // OpSet-12 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } - - // OpSet-14 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } - - // OpSet-18 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); } - return Status::OK(); - }; + } + return Status::OK(); + }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); } -TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Input) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Combined) { auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{2, 3, 3, 3}}); - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); - auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); + auto* data_arg = builder.MakeInput({{144}}); + auto* shape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 8, 3, 3}}); + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(5)}); + auto* starts_2 = builder.MakeInitializer({1}, {6}); + auto* ends_2 = builder.MakeInitializer({1}, {8}); + auto* axes_2 = builder.MakeInitializer({1}, {-3}); + auto* steps_2 = builder.MakeInitializer({1}, {1}); + auto* gather_index_3 = builder.MakeInitializer({1}, {static_cast(4)}); + auto* starts_4 = builder.MakeInitializer({1}, {-16}); + auto* ends_4 = builder.MakeInitializer({1}, {4}); + auto* axes_4 = builder.MakeInitializer({1}, {1}); auto* gather_out_1 = builder.MakeIntermediate(); - auto* gather_out_2 = builder.MakeIntermediate(); + auto* slice_out_2 = builder.MakeIntermediate(); auto* gather_out_3 = builder.MakeIntermediate(); + auto* slice_out_4 = builder.MakeIntermediate(); auto* transpose_out_1 = builder.MakeOutput(); auto* transpose_out_2 = builder.MakeOutput(); auto* transpose_out_3 = builder.MakeOutput(); + auto* transpose_out_4 = builder.MakeOutput(); - builder.AddNode("Gather", {data_arg, gather_index_1}, {gather_out_1}).AddAttribute("axis", static_cast(2)); - builder.AddNode("Gather", {data_arg, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(-2)); - builder.AddNode("Gather", {data_arg, gather_index_3}, {gather_out_3}).AddAttribute("axis", static_cast(2)); + builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(1)); + builder.AddNode("Slice", {reshape_out, starts_2, ends_2, axes_2, steps_2}, {slice_out_2}); + builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) + .AddAttribute("axis", static_cast(-3)); + builder.AddNode("Slice", {reshape_out, starts_4, ends_4, axes_4}, {slice_out_4}); builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}).AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}).AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_4}, {transpose_out_4}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); }; auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 2); return Status::OK(); }; - // OpSet-12 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axes") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axes").ints().at(0))); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } - - // OpSet-14 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - const NodeArg& input_arg = *(node.InputDefs()[1]); - const ONNX_NAMESPACE::TensorProto* tensor_proto = - graph_utils::GetConstantInitializer(graph, input_arg.Name()); - TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; - TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); - TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } - - // OpSet-18 - { - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 3); - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); - TEST_RETURN_IF_NOT(2 == static_cast(attrs.at("axis").i())); - } else if (node.OpType() == "Squeeze") { - const NodeArg& input_arg = *(node.InputDefs()[1]); - const ONNX_NAMESPACE::TensorProto* tensor_proto = - graph_utils::GetConstantInitializer(graph, input_arg.Name()); - TEST_RETURN_IF_NOT(tensor_proto != nullptr); - Initializer init_const{*tensor_proto, graph.ModelPath()}; - TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); - TEST_RETURN_IF_NOT(2 == static_cast(*(init_const.data()))); - } + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 1); + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + TEST_RETURN_IF_NOT(1 == static_cast(attrs.at("axis").i())); + } else if (node.OpType() == "Squeeze") { + const NodeArg& input_arg = *(node.InputDefs()[1]); + const ONNX_NAMESPACE::TensorProto* tensor_proto = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + TEST_RETURN_IF_NOT(tensor_proto != nullptr); + Initializer init_const{*tensor_proto, graph.ModelPath()}; + TEST_RETURN_IF_NOT(tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_INT64); + TEST_RETURN_IF_NOT(1 == static_cast(*(init_const.data()))); } - return Status::OK(); - }; + } + return Status::OK(); + }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); } -TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Consume_Initializer) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInitializer({2, 3, 3, 3}, std::vector(54)); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); @@ -7430,31 +7310,31 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Consume_Initializer) { return Status::OK(); }; - std::unique_ptr transformer = std::make_unique(); + std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } -TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] > 0 || CountOpsInGraph(graph)["Slice"] > 0); return Status::OK(); }; auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 3); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] > 0 || CountOpsInGraph(graph)["Slice"] > 0); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Squeeze"] == 0); return Status::OK(); }; - // Invalid shape. + // Not cover all elements of specific dimension. { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{72}}); - auto* shape_arg = builder.MakeInput({{1}}); + auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 4, 3}}); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); + auto* gather_index_2 = builder.MakeInitializer({1}, {static_cast(1)}); auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(2)}); auto* gather_out_1 = builder.MakeIntermediate(); auto* gather_out_2 = builder.MakeIntermediate(); @@ -7467,63 +7347,65 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) .AddAttribute("axis", static_cast(2)); builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(2)); + .AddAttribute("axis", static_cast(-2)); builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) .AddAttribute("axis", static_cast(2)); builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) .AddAttribute("perm", std::vector{0, 2, 1}); builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) - .AddAttribute("perm", std::vector{0, 2, 1}); + .AddAttribute("perm", std::vector{0, 2, 1, 3}); builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}) .AddAttribute("perm", std::vector{0, 2, 1}); }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 12, *logger_, std::move(transformer), + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } - // Invalid Gather indices. + // Has overlap. { auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{1}}); - auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); - auto* gather_index_3 = builder.MakeInitializer({}, {static_cast(1)}); - auto* gather_out_1 = builder.MakeIntermediate(); - auto* gather_out_2 = builder.MakeIntermediate(); - auto* gather_out_3 = builder.MakeIntermediate(); + auto* data_arg = builder.MakeInput({{2, 3, 8, 3}}); + auto* starts_1 = builder.MakeInitializer({1}, {0}); + auto* ends_1 = builder.MakeInitializer({1}, {3}); + auto* axes_1 = builder.MakeInitializer({1}, {2}); + auto* steps_1 = builder.MakeInitializer({1}, {1}); + auto* starts_2 = builder.MakeInitializer({1}, {2}); + auto* ends_2 = builder.MakeInitializer({1}, {-2}); + auto* axes_2 = builder.MakeInitializer({1}, {-2}); + auto* steps_2 = builder.MakeInitializer({1}, {1}); + auto* starts_3 = builder.MakeInitializer({1}, {-2}); + auto* ends_3 = builder.MakeInitializer({1}, {16}); + auto* axes_3 = builder.MakeInitializer({1}, {2}); + auto* slice_out_1 = builder.MakeIntermediate(); + auto* slice_out_2 = builder.MakeIntermediate(); + auto* slice_out_3 = builder.MakeIntermediate(); auto* transpose_out_1 = builder.MakeOutput(); auto* transpose_out_2 = builder.MakeOutput(); auto* transpose_out_3 = builder.MakeOutput(); - builder.AddNode("Reshape", {data_arg, shape_arg}, {reshape_out}); - builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Gather", {reshape_out, gather_index_3}, {gather_out_3}) - .AddAttribute("axis", static_cast(2)); - builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) - .AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) - .AddAttribute("perm", std::vector{0, 2, 1}); - builder.AddNode("Transpose", {gather_out_3}, {transpose_out_3}) - .AddAttribute("perm", std::vector{0, 2, 1}); + builder.AddNode("Slice", {data_arg, starts_1, ends_1, axes_1, steps_1}, {slice_out_1}); + builder.AddNode("Slice", {data_arg, starts_2, ends_2, axes_2, steps_2}, {slice_out_2}); + builder.AddNode("Slice", {data_arg, starts_3, ends_3, axes_3}, {slice_out_3}); + builder.AddNode("Transpose", {slice_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + builder.AddNode("Transpose", {slice_out_3}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); }; - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 18, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } - // Invalid Gather axis. + // Invalid axis. { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{1}}); + auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); @@ -7550,7 +7432,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion_Invalid) { .AddAttribute("perm", std::vector{0, 2, 1}); }; - std::unique_ptr transformer = std::make_unique(); + std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); } @@ -7643,143 +7525,5 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { } } -TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { - { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{54}}); - auto* reshape_arg = builder.MakeInput({{4}}); - auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); - builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); - - // Create Gather-1 Ops - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); - auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); - builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); - - // Create Transpose 1-Ops - auto* transpose_out_1 = builder.MakeOutput(); - builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - - // Create Gather-2 Ops - auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(-1)}); - auto* gather_out_2 = builder.MakeIntermediate({{2, 512, 1, 64}}); - builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) - .AddAttribute("axis", static_cast(2)); - - // Create Transpose-2 Ops - auto* transpose_out_2 = builder.MakeOutput(); - builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - - // Create Slice Ops - auto* slice_output = builder.MakeIntermediate(); - auto* starts = builder.MakeInitializer({1}, {0}); - auto* ends = builder.MakeInitializer({1}, {-2}); - auto* axes = builder.MakeInitializer({1}, {2}); - auto* steps = builder.MakeInitializer({1}, {1}); - builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); - - // Create Shape-1 Ops - auto* shape_output_1 = builder.MakeOutput(); - builder.AddNode("Shape", {slice_output}, {shape_output_1}); - - // Create Shape-2 Ops - auto* shape_output_2 = builder.MakeOutput(); - builder.AddNode("Shape", {slice_output}, {shape_output_2}); - - // Create Transpose-3 Ops - auto* transpose_out_3 = builder.MakeOutput(); - builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - }; - - auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); - return Status::OK(); - }; - - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 0); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); - - for (auto& node : graph.Nodes()) { - if (node.OpType() == "Split") { - auto& attrs = node.GetAttributes(); - TEST_RETURN_IF_NOT(static_cast(attrs.at("axis").i()) == 2); - } - } - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } -} - -TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { - { - auto build_test_case = [&](ModelTestBuilder& builder) { - auto* data_arg = builder.MakeInput({{54}}); - auto* reshape_arg = builder.MakeInput({{4}}); - auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); - builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); - - // Create Gather-1 Ops - auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); - auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); - builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) - .AddAttribute("axis", static_cast(2)); - - // Create Transpose 1-Ops - auto* transpose_out_1 = builder.MakeOutput(); - builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - - // Create Slice Ops - auto* slice_output = builder.MakeIntermediate(); - auto* starts = builder.MakeInitializer({1}, {0}); - auto* ends = builder.MakeInitializer({1}, {-2}); - auto* axes = builder.MakeInitializer({1}, {2}); - auto* steps = builder.MakeInitializer({1}, {1}); - builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); - - // Create Shape-1 Ops - auto* shape_output_1 = builder.MakeOutput(); - builder.AddNode("Shape", {slice_output}, {shape_output_1}); - - // Create Shape-2 Ops - auto* shape_output_2 = builder.MakeOutput(); - builder.AddNode("Shape", {slice_output}, {shape_output_2}); - - // Create Transpose-3 Ops - auto* transpose_out_3 = builder.MakeOutput(); - builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) - .AddAttribute("perm", std::vector{0, 2, 1, 3}); - }; - - auto pre_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); - return Status::OK(); - }; - - auto post_graph_checker = [&](Graph& graph) { - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); - TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); - return Status::OK(); - }; - - std::unique_ptr transformer = std::make_unique(); - ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), - TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); - } -} - } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 0b68dc65e41cd..5d527369a1b75 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -24,7 +24,6 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" -#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -139,9 +138,8 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); - transformers.emplace_back(std::make_unique(compatible_eps)); - transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps)); // If a model with Q, DQ nodes is being used for the purpose of training, it must be for // Quantization Aware Training. So, replace QDQ nodes with FakeQuant. transformers.emplace_back(std::make_unique(compatible_eps)); From c1bf7fcd2fb105e067dc1f2edd408c399a61a1fe Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Thu, 29 Feb 2024 01:19:25 -0800 Subject: [PATCH 171/207] [QNN Quant] Ensure 16bit tensor quant overrides set MS domain (#19684) ### Description Ensures that DQ and Q ops use the msft domain if tensor quantization overrides specify 16-bit integer types. ### Motivation and Context ONNX does not yet support 16bit integer types for QuantizeLinear and DequantizeLinear ops (coming soon). For now, DQ/Q ops must use the MSFT domain. We have to also check if tensor quantization overrides force the use of 16-bit quantization types. If so, we must correctly set the domain for Q/DQ ops. --- .../tools/quantization/onnx_quantizer.py | 11 ++++--- .../tools/quantization/qdq_quantizer.py | 5 ++- .../test_tensor_quant_overrides_option.py | 32 ++++++++++++++++++- 3 files changed, 42 insertions(+), 6 deletions(-) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 9450426f12444..19a72e38dea33 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -154,7 +154,7 @@ def __init__( if self.mode not in QuantizationMode: raise ValueError(f"unsupported quantization mode {self.mode}") - self.tensor_quant_overrides = self._get_and_check_tensor_quant_overrides() + self.tensor_quant_overrides, self.tensor_quant_override_types = self._get_and_check_tensor_quant_overrides() self.quantization_params = self.calculate_quantization_params() # QuantizeRange tensor name and zero tensor name for scale and zero point calculation. @@ -177,8 +177,10 @@ def __init__( def _get_and_check_tensor_quant_overrides(self): """ Get tensor quantization overrides and check correctness. + Also returns a set of quantization types (as TensorProto) specified across all overrides. """ tensor_quant_overrides = self.extra_options.get("TensorQuantOverrides", {}) + tensor_quant_override_types = set() # Validate that compatible/valid overrides are provided. if tensor_quant_overrides: @@ -211,6 +213,8 @@ def _get_and_check_tensor_quant_overrides(self): # other channels. if index == 0: quant_type = quant_overrides.get("quant_type") + if quant_type is not None: + tensor_quant_override_types.add(quant_type.tensor_type) elif quant_type != quant_overrides.get("quant_type"): raise ValueError( "Channel quantization types for tensor '{tensor_name}' do not match at index {index}." @@ -231,7 +235,7 @@ def _get_and_check_tensor_quant_overrides(self): f"Tensor override option '{key}' is invalid with 'scale' and 'zero_point'" ) - return tensor_quant_overrides + return tensor_quant_overrides, tensor_quant_override_types def get_per_tensor_quant_overrides(self, tensor_name): quant_overrides_list = self.tensor_quant_overrides.get(tensor_name, [{}]) @@ -747,8 +751,7 @@ def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=Non raise ValueError(f"Unexpected type {type(params['scale'])} and param_name={param_name!r}") scale_values = np.array([params["scale"]]) assert scale_values.dtype != np.float64 - # zero_point_type = params["quant_type"] - assert zero_point_type == params["quant_type"] + zero_point_type = params["quant_type"] else: zero_point_values = np.array([use_zeropoint]) scale_values = np.array([use_scale]) diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 775a3e8b8b588..76cd0d21fca37 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -116,7 +116,10 @@ def __init__( # if the activation or weight types are 16-bit integers. # TODO: Remove this override (and use only the 'UseQDQContribOps' option) if/when ONNX adds 16-bit support. int16_types = (TensorProto.UINT16, TensorProto.INT16) - if not self.qdq_op_domain and (self.activation_qType in int16_types or self.weight_qType in int16_types): + overrides_have_int16 = any(t in int16_types for t in self.tensor_quant_override_types) + if not self.qdq_op_domain and ( + self.activation_qType in int16_types or self.weight_qType in int16_types or overrides_have_int16 + ): logging.warning( "ONNX QuantizeLinear and DequantizeLinear operators do not support 16-bit integer quantization types. " f"The domain of QuantizeLinear and DequantizeLinear operators will be set to '{ms_domain}' to " diff --git a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py index cbb6b3ae2e776..9ea4719f3c595 100644 --- a/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py +++ b/onnxruntime/test/python/quantization/test_tensor_quant_overrides_option.py @@ -13,7 +13,7 @@ from onnxruntime import quantization from onnxruntime.quantization.execution_providers.qnn import get_qnn_qdq_config -from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType +from onnxruntime.quantization.quant_utils import compute_scale_zp, get_qmin_qmax_for_qType, ms_domain class DummyDataReader(quantization.CalibrationDataReader): @@ -423,6 +423,36 @@ def test_qdq_overrides_per_channel2(self): self.assertEqual(zp, expected_zp) self.assertEqual(scale, np.float32(expected_scale)) + def test_16bit_overrides_set_ms_domain(self): + """ + Test that overriding a tensor to 16bit (when default is 8bit) automatically sets the 'com.microsoft' + domain on DQ and Q ops. + """ + qdq_model_name = "model_quant_overrides_to_16bit.onnx" + inp_zp, _, sig_out_zp, _, _, _, _, _, out_zp, _ = self.perform_qdq_quantization( + qdq_model_name, + activation_type=onnx.TensorProto.UINT8, # Default to 8bit activations + extra_options={ + "TensorQuantOverrides": { + "INP": [{"quant_type": quantization.QuantType.QUInt16}], + "SIG_OUT": [{"quant_type": quantization.QuantType.QUInt16}], + } + }, + ) + + # Input and Sigmoid's output should be overridden to 16bit + self.assertEqual(inp_zp.data_type, onnx.TensorProto.UINT16) + self.assertEqual(sig_out_zp.data_type, onnx.TensorProto.UINT16) + + # Output should the default uint8 type + self.assertEqual(out_zp.data_type, onnx.TensorProto.UINT8) + + # Q/DQ ops should all have the 'com.microsoft' domain + qdq_model = onnx.load_model(qdq_model_name) + for node in qdq_model.graph.node: + if node.op_type in {"QuantizeLinear", "DequantizeLinear"}: + self.assertEqual(node.domain, ms_domain) + def test_override_validation_nonexisting_tensor(self): """ Test that specifying a non-existing tensor should fail. From c311d1faf50167e38613927e44c8a430ffcc8e89 Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Thu, 29 Feb 2024 17:51:29 +0800 Subject: [PATCH 172/207] [ROCm] Update dockerfile (#19661) Update dockerfile to ROCm6.0 --- dockerfiles/Dockerfile.migraphx | 43 +++------------------------------ dockerfiles/Dockerfile.rocm | 4 +-- dockerfiles/README.md | 4 +-- 3 files changed, 8 insertions(+), 43 deletions(-) diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx index bc513a8e8ba6d..c3541a8bd3425 100644 --- a/dockerfiles/Dockerfile.migraphx +++ b/dockerfiles/Dockerfile.migraphx @@ -5,57 +5,22 @@ # Dockerfile to run ONNXRuntime with MIGraphX integration #-------------------------------------------------------------------------- -FROM ubuntu:20.04 +FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main -ARG ROCM_VERSION=5.4 -# MIGraphX version should be the same as ROCm version -ARG MIGRAPHX_VERSION=rocm-5.4.0 -ENV DEBIAN_FRONTEND noninteractive -ENV MIGRAPHX_DISABLE_FAST_GELU=1 -RUN apt-get clean && apt-get update && apt-get install -y locales -RUN locale-gen en_US.UTF-8 -RUN update-locale LANG=en_US.UTF-8 -ENV LC_ALL C.UTF-8 -ENV LANG C.UTF-8 +ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} -# Install rocm -RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \ - curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ - sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/${ROCM_VERSION}/ ubuntu main > /etc/apt/sources.list.d/rocm.list' - -RUN apt-get update &&\ - apt-get install -y sudo git bash build-essential rocm-dev python3-dev python3-pip miopen-hip \ - rocblas half aria2 libnuma-dev pkg-config - -RUN aria2c -q -d /tmp -o cmake-3.27.3-linux-x86_64.tar.gz \ -https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-x86_64.tar.gz &&\ -tar -zxf /tmp/cmake-3.27.3-linux-x86_64.tar.gz --strip=1 -C /usr - -# Install rbuild -RUN pip3 install https://github.com/RadeonOpenCompute/rbuild/archive/master.tar.gz numpy yapf==0.28.0 - -ENV PATH /opt/miniconda/bin:/code/cmake-3.27.3-linux-x86_64/bin:${PATH} - -# Install MIGraphX from source -RUN mkdir -p /migraphx -RUN cd /migraphx && git clone --depth=1 --branch ${MIGRAPHX_VERSION} https://github.com/ROCmSoftwarePlatform/AMDMIGraphX src -RUN cd /migraphx && rbuild package --cxx /opt/rocm/llvm/bin/clang++ -d /migraphx/deps -B /migraphx/build -S /migraphx/src/ -DPYTHON_EXECUTABLE=/usr/bin/python3 -RUN dpkg -i /migraphx/build/*.deb -RUN rm -rf /migraphx - -# Install rocm ep dependencies RUN apt-get update &&\ - apt-get install -y rocrand rccl hipsparse hipfft hipcub hipblas rocthrust + apt-get install -y migraphx WORKDIR /code # Prepare onnxruntime repository & build onnxruntime RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ /bin/sh onnxruntime/dockerfiles/scripts/install_common_deps.sh &&\ - cd onnxruntime &&\ + cd onnxruntime && pip install --upgrade pip &&\ /bin/sh ./build.sh --allow_running_as_root --cmake_extra_defines ONNXRUNTIME_VERSION=`cat ./VERSION_NUMBER` --config Release --parallel \ --skip_tests --build_wheel --use_rocm --rocm_version=${ROCM_VERSION} --rocm_home /opt/rocm --use_migraphx &&\ pip install /code/onnxruntime/build/Linux/Release/dist/*.whl diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm index 35a676383337b..c242933f677f0 100644 --- a/dockerfiles/Dockerfile.rocm +++ b/dockerfiles/Dockerfile.rocm @@ -5,14 +5,14 @@ # Dockerfile to run ONNXRuntime with ROCm integration #-------------------------------------------------------------------------- -FROM rocm/pytorch:rocm5.4_ubuntu20.04_py3.7_pytorch_1.12.1 +FROM rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1 ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_BRANCH=main WORKDIR /code -ENV PATH /opt/miniconda/bin:/code/cmake-3.27.3-linux-x86_64/bin:${PATH} +ENV PATH /code/cmake-3.27.3-linux-x86_64/bin:${PATH} # Prepare onnxruntime repository & build onnxruntime RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXRUNTIME_REPO} onnxruntime &&\ diff --git a/dockerfiles/README.md b/dockerfiles/README.md index f226ebfe8b193..a2e99d66d4654 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -277,7 +277,7 @@ Nothing else from ONNX Runtime source tree will be copied/installed to the image Note: When running the container you built in Docker, please either use 'nvidia-docker' command instead of 'docker', or use Docker command-line options to make sure NVIDIA runtime will be used and appropiate files mounted from host. Otherwise, CUDA libraries won't be found. You can also [set NVIDIA runtime as default in Docker](https://github.com/dusty-nv/jetson-containers#docker-default-runtime). ## MIGraphX -**Ubuntu 20.04, ROCm5.4, AMDMIGraphX v1.2** +**Ubuntu 20.04, ROCm6.0, MIGraphX** 1. Build the docker image from the Dockerfile in this repository. ``` @@ -291,7 +291,7 @@ Note: When running the container you built in Docker, please either use 'nvidia- ``` ## ROCm -**Ubuntu 20.04, ROCm5.4** +**Ubuntu 20.04, ROCm6.0** 1. Build the docker image from the Dockerfile in this repository. ``` From 937cdd651e4f656e65053d027c71b51f1e1411ec Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Thu, 29 Feb 2024 23:03:57 +0800 Subject: [PATCH 173/207] [ORTMODULE] Support Register Custom Triton Kernel (#19690) Add support for registering custom Triton kernel function. --- .../python/training/ort_triton/__init__.py | 1 + .../python/training/ort_triton/triton_op_executor.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/orttraining/orttraining/python/training/ort_triton/__init__.py b/orttraining/orttraining/python/training/ort_triton/__init__.py index fbb59d1354ae7..5f2d0c62ffa50 100644 --- a/orttraining/orttraining/python/training/ort_triton/__init__.py +++ b/orttraining/orttraining/python/training/ort_triton/__init__.py @@ -9,6 +9,7 @@ from onnxruntime.capi import _pybind_state as _C from .kernel import * # noqa: F403 +from .triton_op_executor import register_triton_kernel # noqa: F401 from .triton_op_executor import call_triton_by_name, call_triton_by_onnx, get_config diff --git a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py index f16abc71251ed..e104ea13c59a3 100644 --- a/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py +++ b/orttraining/orttraining/python/training/ort_triton/triton_op_executor.py @@ -23,6 +23,8 @@ _DEBUG_MODE = "ORTMODULE_TRITON_DEBUG" in os.environ and int(os.getenv("ORTMODULE_TRITON_DEBUG")) == 1 +_CUSTOM_KERNELS = dict() + @functools.lru_cache(None) def _gen_module_internal(sorted_graph: SortedGraph) -> Tuple[str, str, ModuleType]: @@ -113,7 +115,10 @@ def call_triton_by_name(func_name: str, *tensors, **kwargs): """ torch_tensors = [_from_dlpack(tensor) if tensor is not None else None for tensor in tensors] - func = getattr(sys.modules[".".join(__name__.split(".")[:-1])], func_name) + func = getattr(sys.modules[".".join(__name__.split(".")[:-1])], func_name, None) + if func is None: + func = _CUSTOM_KERNELS.get(func_name) + assert func is not None, f"Function {func_name} is not found in the registered kernels." output = func(*torch_tensors, **kwargs) if output is not None: if isinstance(output, tuple): @@ -138,3 +143,8 @@ def call_triton_by_onnx(onnx_key: int, onnx_str: bytes, *tensors): if isinstance(output, tuple): return tuple([to_dlpack(tensor) for tensor in output]) return to_dlpack(output) + + +def register_triton_kernel(fn): + _CUSTOM_KERNELS[fn.__name__] = fn + return fn From ec0e4d3b6572c18a3462eb6efb3bb007ec3a2962 Mon Sep 17 00:00:00 2001 From: Yi-Hong Lyu Date: Thu, 29 Feb 2024 10:31:57 -0800 Subject: [PATCH 174/207] Parallel Transpose_BSNH_to_BNSH (#19406) Achieved a speedup of 1.098 in MultiHeadAttention and an end-to-end speedup of 1.021 in the OCR model through parallelization of the Transpose_BSNH_to_BNSH operation. --- onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index eb25d0fd7cc1e..c4e4b4ec707fb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -58,11 +58,12 @@ Status Reshape_BSD_to_BSNH(Tensor* qkv, // Transpose Q/K/V from BxSxNxH to BxNxSxH Status Transpose_BSNH_to_BNSH(const Tensor* qkv, - OrtValue& qkv_transposed) { + OrtValue& qkv_transposed, + concurrency::ThreadPool* tp = nullptr) { std::vector permutations({0, 2, 1, 3}); gsl::span permutations_span{permutations}; size_t from = 2, to = 1; - SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable(), from, to); + SingleAxisTranspose(permutations_span, *qkv, *qkv_transposed.GetMutable(), from, to, nullptr, tp); return Status::OK(); } @@ -143,7 +144,8 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat ORT_RETURN_IF_ERROR(Reshape_BSD_to_BSNH(qkv_with_bias.GetMutable(), batch_size, sequence_length, num_heads, head_size)); // Transpose Q from BxSxNxH to BxNxSxH - ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed)); + auto tp = context->GetOperatorThreadPool(); + ORT_RETURN_IF_ERROR(Transpose_BSNH_to_BNSH(qkv_with_bias.GetMutable(), qkv_with_bias_transposed, tp)); return Status::OK(); } From d5606cd7ee394ba9444ef509021720ebe63c9856 Mon Sep 17 00:00:00 2001 From: Adam Louly Date: Thu, 29 Feb 2024 13:40:56 -0800 Subject: [PATCH 175/207] Introducing customizable input names for loss in generate_artifacts. (#19705) # loss function extra inputs. Currently, the loss functions in onnxblock expect exactly two inputs in their build method. Occasionally, models may pass additional inputs, causing the build function to fail. To solve this issue, we can let users pass a list of loss input names to be used in the loss function. --- .../orttraining/python/training/artifacts.py | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index 7a4eb251bc5bc..4e76174d8255e 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -48,6 +48,7 @@ def generate_artifacts( custom_op_library: Optional[Union[str, bytes, os.PathLike]] = None, additional_output_names: Optional[List[str]] = None, nominal_checkpoint: bool = False, + loss_input_names: Optional[List[str]] = None, ) -> None: """Generates artifacts required for training with ORT training api. @@ -77,7 +78,9 @@ def generate_artifacts( Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model parameters. It can be used on the device to reduce overhead while constructing the training model as well as to reduce the size of the checkpoint packaged with the on-device application. - + loss_input_names: Specifies a list of input names to be used specifically for the loss computation. When provided, + only these inputs will be passed to the loss function. If `None`, all graph outputs are passed to + the loss function. Raises: RuntimeError: If the loss provided is neither one of the supported losses nor an instance of `onnxblock.Block` RuntimeError: If the optimizer provided is not one of the supported optimizers. @@ -111,11 +114,16 @@ def generate_artifacts( logging.info("Custom loss block provided: %s", loss.__class__.__name__) class _TrainingBlock(onnxblock.TrainingBlock): - def __init__(self, _loss): + def __init__(self, _loss, _loss_input_names=None): super().__init__() self._loss = _loss + self._loss_input_names = _loss_input_names def build(self, *inputs_to_loss): + # If loss_input_names is passed, only pass the specified input names to the loss function. + if self._loss_input_names: + inputs_to_loss = self._loss_input_names + if additional_output_names: # If additional output names is not a list, raise an error if not isinstance(additional_output_names, list): @@ -132,7 +140,7 @@ def build(self, *inputs_to_loss): return self._loss(*inputs_to_loss) - training_block = _TrainingBlock(loss_block) + training_block = _TrainingBlock(loss_block, loss_input_names) if requires_grad is not None and frozen_params is not None and set(requires_grad).intersection(set(frozen_params)): raise RuntimeError( @@ -157,9 +165,11 @@ def build(self, *inputs_to_loss): logging.info("Custom op library provided: %s", custom_op_library) custom_op_library_path = pathlib.Path(custom_op_library) - with onnxblock.base(model), onnxblock.custom_op_library( - custom_op_library_path - ) if custom_op_library is not None else contextlib.nullcontext(): + with onnxblock.base(model), ( + onnxblock.custom_op_library(custom_op_library_path) + if custom_op_library is not None + else contextlib.nullcontext() + ): _ = training_block(*[output.name for output in model.graph.output]) training_model, eval_model = training_block.to_model_proto() model_params = training_block.parameters() From 5ee62a6bcc228e63704f64f2de46d61d2c57a281 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 29 Feb 2024 14:46:42 -0800 Subject: [PATCH 176/207] CUDA Resize-18 implementation (#19595) ### Description Implement Resize-18 on CUDA. ### Motivation and Context Performance --- docs/OperatorKernels.md | 3 +- .../providers/cpu/cpu_execution_provider.cc | 6 +- .../core/providers/cpu/cpu_provider_shared.cc | 8 + .../core/providers/cpu/cpu_provider_shared.h | 5 + .../core/providers/cpu/tensor/upsample.cc | 79 +- .../core/providers/cpu/tensor/upsample.h | 14 +- .../providers/cpu/tensor/upsample_antialias.h | 95 +- .../core/providers/cpu/tensor/upsamplebase.h | 191 ++- .../core/providers/cuda/cu_inc/common.cuh | 12 +- .../providers/cuda/cuda_execution_provider.cc | 30 +- .../core/providers/cuda/tensor/resize.cc | 14 +- .../cuda/tensor/resize_antialias_impl.cu | 1179 +++++++++++++++++ .../core/providers/cuda/tensor/resize_impl.cu | 254 ++-- .../core/providers/cuda/tensor/resize_impl.h | 111 ++ .../core/providers/cuda/tensor/upsample.cc | 254 +++- .../core/providers/cuda/tensor/upsample.h | 10 +- .../providers/rocm/rocm_execution_provider.cc | 40 +- .../provider_bridge_provider.cc | 7 +- .../core/providers/xnnpack/tensor/resize.cc | 2 +- .../providers/cpu/tensor/resize_op_test.cc | 171 ++- 20 files changed, 2090 insertions(+), 395 deletions(-) create mode 100644 onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index b0ed68d595c42..1eaf0fb6dad76 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -734,7 +734,8 @@ Do not modify directly.* |||13|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[5, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**shape** = tensor(int64)| |||[1, 4]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|18+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| +|||[13, 17]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||[11, 12]|**T1** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |||10|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(uint8)| |ReverseSequence|*in* input:**T**
*in* sequence_lens:**tensor(int64)**
*out* Y:**T**|10+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 48e4617b33b4d..37e7e42150413 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -2008,8 +2008,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Greater)>, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo namespace onnxruntime { // The suppressed warning is: "The type with a virtual function needs either public virtual or protected nonvirtual destructor." @@ -292,6 +294,12 @@ struct ProviderHostCPUImpl : ProviderHostCPU { Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); } Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); } + void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const override { + p->AdjustOutputSizeAsPolicy(output_dims, input_dims, scales); + } + #ifdef ENABLE_ATEN Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); } #endif diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index f33eec4b93e98..c0e674827e4d1 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -24,6 +24,7 @@ class SliceOp__PrepareForComputeMetadata; // Directly maps to SliceOp::PrepareF class UnsqueezeBase__Prepare; // Directly maps to UnsqueezeBase::Prepare class contrib__AdamWOptimizerBase__Prepare; class contrib__SGDOptimizerV2Base__Prepare; +class UpsampleBase; using PadsVector = InlinedVector; @@ -202,6 +203,10 @@ struct ProviderHostCPU { virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0; virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; + virtual void UpsampleBase__AdjustOutputSizeAsPolicy(const UpsampleBase* p, TensorShapeVector& output_dims, + gsl::span input_dims, + InlinedVector& scales) const = 0; + #ifdef ENABLE_ATEN virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0; #endif diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.cc b/onnxruntime/core/providers/cpu/tensor/upsample.cc index fa69e144be554..babbac0b7be17 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/cpu/tensor/upsample.cc @@ -1,10 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/cpu/tensor/upsample.h" + +#include + +#include "core/common/inlined_containers.h" #include "core/common/safeint.h" #include "core/platform/threadpool.h" -#include "core/providers/cpu/tensor/upsample.h" #include "core/providers/cpu/tensor/upsample_antialias.h" + using namespace onnxruntime::common; using namespace std; using onnxruntime::narrow; @@ -30,6 +35,46 @@ REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(int8_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9); +void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const { + // AspectRatioPolicy::STRETCH is default policy when opset < 18 + if (keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH) { + return; + } + + InlinedHashSet axes_set(axes_.begin(), axes_.end()); + + float scale_in_policy = 0.0f; + if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) { + scale_in_policy = std::numeric_limits::max(); + + for (size_t i = 0; i < scales.size(); i++) { + if (axes_set.empty() || axes_set.count(i) > 0) { + scale_in_policy = std::min(scale_in_policy, scales[i]); + } + } + } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) { + scale_in_policy = std::numeric_limits::min(); + + for (size_t i = 0; i < scales.size(); i++) { + if (axes_set.empty() || axes_set.count(i) > 0) { + scale_in_policy = std::max(scale_in_policy, scales[i]); + } + } + } + + for (size_t i = 0; i < scales.size(); i++) { + // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes + if (axes_set.empty() || axes_set.count(i) > 0) { + scales[i] = scale_in_policy; + output_dims[i] = static_cast(std::round(scales[i] * input_dims[i])); + } else { + scales[i] = 1.0f; + output_dims[i] = input_dims[i]; + } + } +} + template void UpsampleNearest2x(int64_t batch_size, int64_t num_channels, @@ -94,8 +139,8 @@ UpsampleNearestSetupInputMappings(int64_t n_dim, const TensorShape& input_shape, const TensorShape& output_shape, const std::vector& input_dim_factor, - const vector& scales, - const vector& roi, + gsl::span scales, + gsl::span roi, bool extrapolation_enabled, const GetOriginalCoordinateFunc& get_original_coordinate, const GetNearestPixelFunc& get_nearest_pixel) { @@ -141,8 +186,8 @@ static Status UpsampleNearestImpl(const T* input, T* output, const TensorShape& input_shape, const TensorShape& output_shape, - const vector& scales, - const vector& roi, + gsl::span scales, + gsl::span roi, bool extrapolation_enabled, const T extrapolation_value, const GetOriginalCoordinateFunc& get_original_coordinate, @@ -285,8 +330,8 @@ static Status UpsampleNearest(const T* input, T* output, const TensorShape& input_shape, const TensorShape& output_shape, - const vector& scales, - const vector& roi, + gsl::span scales, + gsl::span roi, bool is_resize, bool extrapolation_enabled, T extrapolation_value, @@ -412,7 +457,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw) { @@ -518,7 +563,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw) { @@ -650,7 +695,7 @@ static TrilinearParams SetupUpsampleTrilinear(int64_t input_depth, float depth_scale, float height_scale, float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate) { TrilinearParams p; @@ -796,7 +841,7 @@ void UpsampleTrilinear(int64_t batch_size, float depth_scale, float height_scale, float width_scale, - const std::vector& roi, + gsl::span roi, bool use_extrapolation, float extrapolation_value, const T* XdataBase, @@ -929,7 +974,7 @@ void ResizeBiCubic(int64_t batch_size, bool use_extrapolation, float extrapolation_value, bool exclude_outside, - const std::vector& roi, + gsl::span roi, const T* Xdata, T* Ydata, const GetOriginalCoordinateFunc& get_original_coordinate) { @@ -1067,9 +1112,9 @@ void ResizeBiCubic(int64_t batch_size, template Status Upsample::BaseCompute(OpKernelContext* context, - const std::vector& roi, - const std::vector& scales, - const gsl::span& output_dims) const { + gsl::span roi, + gsl::span scales, + gsl::span output_dims) const { const auto* X = context->Input(0); auto dims = X->Shape().GetDims(); ORT_RETURN_IF_NOT(output_dims.size() == dims.size(), "Rank of input and output tensor should be same."); @@ -1327,7 +1372,7 @@ Status Upsample::Compute(OpKernelContext* context) const { // Initialize the roi array to all zeros as this will be the most common case // Roi data is needed only when coordinate transformation mode is set to tf_crop_and_resize // for all other cases we need a 0 initialized roi array - std::vector roi_array(roi_); + InlinedVector roi_array(roi_); if (!roi_cached_) { bool use_default_roi = true; @@ -1353,7 +1398,7 @@ Status Upsample::Compute(OpKernelContext* context) const { ComputeROIWithAxes(roi_array, input_dims.size()); // Get scales data - std::vector scales_array(input_dims.size()); + InlinedVector scales_array(input_dims.size()); if (OpKernel::Node().InputDefs().size() == 1) { // Compute output shape from scales and input dims diff --git a/onnxruntime/core/providers/cpu/tensor/upsample.h b/onnxruntime/core/providers/cpu/tensor/upsample.h index 3046ee4b8260d..8ff04781f6ad0 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample.h @@ -66,8 +66,8 @@ class Upsample : public UpsampleBase, public OpKernel { Status Compute(OpKernelContext* context) const override; - Status BaseCompute(OpKernelContext* context, const std::vector& roi, const std::vector& scales, - const gsl::span& output_dims) const; + Status BaseCompute(OpKernelContext* context, gsl::span roi, gsl::span scales, + gsl::span output_dims) const; }; BilinearParams SetupUpsampleBilinear(const int32_t input_height, @@ -76,7 +76,7 @@ BilinearParams SetupUpsampleBilinear(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw); @@ -90,7 +90,7 @@ void UpsampleBilinear(const int32_t batch_size, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const bool use_extrapolation, const float extrapolation_value, const T* const XdataBase, @@ -144,7 +144,7 @@ void NhwcUpsampleBilinear(const int32_t batch_size, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const float extrapolation_value, const T* const XdataBase, T* const YdataBase, @@ -227,7 +227,7 @@ BilinearParamsInteger SetupUpsampleBilinearInteger(const int32_t input_height, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, const bool is_nchw); @@ -241,7 +241,7 @@ void NhwcUpsampleBilinearInteger(const int32_t batch_size, const int32_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const float extrapolation_value, const T* const XdataBase, T* const YdataBase, diff --git a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h index e1dcaf500a325..1e32b7e874b1a 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h +++ b/onnxruntime/core/providers/cpu/tensor/upsample_antialias.h @@ -21,32 +21,6 @@ namespace onnxruntime { -namespace ConstValue { -constexpr int32_t mag_factor = 1 << (22 - 1); -} - -namespace { -const uint8_t* GetLookupTableShared() { - // initialized once - static const auto* lookup_table = []() { - // if we have already initialized the lookup table, just return - // ideally we could have a global lookup table, but that account for too much space. - /* Handles values form -640 to 639. */ - static uint8_t table[1280] = {0}; - - // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94 - // we need to handle negative values - // it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639] - // we will accept a negative x for (&table[640])[x] means table +640 -x - for (int i = 0; i < 1280; ++i) { - table[i] = static_cast(std::min(std::max(i - 640, 0), 255)); - } - return table; - }(); - return lookup_table; -} -} // namespace - template struct FilterParamsBaseAntiAlias { std::vector bound; @@ -57,15 +31,15 @@ struct FilterParamsBaseAntiAlias { template struct FilterParamsAntiAlias { - float support_size = 2.0f; - float cubic_coeff_a = -0.75f; + float support_size = antialias_constants::kSupportSize; + float cubic_coeff_a = antialias_constants::kCubicCoeffA; FilterParamsBaseAntiAlias dim_x; FilterParamsBaseAntiAlias dim_y; FilterParamsBaseAntiAlias dim_z; const uint8_t* GetClip8LookupTable() const { - return GetLookupTableShared(); + return UpsampleBase::GetLookupTableShared(); } virtual ~FilterParamsAntiAlias() = default; virtual float Filter(float x) const = 0; @@ -89,7 +63,7 @@ struct BilinearParamsAntiAlias : FilterParamsAntiAlias { template struct BiCubicParamsAntiAlias : FilterParamsAntiAlias { BiCubicParamsAntiAlias() { - this->support_size = 4.0f; + this->support_size = antialias_constants::kBiCubicSupportSize; } // taken from @@ -124,27 +98,6 @@ struct TriLinearParamsAntiAlias : FilterParamsAntiAlias { } }; -template -struct AccumulateType { - using type = int32_t; - using Dtype = T; -}; - -template <> -struct AccumulateType { - using type = float; -}; - -template <> -struct AccumulateType { - using type = float; -}; - -template <> -struct AccumulateType { - using type = double; -}; - // The following method supports a 3/4/5-D input in 'Linear mode, cubic mode' // that amounts to 'Bilinear,TriLinear, Bicubic/Tricubic' Upsampling/Resizing in the sense that it assumes // A N-D tensor has @@ -156,19 +109,20 @@ struct AccumulateType { // - [N, H, W, C] and the scales are [1.0, height_scale, width_scale, 1.0] template void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, - const gsl::span input_h_w_c, - const gsl::span output_h_w_c, - const gsl::span scale_h_w_c, - const std::vector& roi, + gsl::span input_h_w_c, + gsl::span output_h_w_c, + gsl::span scale_h_w_c, + gsl::span roi, AllocatorPtr& alloc, const GetOriginalCoordinateFunc& get_original_coordinate, bool exclude_outside, const bool is_nchw) { - auto compute_weight_coefficients = [&alloc, &roi, &get_original_coordinate, exclude_outside](const FilterParamsAntiAlias& p, - const int64_t input_size, - const int64_t output_size, - size_t rindex, - FilterParamsBaseAntiAlias& param_base, - const float rscale) -> int64_t { + auto compute_weight_coefficients = [&alloc, roi, &get_original_coordinate, exclude_outside]( + const FilterParamsAntiAlias& p, + const int64_t input_size, + const int64_t output_size, + size_t rindex, + FilterParamsBaseAntiAlias& param_base, + const float rscale) -> int64_t { param_base.bound.reserve(static_cast(output_size) * 2); param_base.out_of_bound_idx.reserve(static_cast(output_size)); @@ -245,13 +199,14 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, // normalize the scale to 1 << 22 for int8/uint8 if constexpr (std::is_same::value) { - scale_buffer_int[x] = static_cast(std::round(scale_buffer[x] * ConstValue::mag_factor * 2.f)); + scale_buffer_int[x] = static_cast(std::round(scale_buffer[x] * ConstValue::mag_factor_x_2)); } } /*for (; x < window_size; x++) { scale_buffer[x] = 0; }*/ } + return window_size; }; @@ -269,9 +224,6 @@ void SetupUpsampleFilterAntiAlias(FilterParamsAntiAlias& p, } } -template -inline constexpr bool is_8bit_v = std::is_same::value || std::is_same::value; - /** * @brief To compute interpolation along with the last axis. * For brief,we assume the input tensor has 3 dimensions and we all it CHW for each character represent a dim. @@ -398,6 +350,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in output += *Xdata_offset * (*weight_coeff_start++); Xdata_offset += output_width; } + if constexpr (is_8bit_v) { *Ydata_offset++ = static_cast(clip8_lookups[output >> 22]); } else if constexpr (std::is_same::value) { @@ -444,6 +397,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in output += *Xdata_offset * (*weight_coeff_start++); Xdata_offset += output_width; } + if constexpr (is_8bit_v) { *Ydata_offset++ = static_cast(clip8_lookups[output >> 22]); } else if constexpr (std::is_same::value) { @@ -515,6 +469,7 @@ void UpsampleBaseAntiAlias(FilterParamsAntiAlias& p, narrow(input_height * num_channels * input_width)); auto ydata_span = gsl::make_span(image_temp_buffer.get(), narrow(input_height * num_channels * output_width)); + // This computes only the width direction.Thus height keeps unchanged. ComputeInterpolationAtLevel1(num_channels, input_height, input_width, input_height, output_width, xdata_span, ydata_span, p, p.dim_x, tp); } @@ -546,7 +501,7 @@ void UpsampleBilinearAntiAlias(const int64_t batch_size, const int64_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const bool use_extrapolation, const float extrapolation_value, bool exclude_outside, @@ -575,7 +530,7 @@ void NhwcUpsampleBilinearAntiAlias(const int64_t batch_size, const int64_t output_width, const float height_scale, const float width_scale, - const std::vector& roi, + gsl::span roi, const bool use_extrapolation, const float extrapolation_value, bool exclude_outside, @@ -608,7 +563,7 @@ void NhwcResizeBiCubicAntiAlias(const int64_t batch_size, bool use_extrapolation, float extrapolation_value, bool exclude_outside, - const std::vector& roi, + gsl::span roi, const Tensor* X, T* Ydata_base, AllocatorPtr& alloc, @@ -688,7 +643,7 @@ void ResizeBiCubicAntiAlias(int64_t batch_size, bool use_extrapolation, float extrapolation_value, bool exclude_outside, - const std::vector& roi, + gsl::span roi, const Tensor* X, T* Ydata_base, AllocatorPtr& alloc, @@ -719,7 +674,7 @@ void UpsampleTrilinearAntiAlias(int64_t batch_size, float depth_scale, float height_scale, float width_scale, - const std::vector& roi, + gsl::span roi, bool use_extrapolation, float extrapolation_value, bool exclude_outside, diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index a0e7ca1084fef..b768fedd8513a 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -3,11 +3,13 @@ #pragma once +#include #include #include #include #include -#include + +#include #include "core/common/status.h" #include #include @@ -58,7 +60,73 @@ enum class AspectRatioPolicy { NOT_SMALLER, }; +// Antialias types +template +struct AccumulateType { + using type = int32_t; + using Dtype = T; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = float; +}; + +template <> +struct AccumulateType { + using type = double; +}; + +namespace antialias_constants { +constexpr float kCubicCoeffA = -0.75f; +constexpr float kSupportSize = 2.0f; +constexpr float kBiCubicSupportSize = 4.0f; +} // namespace antialias_constants + +namespace ConstValue { +constexpr int32_t mag_factor = 1 << (22 - 1); +// We use to multiply by 2, let's make a constant which is twice as big +constexpr int32_t mag_factor_x_2 = 1 << 22; +} // namespace ConstValue + +template +inline constexpr bool is_8bit_v = std::is_same::value || std::is_same::value; + +template +void PrintAntiAliasBuffers(std::ostream& os, gsl::span bounds, gsl::span out_of_bounds, + gsl::span weight_coefficients) { + os << "#### Bounds: "; + std::copy(bounds.begin(), bounds.end(), std::ostream_iterator(os, " ")); + os << std::endl; + + os << "#### Out of Bounds: "; + std::copy(out_of_bounds.begin(), out_of_bounds.end(), + std::ostream_iterator(os, " ")); + os << std::endl; + + os << "#### Scale Buffer: "; + std::copy(weight_coefficients.begin(), weight_coefficients.end(), + std::ostream_iterator(os, " ")); + os << std::endl; +} + class UpsampleBase { + public: + // Make this available in other EP via provider bridge + // it works iff output_shape is specified + void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const; + protected: explicit UpsampleBase(const OpKernelInfo& info) : scales_cached_(false), roi_cached_(false), use_extrapolation_(false) { @@ -69,23 +137,32 @@ class UpsampleBase { std::string mode; ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK()); mode_ = StringToUpsampleMode(mode); - antialias_ = info.GetAttrOrDefault("antialias", 0) == 0 ? false : true; - if (antialias_) { - ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_), - "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`."); - } auto input_count = info.GetInputCount(); if (input_count == 1) { // opset < 10 - ORT_THROW_IF_ERROR(info.GetAttrs("scales", scales_)); - ORT_THROW_IF_ERROR(ScalesValidation(scales_, mode_)); + std::vector scales; + ORT_THROW_IF_ERROR(info.GetAttrs("scales", scales)); + ORT_THROW_IF_ERROR(ScalesValidation(scales, mode_)); + scales_.assign(scales.cbegin(), scales.cend()); scales_cached_ = true; } - std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch"); - keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy); + if (opset >= 18) { + antialias_ = info.GetAttrOrDefault("antialias", 0) == 0 ? false : true; + + if (antialias_) { + ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_), + "when anti-aliasing is set, Resize only supports mode `LINEAR` and `CUBIC`."); + } - axes_ = info.GetAttrsOrDefault("axes"); + // The attribute is absent in opset < 18, but the default value as if stretch. + std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch"); + keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy); + + // guard against unit tests that can add an attribute + auto axes = info.GetAttrsOrDefault("axes"); + axes_.assign(axes.cbegin(), axes.cend()); + } extrapolation_value_ = info.GetAttrOrDefault("extrapolation_value", 0.0f); @@ -112,7 +189,7 @@ class UpsampleBase { nearest_mode_ = StringToNearestMode(nearest_mode_name); get_nearest_pixel_ = GetNearestPixelFromOriginal(nearest_mode_); - cubic_coeff_a_ = info.GetAttrOrDefault("cubic_coeff_a", -0.75f); + cubic_coeff_a_ = info.GetAttrOrDefault("cubic_coeff_a", antialias_constants::kCubicCoeffA); exclude_outside_ = info.GetAttrOrDefault("exclude_outside", 0) == 0 ? false : true; if ((exclude_outside_ == 1 && mode_ != CUBIC) && (antialias_ == false || mode_ != LINEAR)) { @@ -166,7 +243,7 @@ class UpsampleBase { ResizeCoordinateTransformationMode coordinate_transform_mode_; GetOriginalCoordinateFunc get_original_coordinate_; ResizeNearestMode nearest_mode_; - AspectRatioPolicy keep_aspect_ratio_policy_; + AspectRatioPolicy keep_aspect_ratio_policy_{AspectRatioPolicy::STRETCH}; GetNearestPixelFunc get_nearest_pixel_; float cubic_coeff_a_; bool exclude_outside_; @@ -174,9 +251,9 @@ class UpsampleBase { float extrapolation_value_; bool use_nearest2x_optimization_ = false; - std::vector scales_; - std::vector roi_; - std::vector axes_; + InlinedVector scales_; + InlinedVector roi_; + TensorShapeVector axes_; bool scales_cached_; bool roi_cached_; @@ -335,7 +412,7 @@ class UpsampleBase { } } - [[nodiscard]] Status ScalesValidation(const std::vector& scales, const UpsampleMode mode) const { + [[nodiscard]] Status ScalesValidation(gsl::span scales, const UpsampleMode mode) const { if (!is_resize_) { for (auto& scale : scales) { ORT_RETURN_IF_NOT(scale >= 1, "Scale value should be greater than or equal to 1."); @@ -372,7 +449,7 @@ class UpsampleBase { } [[nodiscard]] Status - ParseScalesData(const Tensor* scale, std::vector& scales, int64_t rank) const { + ParseScalesData(const Tensor* scale, InlinedVector& scales, int64_t rank) const { const auto* scale_data = scale->Data(); int64_t scales_size = scale->Shape().Size(); ORT_RETURN_IF_NOT(scales_size > 0, "scales size should be greater than 0."); @@ -387,19 +464,19 @@ class UpsampleBase { // in which case the other axes is ignored and use default scale of 1 // scales_size == axes_.size() should be guaranteed if axes is not empty if (rank > 0 && (scales_size != rank || axes_.size())) { - std::vector new_scales(size_t(rank), 1.0f); + InlinedVector new_scales(size_t(rank), 1.0f); ORT_RETURN_IF_NOT(*std::max_element(axes_.begin(), axes_.end()) < rank && (int64_t(axes_.size()) == scales_size), "all values in axes should be less than rank of the data"); for (size_t i = 0; i < axes_.size(); i++) { new_scales[static_cast(axes_[i])] = scales[i]; } - scales = new_scales; + scales.swap(new_scales); } return ScalesValidation(scales, mode_); } - void ParseRoiData(const Tensor* roi, std::vector& roi_array) const { + void ParseRoiData(const Tensor* roi, InlinedVector& roi_array) const { int64_t roi_size = roi->Shape().Size(); if (roi_size > 0) { roi_array.resize(onnxruntime::narrow(roi_size)); @@ -429,52 +506,11 @@ class UpsampleBase { return Status::OK(); } - // it works iff output_shape is specified - void AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, - std::vector& scales) const { - std::unordered_set axes_set(axes_.begin(), axes_.end()); - - // AspectRatioPolicy::STRETCH is default policy when opset < 18 - if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::STRETCH) { - return; - } - - float scale_in_policy = 0.0f; - if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_LARGER) { - scale_in_policy = std::numeric_limits::max(); - - for (size_t i = 0; i < scales.size(); i++) { - if (axes_set.empty() || axes_set.count(i) > 0) { - scale_in_policy = std::min(scale_in_policy, scales[i]); - } - } - } else if (keep_aspect_ratio_policy_ == AspectRatioPolicy ::NOT_SMALLER) { - scale_in_policy = std::numeric_limits::min(); - - for (size_t i = 0; i < scales.size(); i++) { - if (axes_set.empty() || axes_set.count(i) > 0) { - scale_in_policy = std::max(scale_in_policy, scales[i]); - } - } - } - - for (size_t i = 0; i < scales.size(); i++) { - // if axes is not specified (AKA axes_set.empty()), we apply the policy to all axes - if (axes_set.empty() || axes_set.count(i) > 0) { - scales[i] = scale_in_policy; - output_dims[i] = static_cast(std::round(scales[i] * input_dims[i])); - } else { - scales[i] = 1.0f; - output_dims[i] = input_dims[i]; - } - } - } - // It's different in Opset 18 and before. // we will modify output_shape by sorts of policy even if it's specified [[nodiscard]] Status ParseScalesDataAndAdjustOutputSize(TensorShapeVector& output_dims, gsl::span input_dims, - std::vector& scales) const { + InlinedVector& scales) const { for (size_t i = 0, end = input_dims.size(); i < end; ++i) { // Handle corner case to avoid dividing by zero in the next step if (input_dims[i] == 0) { @@ -507,9 +543,9 @@ class UpsampleBase { // Roi is redefined in Opset-18, we have a concept of axes. // So we need to update it accordingly. - void ComputeROIWithAxes(std::vector& roi_array, size_t rank) const { + void ComputeROIWithAxes(InlinedVector& roi_array, size_t rank) const { if (axes_.size()) { - std::vector roi_tmp(rank * 2, 0); + InlinedVector roi_tmp(rank * 2, 0); for (size_t i = rank; i < rank * 2; ++i) { roi_tmp[i] = 1; } @@ -518,9 +554,32 @@ class UpsampleBase { roi_tmp[v_in_axes] = (roi_array[i]); roi_tmp[rank + v_in_axes] = (roi_array[axes_.size() + i]); } - roi_array = roi_tmp; + roi_array.swap(roi_tmp); } } + + public: + static constexpr size_t kLookupTableSize = 1280; + + static const uint8_t* GetLookupTableShared() { + // initialized once + static const auto* lookup_table = []() { + // if we have already initialized the lookup table, just return + // ideally we could have a global lookup table, but that account for too much space. + /* Handles values form -640 to 639. */ + static uint8_t table[kLookupTableSize] = {0}; + + // taken from https://github.com/python-pillow/Pillow/blob/66add095a50d76c35c7f58643461f2edf78a3f05/src/libImaging/Resample.c#L94 + // we need to handle negative values + // it's equivalent to :x = np.clip(x, 0, 255) where x \in [-640, 639] + // we will accept a negative x for (&table[640])[x] means table +640 -x + for (int i = 0; i < static_cast(kLookupTableSize); ++i) { + table[i] = static_cast(std::min(std::max(i - 640, 0), 255)); + } + return table; + }(); + return lookup_table; + } }; // UpsampleBase } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 0d9928baa86e0..66794f88d8670 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -194,13 +194,13 @@ template <> __device__ __inline__ half _Ceil(half a) { return half(ceilf((float)a)); } template -__device__ __inline__ T _Floor(T a); +__device__ __host__ __inline__ T _Floor(T a); template <> -__device__ __inline__ float _Floor(float a) { return floorf(a); } +__device__ __host__ __inline__ float _Floor(float a) { return floorf(a); } template <> -__device__ __inline__ double _Floor(double a) { return floor(a); } +__device__ __host__ __inline__ double _Floor(double a) { return floor(a); } template <> __device__ __inline__ half _Floor(half a) { return half(floorf((float)a)); } @@ -230,13 +230,13 @@ template <> __device__ __inline__ half _Erf(half a) { return half(erff((float)a)); } template -__device__ __inline__ T _Round(T a); +__device__ __host__ __inline__ T _Round(T a); template <> -__device__ __inline__ float _Round(float a) { return rintf(a); } +__device__ __host__ __inline__ float _Round(float a) { return rintf(a); } template <> -__device__ __inline__ double _Round(double a) { return rint(a); } +__device__ __host__ __inline__ double _Round(double a) { return rint(a); } template <> __device__ __inline__ half _Round(half a) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 00783bcbc2665..1ce089fd93044 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1109,11 +1109,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, int32_t, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, uint8_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, int32_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Flatten); @@ -1277,6 +1277,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, bool, Pad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, uint8_t, Resize); // Opset 19 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, float, Cast); @@ -2009,11 +2014,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2176,6 +2181,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // Opset 19 BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/tensor/resize.cc b/onnxruntime/core/providers/cuda/tensor/resize.cc index 764172a8d1fac..97d4eb71e970a 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize.cc +++ b/onnxruntime/core/providers/cuda/tensor/resize.cc @@ -28,10 +28,22 @@ namespace cuda { .InputMemoryType(OrtMemTypeCPUInput, 3) \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ Resize); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Resize, \ + kOnnxDomain, \ + 13, 17, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()), \ + Resize); \ ONNX_OPERATOR_TYPED_KERNEL_EX( \ Resize, \ kOnnxDomain, \ - 13, \ + 18, \ T, \ kCudaExecutionProvider, \ (*KernelDefBuilder::Create()) \ diff --git a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu new file mode 100644 index 0000000000000..56b7c3f499303 --- /dev/null +++ b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu @@ -0,0 +1,1179 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/tensor/resize_impl.h" + +#define FUNC_DEF __device__ + +namespace onnxruntime { +namespace cuda { + +using onnxruntime::ResizeCoordinateTransformationMode; +using onnxruntime::UpsampleMode; + +/// +/// Compute a buffer for bilinear data for CUDA antialias resizing. +/// +static std::tuple ComputeBilinearScaleBufferSize( + int64_t output_height, int64_t output_width, + float height_rscale, float width_rscale, + float support_value, + float& scaled_support_height, float& scaled_support_width, + int32_t& window_size_height, int32_t& window_size_width) { + scaled_support_height = ComputeScaledSupportValue(support_value, height_rscale); + scaled_support_width = ComputeScaledSupportValue(support_value, width_rscale); + window_size_height = ComputeWindowSize(scaled_support_height); + window_size_width = ComputeWindowSize(scaled_support_width); + + auto height_buffer_size = ComputeWeightedCoeffBufferSize(output_height, window_size_height); + auto width_buffer_size = ComputeWeightedCoeffBufferSize(output_width, window_size_width); + + return std::make_tuple(height_buffer_size, width_buffer_size); +} + +/// +/// Compute a buffer for btrilinear data for CUDA antialias resizing. +/// +static std::tuple ComputeTrilinearScaleBufferSize( + int64_t output_depth, int64_t output_height, int64_t output_width, + float depth_rscale, float height_rscale, float width_rscale, + float support_value, + float& scaled_support_depth, float& scaled_support_height, + float& scaled_support_width, int32_t& window_size_depth, + int32_t& window_size_height, int32_t& window_size_width) { + scaled_support_depth = ComputeScaledSupportValue(support_value, depth_rscale); + window_size_depth = ComputeWindowSize(scaled_support_depth); + auto depth_buffer_size = ComputeWeightedCoeffBufferSize(output_depth, window_size_depth); + + const auto [y_buffer_size, w_buffer_size] = ComputeBilinearScaleBufferSize(output_height, + output_width, height_rscale, + width_rscale, support_value, + scaled_support_height, + scaled_support_width, + window_size_height, window_size_width); + return std::make_tuple(depth_buffer_size, y_buffer_size, w_buffer_size); +} + +// Antialiasing filters +struct BilinearFilter { + __device__ __host__ float operator()(float x, float /* cubic_coeff_a */) const { + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return 1.0f - x; + } + return 0.0f; + } +}; + +struct BiCubicFilter { + __device__ __host__ float operator()(float x, float cubic_coeff_a) const { + /* https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm + */ + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return ((cubic_coeff_a + 2.0f) * x - (cubic_coeff_a + 3.0f)) * x * x + 1; + } + if (x < 2.0f) { + return (((x - 5.0f) * x + 8.f) * x - 4.f) * cubic_coeff_a; + } + return 0.0f; + } +}; + +struct TriLinearFilter { + __device__ __host__ float operator()(float x, float /* cubic_coeff_a */) const { + if (x < 0.0f) { + x = -x; + } + if (x < 1.0f) { + return 1.0f - x; + } + return 0.0f; + } +}; + +template +struct AccumTypeCaster { + static __device__ __host__ AccumType* cast(AccumType* p) { + return p; + } +}; + +template <> +struct AccumTypeCaster { + static __device__ __host__ float* cast(int32_t* p) { + return reinterpret_cast(p); + } +}; + +template +__global__ void _ComputeInterpolationAtLevel1( + int64_t num_channels, + int64_t input_height, int64_t input_width, + int64_t output_height, int64_t output_width, + const fast_divmod div_output_width, + const fast_divmod div_output_image, + int32_t window_size, + const uint8_t* clip8_table, + const int64_t* bound_data, + std::tuple outof_bounds_buffers, + const AccumType* weight_coefficients, + const T* Xdata, T* Ydata, + const int N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + // No need to do scale + if (output_width == input_width) { + Ydata[id] = Xdata[id]; + return; + } + + int bxc, output_image_index; + div_output_image.divmod(id, bxc, output_image_index); + + int output_y, output_x; + div_output_width.divmod(output_image_index, output_y, output_x); + + CUDA_LONG input_index = static_cast(bxc * num_channels * input_height * input_width); + CUDA_LONG output_index = static_cast(bxc * num_channels * output_height * output_width); + + auto* Ydata_offset = Ydata + output_index + output_width * output_y + output_x; + const auto* bound = bound_data; + + AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = weight_coefficients + window_size * output_x; + int64_t xmin = bound[static_cast(output_x) * 2]; + int64_t xmax = bound[static_cast(output_x) * 2 + 1]; + + // Input window + const auto* Xdata_offset = Xdata + input_index + input_width * output_y + xmin; + + for (; xmin < xmax; ++xmin) { + if constexpr (std::is_same::value) { + // This cast is needed when we deal with half + output += static_cast((*Xdata_offset++)) * (*weight_coeff++); + } else { + output += (*Xdata_offset++) * (*weight_coeff++); + } + } + + if constexpr (onnxruntime::is_8bit_v) { + const uint8_t* clip8_lookups = &clip8_table[640]; + *Ydata_offset = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset = static_cast(std::round(output)); + } else { + *Ydata_offset = static_cast(output); + } +} + +template +__global__ void _ComputeInterpolationAtLevel2( + int64_t num_channels, + int64_t input_height, int64_t input_width, + int64_t output_height, int64_t output_width, + const fast_divmod div_output_height, + const fast_divmod div_output_width, + const fast_divmod div_output_image, + int32_t window_size, + bool use_extrapolation, float extrapolation_value, + const uint8_t* clip8_table, + const int64_t* bound_data, + std::tuple outof_bounds_buffers, + const AccumType* weight_coefficients, + const T* Xdata, T* Ydata, int N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + // No need to do scale + if (output_height == input_height) { + Ydata[id] = Xdata[id]; + return; + } + + int bxc, output_image_index; + div_output_image.divmod(id, bxc, output_image_index); + + int output_z, output_y, output_x, temp; + div_output_height.divmod(output_image_index, output_z, temp); + div_output_width.divmod(temp, output_y, output_x); + + CUDA_LONG input_index = static_cast(bxc * num_channels * input_height * input_width + + output_z * input_height * input_width); + CUDA_LONG output_index = static_cast(bxc * num_channels * output_height * output_width + + output_z * output_height * output_width); + + auto* Ydata_offset = Ydata + output_index + output_width * output_y + output_x; + + if (use_extrapolation) { + const auto* w_outof_bounds = std::get<1>(outof_bounds_buffers); + // Extrapolate along the w dimension + if (w_outof_bounds[static_cast(output_x)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + + // Extrapolate along the y dimension + const auto* y_outof_bounds = std::get<0>(outof_bounds_buffers); + if (y_outof_bounds[static_cast(output_y)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + } + + const auto* bound = bound_data; + + AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = weight_coefficients + window_size * output_y; + int64_t ymin = bound[static_cast(output_y) * 2]; + int64_t ymax = bound[static_cast(output_y) * 2 + 1]; + + const auto* Xdata_offset = Xdata + input_index + ymin * output_width + output_x; + + for (; ymin < ymax; ++ymin) { + if constexpr (std::is_same::value) { + // We cast to AccumType to resolve ambiguous call to operator* for half in CUDA + output += static_cast((*Xdata_offset)) * (*weight_coeff++); + } else { + output += (*Xdata_offset) * (*weight_coeff++); + } + Xdata_offset += input_width; + } + + if constexpr (onnxruntime::is_8bit_v) { + const uint8_t* clip8_lookups = &clip8_table[640]; + *Ydata_offset = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset = static_cast(std::round(output)); + } else { + *Ydata_offset = output; + } +} + +template +__global__ void _ComputeInterpolationAtLevel3( + int64_t input_depth, + int64_t input_height, int64_t input_width, + int64_t output_depth, + int64_t output_height, int64_t output_width, + const fast_divmod div_output_height, + const fast_divmod div_output_width, + const fast_divmod div_output_image, + int32_t window_size, + bool use_extrapolation, float extrapolation_value, + const uint8_t* clip8_table, + const int64_t* bound_data, + std::tuple outof_bounds_buffers, + const AccumType* weight_coefficients, + const T* Xdata, T* Ydata, int N) { + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + // No need to do scale + if (input_depth == output_depth) { + Ydata[id] = Xdata[id]; + return; + } + + int bxc, output_image_index; + div_output_image.divmod(id, bxc, output_image_index); + + int output_z, output_y, output_x, temp; + div_output_height.divmod(output_image_index, output_z, temp); + div_output_width.divmod(temp, output_y, output_x); + + CUDA_LONG input_index = static_cast(bxc * input_depth * input_height * input_width); + + auto* Ydata_offset = Ydata + id; + + if (use_extrapolation) { + const auto* w_outof_bounds = std::get<2>(outof_bounds_buffers); + // Extrapolate along the w dimension + if (w_outof_bounds[static_cast(output_x)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + + // Extrapolate along the y dimension + const auto* y_outof_bounds = std::get<1>(outof_bounds_buffers); + if (y_outof_bounds[static_cast(output_y)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + + // Extrapolate along the z dimension + const int64_t* z_outof_bounds = std::get<0>(outof_bounds_buffers); + if (z_outof_bounds != nullptr && z_outof_bounds[static_cast(output_z)] != -1) { + *Ydata_offset = static_cast(extrapolation_value); + return; + } + } + + const auto* bound = bound_data; + + AccumType output = onnxruntime::is_8bit_v ? ConstValue::mag_factor : 0; + + const auto* weight_coeff = weight_coefficients + window_size * output_z; + int64_t zmin = bound[static_cast(output_z) * 2]; + int64_t zmax = bound[static_cast(output_z) * 2 + 1]; + + const auto z_step = input_height * input_width; + const auto* Xdata_offset = Xdata + input_index + zmin * z_step + output_y * output_width + output_x; + + for (; zmin < zmax; ++zmin) { + if constexpr (std::is_same::value) { + // We cast to AccumType to resolve ambiguous call to operator* for half in CUDA + output += static_cast((*Xdata_offset)) * (*weight_coeff++); + } else { + output += (*Xdata_offset) * (*weight_coeff++); + } + Xdata_offset += z_step; + } + + if constexpr (onnxruntime::is_8bit_v) { + const uint8_t* clip8_lookups = &clip8_table[640]; + *Ydata_offset = static_cast(clip8_lookups[output >> 22]); + } else if constexpr (std::is_same::value) { + *Ydata_offset = static_cast(std::round(output)); + } else { + *Ydata_offset = output; + } +} + +/// +/// This function expects the following buffers to be pre-allocated on device +/// 1. bounds: int64_t[output_size * 2] +/// 2. out_of_bounds: int64_t[output_size] +/// 3. scale_data: T[output_size * window_size] +/// +/// Template parameter AccumType +/// +template +FUNC_DEF void SetupUpsampleFilterAnitAliasImpl( + int64_t i, + int64_t input_size, int64_t output_size, + float rscale, + float roi_start, float roi_end, + float scaled_support, int32_t window_size, bool exclude_outside, + float cubic_coeff_a, + int64_t* bounds, + int64_t* out_of_bounds, + AccumType* scale_data) { + Filter filter{}; + CudaFunctionOriginalCoordinate get_original_coordinate{}; + + const auto scale = 1.f / rscale; + const float inv_scale = (scale >= 1.0f) ? 1.0f / scale : 1.0f; + + const float id = static_cast(i); + float center = 0.5f; + if (scale == 1.0f) { + center += id; + } else { + center += get_original_coordinate(id, rscale, + static_cast(output_size), + static_cast(input_size), + roi_start, roi_end); + } + + if (center - 0.5f < 0 || center - 0.5f > static_cast(input_size - 1)) { + out_of_bounds[i] = i; + } else { + out_of_bounds[i] = -1; + } + + float total_weight{0}; + + auto fmin = _Floor(center - scaled_support + 0.5f); + auto fmax = _Floor(center + scaled_support + 0.5f); + + int64_t min_real = static_cast(fmin); + int64_t max_real = static_cast(fmax); + int64_t min_cut = std::max(min_real, 0); + int64_t max_cut = std::min(max_real, input_size); + + int64_t min_val = exclude_outside ? min_cut : min_real; + int64_t max_val = exclude_outside ? max_cut : max_real; + bounds[i * 2] = min_cut; + bounds[i * 2 + 1] = max_cut; + + // This is done for int32_t case, when the final result is in int32_t, but + // we perform calculations in float. All other types as is. + auto* scale_buffer = AccumTypeCaster::cast(&scale_data[i * window_size]); + + max_val -= min_val; + for (int64_t x = 0; x < max_val; x++) { + const float arg = (x + min_val - center + 0.5f) * inv_scale; + const auto w = filter(arg, cubic_coeff_a); + scale_buffer[x] = w; + total_weight += w; + } + + if (!exclude_outside) { + int64_t neg_xsize = min_val < 0 ? -min_val : 0; + for (int64_t x = 0; x < neg_xsize; x++) { + scale_buffer[neg_xsize] += scale_buffer[x]; + } + + int64_t bound_size = + max_val + min_val > input_size ? max_val + min_val - input_size : 0; + for (int64_t x = max_val - bound_size; x < max_val; x++) { + scale_buffer[max_val - bound_size - 1] += + scale_buffer[x]; + } + + for (int64_t x = 0; (neg_xsize | bound_size) > 0 && x < max_cut - min_cut; x++) { + scale_buffer[x] = scale_buffer[x + neg_xsize]; + } + } + + const float total_weight_inv = (total_weight == 0) ? 1.f : (1.f / total_weight); + if constexpr (std::is_same::value) { + auto* scale_buffer_int = reinterpret_cast(scale_buffer); + for (int64_t x = 0; x < max_cut - min_cut; x++) { + scale_buffer[x] *= total_weight_inv; + // normalize the scale to 1 << 22 for int8/uint8 + scale_buffer_int[x] = static_cast(_Round(scale_buffer[x] * ConstValue::mag_factor_x_2)); + } + } else { + for (int64_t x = 0; x < max_cut - min_cut; x++) { + scale_buffer[x] *= total_weight_inv; + } + } +} + +/// This kernel computes antialias filter for bilinear or bicubic upsampling. +/// The function expects the following buffers to be pre-allocated on device +/// 1. bounds: int64_t[output_size * 2] for each of the two dimensions +/// 2. out_of_bounds: int64_t[output_size] for each of the two dimensions +/// 3. scale_data: AccumType[output_size * window_size] for each of the two dimensions +/// Buffers layout [h_data, w_data] +template +__global__ void _SetupBilinearUpsampleFilterAntiAlias( + std::tuple input_dims, // h, w + std::tuple output_dims, // h, w + std::tuple inv_scale_vals, // h, w + std::tuple roi_start_vals, // h, w + std::tuple roi_end_vals, // h, w + std::tuple dim_scaled_support, // Pre-computed scaled support values h, w + std::tuple dim_window_size, // Pre-computed windows sizes h, w + float cubic_coeff_a, + bool exclude_outside, + int64_t* bounds, + int64_t* out_of_bounds, + std::tuple weighted_coefficients // y, h buffers +) { + const auto N = std::get<0>(output_dims) + std::get<1>(output_dims); + + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + if (id < std::get<0>(output_dims)) { + // Setup for y + int64_t input_size = std::get<0>(input_dims); + int64_t output_size = std::get<0>(output_dims); + float inv_scale = std::get<0>(inv_scale_vals); + float roi_start = std::get<0>(roi_start_vals); + float roi_end = std::get<0>(roi_end_vals); + float scaled_support = std::get<0>(dim_scaled_support); + int32_t window_size = std::get<0>(dim_window_size); + + SetupUpsampleFilterAnitAliasImpl( + id, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outside, + cubic_coeff_a, + bounds, + out_of_bounds, + std::get<0>(weighted_coefficients)); + + } else { + // Setup for w + // w = id - output_height + + int64_t input_size = std::get<1>(input_dims); + int64_t output_size = std::get<1>(output_dims); + float inv_scale = std::get<1>(inv_scale_vals); + float roi_start = std::get<1>(roi_start_vals); + float roi_end = std::get<1>(roi_end_vals); + + float scaled_support = std::get<1>(dim_scaled_support); + int32_t window_size = std::get<1>(dim_window_size); + + // Adjust buffer positions + const auto y_output_size = std::get<0>(output_dims); + + auto i = id - y_output_size; + bounds += (y_output_size * 2); + out_of_bounds += y_output_size; + + SetupUpsampleFilterAnitAliasImpl( + i, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outside, + cubic_coeff_a, + bounds, + out_of_bounds, + std::get<1>(weighted_coefficients)); + } +} + +/// +/// Compute AntiAlias filter for trilinear upsampling, all in one go +/// The function expects the following buffers to be pre-allocated on device +/// 1. bounds: int64_t[output_size * 2] for each of the three dimensions +/// 2. out_of_bounds: int64_t[output_size] for each of the three dimensions +/// 3. scale_data: AccumType[output_size * window_size] for each of the three dimensions +/// Each kind of buffer contains data for all 3 dims. +/// Buffers layout [d_data, h_data, w_data] +/// +template +__global__ void _SetupTrilinerarUpsampleFilterAntiAlias( + std::tuple input_dims, // d, h, w + std::tuple output_dims, // d, h, w + std::tuple inv_scale_vals, // d, h, w + std::tuple roi_start_vals, // d, h, w + std::tuple roi_end_vals, // d, h, w + std::tuple dim_scaled_support, // Pre-computed scaled support values d, h, w + std::tuple dim_window_size, // Pre-computed windows sizes d, h, w + bool exclude_outisde, + int64_t* bounds, + int64_t* out_of_bounds, + std::tuple weighted_coefficients) { + const auto N = std::get<0>(output_dims) + std::get<1>(output_dims) + std::get<2>(output_dims); + + CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N); + + if (id < std::get<0>(output_dims)) { + // Setup for d by default (id < output_depth) + int64_t input_size = std::get<0>(input_dims); + int64_t output_size = std::get<0>(output_dims); + float inv_scale = std::get<0>(inv_scale_vals); + float roi_start = std::get<0>(roi_start_vals); + float roi_end = std::get<0>(roi_end_vals); + float scaled_support = std::get<0>(dim_scaled_support); + int32_t window_size = std::get<0>(dim_window_size); + + SetupUpsampleFilterAnitAliasImpl( + id, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outisde, + onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear + bounds, + out_of_bounds, + std::get<0>(weighted_coefficients)); + + } else if (id >= std::get<0>(output_dims) && id < (std::get<0>(output_dims) + std::get<1>(output_dims))) { + int64_t input_size = std::get<1>(input_dims); + int64_t output_size = std::get<1>(output_dims); + float inv_scale = std::get<1>(inv_scale_vals); + float roi_start = std::get<1>(roi_start_vals); + float roi_end = std::get<1>(roi_end_vals); + + float scaled_support = std::get<1>(dim_scaled_support); + int32_t window_size = std::get<1>(dim_window_size); + + // Adjust buffer positions + const auto d_output_size = std::get<0>(output_dims); + + auto i = id - d_output_size; + bounds += d_output_size * 2; + out_of_bounds += d_output_size; + + SetupUpsampleFilterAnitAliasImpl( + i, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outisde, + onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear + bounds, + out_of_bounds, + std::get<1>(weighted_coefficients)); + } else { + int64_t input_size = std::get<2>(input_dims); + int64_t output_size = std::get<2>(output_dims); + float inv_scale = std::get<2>(inv_scale_vals); + float roi_start = std::get<2>(roi_start_vals); + float roi_end = std::get<2>(roi_end_vals); + float scaled_support = std::get<2>(dim_scaled_support); + int32_t window_size = std::get<2>(dim_window_size); + + // Adjust buffer positions + const auto d_y_output_size = std::get<0>(output_dims) + std::get<1>(output_dims); + + auto i = id - d_y_output_size; + bounds += (d_y_output_size * 2); + out_of_bounds += d_y_output_size; + + SetupUpsampleFilterAnitAliasImpl( + i, + input_size, output_size, + inv_scale, + roi_start, roi_end, + scaled_support, window_size, + exclude_outisde, + onnxruntime::antialias_constants::kCubicCoeffA, // Default value for trilinear + bounds, + out_of_bounds, + std::get<2>(weighted_coefficients)); + } +} + +#define CASEA_COORD_ANTIALIAS(coordinate_mode, TransformCoordType, ...) \ + case coordinate_mode: { \ + using coord_t = TransformCoordType; \ + return __VA_ARGS__(); \ + break; \ + } + +#define DISPATCH_ANTIALIAS_FILTER_SETUP(coord_enum, ...) \ + [&] { \ + const auto the_type = coord_enum; \ + switch (the_type) { \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::HALF_PIXEL, \ + TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::ASYMMETRIC, \ + TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL, \ + TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::ALIGN_CORNERS, \ + TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, \ + TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \ + CASEA_COORD_ANTIALIAS(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, \ + TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \ + default: \ + ORT_THROW("unknown ResizeCoordinateTransformationMode"); \ + } \ + }() + +namespace { +template +IAllocatorUniquePtr AllocateTyped( + const TempSpaceAllocateFunc& alloc, + size_t elements) { + return alloc(elements * sizeof(T)); +} + +template +T* GetTyped(IAllocatorUniquePtr& bytes) { + return reinterpret_cast(bytes.get()); +} +} // namespace + +template +void ResizeTrilinearUpsample( + cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + const TempSpaceAllocateFunc& allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + using AccumType = typename onnxruntime::AccumulateType::type; + + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + + int64_t input_depth, input_height, input_width; + std::tie(input_depth, input_height, input_width) = inferred_input_dims; + + int64_t output_depth, output_height, output_width; + std::tie(output_depth, output_height, output_width) = inferred_output_dims; + + int blocksPerDimsMappingGrid = + static_cast(ceil((output_depth + output_height + output_width) / 32.0)); + + int blocksPerGrid = static_cast(ceil(static_cast(N) / GridDim::maxThreadsPerBlock)); + + constexpr float support_value = antialias_constants::kSupportSize; + float z_scale, h_scale, w_scale; + std::tie(z_scale, h_scale, w_scale) = inferred_dim_rscales; + + const auto& div_output_width = output_div_pitches[rank - 2]; + + SafeInt bounds_buffer_size = (SafeInt(output_depth) + output_height + output_width) * 2; + SafeInt out_of_bounds_buffer_size = (SafeInt(output_depth) + output_height + output_width); + + auto bounds_buffer_ptr = AllocateTyped(allocate_temp_space, bounds_buffer_size); + auto out_of_bounds_buffer_ptr = AllocateTyped(allocate_temp_space, out_of_bounds_buffer_size); + + int64_t* z_bounds_buffer = GetTyped(bounds_buffer_ptr); + int64_t* y_bounds_buffer = z_bounds_buffer + output_depth * 2; + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* z_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* y_outof_bounds_buffer = z_outof_bounds_buffer + output_depth; + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + + float z_scaled_support, h_scaled_support, w_scaled_support; + int32_t z_window_size, h_window_size, w_window_size; + const auto [z_buffer_size, y_buffer_size, w_buffer_size] = ComputeTrilinearScaleBufferSize( + output_depth, output_height, output_width, + z_scale, h_scale, w_scale, support_value, + z_scaled_support, h_scaled_support, w_scaled_support, + z_window_size, h_window_size, w_window_size); + + const int64_t weighted_buffer_size = SafeInt(z_buffer_size) + y_buffer_size + w_buffer_size; + + auto weighted_buffer_ptr = AllocateTyped(allocate_temp_space, weighted_buffer_size); + AccumType* z_weighted_buffer = GetTyped(weighted_buffer_ptr); + AccumType* y_weighted_buffer = z_weighted_buffer + z_buffer_size; + AccumType* w_weighted_buffer = y_weighted_buffer + y_buffer_size; + + const auto h_w_interpolate_temp_buf_size = SafeInt(batch_size) * num_channels * + input_depth * input_height * output_width; + auto h_w_interpolate_temp_buffer_ptr = AllocateTyped(allocate_temp_space, + narrow(h_w_interpolate_temp_buf_size)); + + const auto h_w_interpolate_result_buffer_size = SafeInt(batch_size) * num_channels * + input_depth * output_height * output_width; + auto h_w_interpolate_result_buffer_ptr = AllocateTyped(allocate_temp_space, h_w_interpolate_result_buffer_size); + + // clang-format off + DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { + _SetupTrilinerarUpsampleFilterAntiAlias<<>>( + inferred_input_dims, + inferred_output_dims, + inferred_dim_rscales, + std::make_tuple(roi_vals[rank - 3], roi_vals[rank - 2], roi_vals[rank - 1]), // roi starts d, h, w + std::make_tuple(roi_vals[rank - 3 + rank], roi_vals[rank - 2 + rank], // roi ends d, h, w + roi_vals[rank - 1 + rank]), + std::make_tuple(z_scaled_support, h_scaled_support, w_scaled_support), + std::make_tuple(z_window_size, h_window_size, w_window_size), + exclude_outside, + GetTyped(bounds_buffer_ptr), + GetTyped(out_of_bounds_buffer_ptr), + std::make_tuple(z_weighted_buffer, y_weighted_buffer, w_weighted_buffer)); + }); + + // clang-format on + const fast_divmod div_w_image(narrow(num_channels * input_depth * input_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel1<<>>( + num_channels * input_depth, input_height, input_width, input_height, output_width, + div_output_width, + div_w_image, + w_window_size, + clip8_lookups, + w_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + w_weighted_buffer, input_data, + GetTyped(h_w_interpolate_temp_buffer_ptr), + narrow(h_w_interpolate_temp_buf_size)); + + // clang-format on + const fast_divmod div_output_height{narrow(output_height * output_width)}; + const fast_divmod div_h_w_image(narrow(num_channels * input_depth * output_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel2<<>>( + num_channels * input_depth, input_height, output_width, output_height, output_width, + div_output_height, + div_output_width, + div_h_w_image, + h_window_size, + false, 0.f, // No extrapolation + clip8_lookups, + y_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + y_weighted_buffer, GetTyped(h_w_interpolate_temp_buffer_ptr), + GetTyped(h_w_interpolate_result_buffer_ptr), + narrow(h_w_interpolate_result_buffer_size)); + + // clang-format on + const fast_divmod div_z_h_w_image(narrow(input_depth * output_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel3<<>>( + input_depth, output_height, output_width, + output_depth, output_height, output_width, + div_output_height, + div_output_width, + div_z_h_w_image, + z_window_size, + use_extrapolation, extrapolation_value, + clip8_lookups, + z_bounds_buffer, + std::make_tuple(z_outof_bounds_buffer, y_outof_bounds_buffer, w_outof_bounds_buffer), + z_weighted_buffer, GetTyped(h_w_interpolate_result_buffer_ptr), + output_data, + narrow(N)); + // clang-format on +} + +template +void ResizeBiLinearUpsample(cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + const TempSpaceAllocateFunc& allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + using AccumType = typename onnxruntime::AccumulateType::type; + + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + + int64_t input_depth, input_height, input_width; + std::tie(input_depth, input_height, input_width) = inferred_input_dims; + + int64_t output_depth, output_height, output_width; + std::tie(output_depth, output_height, output_width) = inferred_output_dims; + + int blocksPerDimsMappingGrid = + narrow(CeilDiv((output_depth + output_height + output_width), 32)); + + // rank 2 or 4 + const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4] + : fast_divmod(gsl::narrow_cast(N)); + const fast_divmod& div_output_width = output_div_pitches[rank - 2]; + + constexpr float support_value = antialias_constants::kSupportSize; + + float h_scale, w_scale; + std::tie(std::ignore, h_scale, w_scale) = inferred_dim_rscales; + + int blocksPerGrid = narrow(CeilDiv(N, GridDim::maxThreadsPerBlock)); + + SafeInt bounds_buffer_size = (SafeInt(output_height) + output_width) * 2; + SafeInt out_of_bounds_buffer_size = (SafeInt(output_height) + output_width); + + float h_scaled_support, w_scaled_support; + int32_t h_window_size, w_window_size; + const auto [weighted_y_size, weighted_w_size] = + ComputeBilinearScaleBufferSize(output_height, output_width, + h_scale, w_scale, support_value, + h_scaled_support, w_scaled_support, h_window_size, w_window_size); + + auto bounds_buffer_ptr = AllocateTyped(allocate_temp_space, bounds_buffer_size); + auto out_of_bounds_buffer_ptr = AllocateTyped(allocate_temp_space, out_of_bounds_buffer_size); + + int64_t* y_bounds_buffer = GetTyped(bounds_buffer_ptr); + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* y_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + + const int64_t weighted_buffer_size = SafeInt(weighted_y_size) + weighted_w_size; + auto weighted_buffer_ptr = AllocateTyped(allocate_temp_space, narrow(weighted_buffer_size)); + + AccumType* y_weighted_buffer = GetTyped(weighted_buffer_ptr); + AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size; + + const auto temp_buf_size = num_channels * input_height * output_width; + auto image_temp_buffer = AllocateTyped(allocate_temp_space, narrow(temp_buf_size)); + + // clang-format off + DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { + // Data is d, h, w in tuples + + _SetupBilinearUpsampleFilterAntiAlias<<>>( + std::make_tuple(input_height, input_width), + std::make_tuple(output_height, output_width), + std::make_tuple(h_scale, w_scale), + std::make_tuple(roi_vals[rank - 2], roi_vals[rank - 1]), // roi starts h, w + std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]), // roi ends h, w + std::make_tuple(h_scaled_support, w_scaled_support), + std::make_tuple(h_window_size, w_window_size), + onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside, + GetTyped(bounds_buffer_ptr), + GetTyped(out_of_bounds_buffer_ptr), + std::make_tuple(y_weighted_buffer, w_weighted_buffer)); + }); + + // clang-format on + const fast_divmod div_step_image{narrow(num_channels * input_height * output_width)}; + // clang-format off + _ComputeInterpolationAtLevel1<<>>( + num_channels, input_height, input_width, input_height, output_width, + div_output_width, + div_step_image, + w_window_size, + clip8_lookups, + w_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + w_weighted_buffer, input_data, GetTyped(image_temp_buffer), + narrow(temp_buf_size)); + + // clang-format on + const fast_divmod div_output_height{narrow(output_height * output_width)}; + // clang-format off + _ComputeInterpolationAtLevel2<<>>( + num_channels, input_height, output_width, output_height, output_width, + div_output_height, + div_output_width, + div_output_image, + h_window_size, + use_extrapolation, extrapolation_value, + clip8_lookups, + y_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + y_weighted_buffer, GetTyped(image_temp_buffer), output_data, + narrow(N)); + + // clang-format on +} + +template +void ResizeBicubicUpsample(cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + // const TArray& input_strides, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + const TempSpaceAllocateFunc& allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + using AccumType = typename onnxruntime::AccumulateType::type; + + const bool use_extrapolation = extrapolation.has_value(); + const float extrapolation_value = use_extrapolation ? *extrapolation : 0.f; + + int blocksPerGrid = narrow(CeilDiv(N, GridDim::maxThreadsPerBlock)); + const fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 4] + : fast_divmod(gsl::narrow_cast(N)); + const fast_divmod& div_output_width = output_div_pitches[rank - 2]; + + constexpr float support_value = antialias_constants::kBiCubicSupportSize; + + int64_t input_depth, input_height, input_width; + std::tie(input_depth, input_height, input_width) = inferred_input_dims; + + int64_t output_depth, output_height, output_width; + std::tie(output_depth, output_height, output_width) = inferred_output_dims; + + int blocksPerDimsMappingGrid = + narrow(CeilDiv((output_depth + output_height + output_width), 32)); + + float h_scale, w_scale; + std::tie(std::ignore, h_scale, w_scale) = inferred_dim_rscales; + + SafeInt bounds_buffer_size = (SafeInt(output_height) + output_width) * 2; + SafeInt out_of_bounds_buffer_size = (SafeInt(output_height) + output_width); + + float h_scaled_support, w_scaled_support; + int32_t h_window_size, w_window_size; + const auto [weighted_y_size, weighted_w_size] = + ComputeBilinearScaleBufferSize(output_height, output_width, + h_scale, w_scale, support_value, + h_scaled_support, w_scaled_support, h_window_size, w_window_size); + + auto bounds_buffer_ptr = AllocateTyped(allocate_temp_space, bounds_buffer_size); + auto out_of_bounds_buffer_ptr = AllocateTyped(allocate_temp_space, out_of_bounds_buffer_size); + + int64_t* y_bounds_buffer = GetTyped(bounds_buffer_ptr); + int64_t* w_bounds_buffer = y_bounds_buffer + output_height * 2; + + int64_t* y_outof_bounds_buffer = GetTyped(out_of_bounds_buffer_ptr); + int64_t* w_outof_bounds_buffer = y_outof_bounds_buffer + output_height; + + const int64_t weighted_buffer_size = SafeInt(weighted_y_size) + + weighted_w_size; + auto weighted_buffer_ptr = AllocateTyped(allocate_temp_space, weighted_buffer_size); + + AccumType* y_weighted_buffer = GetTyped(weighted_buffer_ptr); + AccumType* w_weighted_buffer = y_weighted_buffer + weighted_y_size; + + const auto temp_buf_size = SafeInt(batch_size) * num_channels * input_height * output_width; + auto image_temp_buffer = AllocateTyped(allocate_temp_space, narrow(temp_buf_size)); + + // clang-format off + DISPATCH_ANTIALIAS_FILTER_SETUP(coordinate_transform_mode, [&]() { + _SetupBilinearUpsampleFilterAntiAlias<<>>( + std::make_tuple(input_height, input_width), + std::make_tuple(output_height, output_width), + std::make_tuple(h_scale, w_scale), + std::make_tuple(roi_vals[rank - 2], roi_vals[rank - 1]), // roi starts h, w + std::make_tuple(roi_vals[rank - 2 + rank], roi_vals[rank - 1 + rank]), // roi ends h, w + std::make_tuple(h_scaled_support, w_scaled_support), + std::make_tuple(h_window_size, w_window_size), + onnxruntime::antialias_constants::kCubicCoeffA, exclude_outside, + GetTyped(bounds_buffer_ptr), + GetTyped(out_of_bounds_buffer_ptr), + std::make_tuple(y_weighted_buffer, w_weighted_buffer)); + }); + // clang-format on + const fast_divmod div_step_image(narrow(num_channels * input_height * output_width)); + // clang-format off + _ComputeInterpolationAtLevel1<<>>( + num_channels, input_height, input_width, input_height, output_width, + div_output_width, + div_step_image, + w_window_size, + clip8_lookups, + w_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + w_weighted_buffer, input_data, GetTyped(image_temp_buffer), + narrow(temp_buf_size)); + // clang-format on + + const fast_divmod div_output_height{narrow(output_height * output_width)}; + // clang-format off + _ComputeInterpolationAtLevel2<<>>( + num_channels, input_height, output_width, output_height, output_width, + div_output_height, + div_output_width, + div_output_image, + h_window_size, + use_extrapolation, extrapolation_value, + clip8_lookups, + y_bounds_buffer, + std::make_tuple(y_outof_bounds_buffer, w_outof_bounds_buffer), + y_weighted_buffer, GetTyped(image_temp_buffer), output_data, + narrow(N)); + // clang-format on +} + +template +void ResizeAntiAliasImpl( + cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, + const std::optional& extrapolation, + bool exclude_outside, + TempSpaceAllocateFunc allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N) { + // We support a special case of bilinear or bicubic if the input data is 4D with the outer 2 scales being 1.0 + // We would have validated the outer scale values by the time execution reaches this + const bool is_2D = (rank == 2 || rank == 4); + + // We support a special case of trilinear or tricubic if the input data is 5D with the outer 2 scales being 1.0 + // We would have validated the outer scale values by the time execution reaches this + const bool is_3D = (rank == 3 || rank == 5); + + // Should not hit this as we have already validated input rank/scales and we provide verbose error messages + // to the user. + ORT_ENFORCE(is_2D || is_3D, "Only bilinear/trilinear and bicubic modes are supported in Resize anti-alias mode"); + + switch (upsample_mode) { + case UpsampleMode::LINEAR: { + if (is_2D) { + ResizeBiLinearUpsample(stream, rank, upsample_mode, coordinate_transform_mode, + input_shape, output_shape, batch_size, num_channels, + inferred_input_dims, inferred_output_dims, inferred_dim_rscales, + output_div_pitches, roi_vals, extrapolation, exclude_outside, + allocate_temp_space, clip8_lookups, input_data, output_data, N); + } else if (is_3D) { + ResizeTrilinearUpsample(stream, rank, upsample_mode, coordinate_transform_mode, + input_shape, output_shape, batch_size, num_channels, + inferred_input_dims, inferred_output_dims, inferred_dim_rscales, + output_div_pitches, roi_vals, extrapolation, exclude_outside, + allocate_temp_space, clip8_lookups, input_data, output_data, N); + } else { + ORT_NOT_IMPLEMENTED("Resize supports only 2-D or 3-D in LINEAR mode."); + } + } break; + case CUBIC: { + if (is_2D) { + ResizeBicubicUpsample(stream, rank, upsample_mode, coordinate_transform_mode, + input_shape, output_shape, batch_size, num_channels, + inferred_input_dims, inferred_output_dims, inferred_dim_rscales, + output_div_pitches, roi_vals, extrapolation, exclude_outside, + allocate_temp_space, clip8_lookups, input_data, output_data, N); + } else { + ORT_NOT_IMPLEMENTED("Resize supports only 2-D in CUBIC mode."); + } + } break; + default: + ORT_NOT_IMPLEMENTED("Only bilinear/trilinear and bicubic modes are supported in Resize anti-alias mode"); + break; + } +} + +#define SPECIALIZED_ANTIALIAS_IMPL(T) \ + template void ResizeAntiAliasImpl( \ + cudaStream_t stream, \ + int rank, \ + const UpsampleMode upsample_mode, \ + ResizeCoordinateTransformationMode coordinate_transform_mode, \ + gsl::span input_shape, \ + gsl::span output_shape, \ + int64_t batch_size, int64_t num_channels, \ + std::tuple inferred_input_dims, \ + std::tuple inferred_output_dims, \ + std::tuple inferred_dim_rscales, \ + const TArray& output_div_pitches, \ + gsl::span roi_vals, \ + const std::optional& extrapolation_value, \ + bool exclude_outside, \ + TempSpaceAllocateFunc allocate_temp_space, \ + const uint8_t* clip8_lookups, \ + const T* input_data, \ + T* output_data, \ + const size_t N); + +SPECIALIZED_ANTIALIAS_IMPL(float) +SPECIALIZED_ANTIALIAS_IMPL(double) +SPECIALIZED_ANTIALIAS_IMPL(half) +SPECIALIZED_ANTIALIAS_IMPL(int32_t) +SPECIALIZED_ANTIALIAS_IMPL(uint8_t) + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index 1a94c7705e913..0cde0ed8e8681 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -12,7 +12,7 @@ using onnxruntime::ResizeNearestMode; using onnxruntime::UpsampleMode; struct NearestPixel_SIMPLE { - __device__ __forceinline__ int operator() (float x_original, bool is_down_sampling) const { + __device__ __forceinline__ int operator()(float x_original, bool is_down_sampling) const { if (is_down_sampling) { return static_cast(_Ceil(x_original)); } @@ -21,7 +21,7 @@ struct NearestPixel_SIMPLE { }; struct NearestPixel_ROUND_PREFER_FLOOR { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { if (x_original == static_cast(x_original) + 0.5f) { return static_cast(_Floor(x_original)); } @@ -30,62 +30,23 @@ struct NearestPixel_ROUND_PREFER_FLOOR { }; struct NearestPixel_ROUND_PREFER_CEIL { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { return static_cast(roundf(x_original)); } }; struct NearestPixel_FLOOR { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { return static_cast(_Floor(x_original)); } }; struct NearestPixel_CEIL { - __device__ __forceinline__ int operator() (float x_original, bool) const { + __device__ __forceinline__ int operator()(float x_original, bool) const { return static_cast(_Ceil(x_original)); } }; -struct TransformCoordinate_ASYMMETRIC { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const { - return x_resized / x_scale; - } -}; - -struct TransformCoordinate_HALF_PIXEL { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const { - return ((x_resized + 0.5f) / x_scale) - 0.5f; - } -}; - -struct TransformCoordinate_PYTORCH_HALF_PIXEL { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float length_resized, float, float, float) const { - return length_resized > 1 ? (x_resized + 0.5f) / x_scale - 0.5f : 0.0f; - } -}; - -struct TransformCoordinate_TF_HALF_PIXEL_FOR_NN { - __device__ __forceinline__ float operator() (float x_resized, float x_scale, float, float, float, float) const { - return (x_resized + 0.5f) / x_scale; - } -}; - -struct TransformCoordinate_ALIGN_CORNERS { - __device__ __forceinline__ float operator() (float x_resized, float, float length_resized, float length_original, float, float) const { - return length_resized == 1 ? 0 : x_resized * (length_original - 1) / (length_resized - 1); - } -}; - -struct TransformCoordinate_TF_CROP_AND_RESIZE { - __device__ __forceinline__ float operator() (float x_resized, float, float length_resized, float length_original, float roi_start, float roi_end) const { - auto orig = length_resized > 1 - ? roi_start * (length_original - 1) + (x_resized * (roi_end - roi_start) * (length_original - 1)) / (length_resized - 1) - : 0.5 * (roi_start + roi_end) * (length_original - 1); - return static_cast(orig); - } -}; - #define CASE_TYPE_USING_HINT(enum_type, type, HINT, ...) \ case enum_type: { \ using HINT = type; \ @@ -95,20 +56,24 @@ struct TransformCoordinate_TF_CROP_AND_RESIZE { #define CASE_TYPE_COORD(enum_type, type, ...) \ CASE_TYPE_USING_HINT(enum_type, type, coord_t, __VA_ARGS__) -#define DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(TYPE, ...) \ - [&] { \ - const auto& the_type = TYPE; \ - /* don't use TYPE again in case it is an expensive or side-effect op */ \ - switch (the_type) { \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::HALF_PIXEL, TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ASYMMETRIC, TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL, TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ALIGN_CORNERS, TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \ - CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \ - default: \ - ORT_THROW("unknown ResizeCoordinateTransformationMode"); \ - } \ +#define DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(TYPE, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + switch (the_type) { \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::HALF_PIXEL, TransformCoordinate_HALF_PIXEL, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ASYMMETRIC, TransformCoordinate_ASYMMETRIC, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::PYTORCH_HALF_PIXEL, \ + TransformCoordinate_PYTORCH_HALF_PIXEL, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::ALIGN_CORNERS, \ + TransformCoordinate_ALIGN_CORNERS, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_HALF_PIXEL_FOR_NN, \ + TransformCoordinate_TF_HALF_PIXEL_FOR_NN, __VA_ARGS__) \ + CASE_TYPE_COORD(ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE, \ + TransformCoordinate_TF_CROP_AND_RESIZE, __VA_ARGS__) \ + default: \ + ORT_THROW("unknown ResizeCoordinateTransformationMode"); \ + } \ }() #define CASE_TYPE_NEAREST(enum_type, type, ...) \ @@ -119,11 +84,11 @@ struct TransformCoordinate_TF_CROP_AND_RESIZE { const auto& the_type = TYPE; \ /* don't use TYPE again in case it is an expensive or side-effect op */ \ switch (the_type) { \ - CASE_TYPE_NEAREST(ResizeNearestMode::SIMPLE, NearestPixel_SIMPLE, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::SIMPLE, NearestPixel_SIMPLE, __VA_ARGS__) \ CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_FLOOR, NearestPixel_ROUND_PREFER_FLOOR, __VA_ARGS__) \ - CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_CEIL, NearestPixel_ROUND_PREFER_CEIL, __VA_ARGS__) \ - CASE_TYPE_NEAREST(ResizeNearestMode::FLOOR, NearestPixel_FLOOR, __VA_ARGS__) \ - CASE_TYPE_NEAREST(ResizeNearestMode::CEIL, NearestPixel_CEIL, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::ROUND_PREFER_CEIL, NearestPixel_ROUND_PREFER_CEIL, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::FLOOR, NearestPixel_FLOOR, __VA_ARGS__) \ + CASE_TYPE_NEAREST(ResizeNearestMode::CEIL, NearestPixel_CEIL, __VA_ARGS__) \ default: \ ORT_THROW("unknown ResizeNearestMode"); \ } \ @@ -151,10 +116,12 @@ __global__ void _ResizeNearestMappingKernel2D( // only apply co-ordinate transformation if scale != 1.0 if (scales_height == 1.0f) { - dims_mapping[id].extrapolate_ = 0; + dims_mapping[id].extrapolate_ = 0; } else { - float orig_coord = transform_coordinate(static_cast(dim), scales_height, static_cast(output_height), - static_cast(input_height), roi_start_height, roi_end_height); + float orig_coord = transform_coordinate(static_cast(dim), scales_height, + static_cast(output_height), + static_cast(input_height), + roi_start_height, roi_end_height); dims_mapping[id].extrapolate_ = static_cast( extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast(input_height - 1))); dim = calc_nearest_pixel(orig_coord, scales_height < 1); @@ -210,9 +177,12 @@ __global__ void _ResizeNearestMappingKernel( if (scales[axis] == 1.0f) { dims_mapping[id].extrapolate_ = 0; } else { - float orig_coord = transform_coordinate(static_cast(dim), scales[axis], static_cast(output_shape[axis]), + float orig_coord = transform_coordinate(static_cast(dim), scales[axis], + static_cast(output_shape[axis]), static_cast(input_shape[axis]), roi[axis], roi[axis + rank]); - dims_mapping[id].extrapolate_ = static_cast(extrapolation_enabled && (orig_coord < 0.f || orig_coord > static_cast(input_shape[axis] - 1))); + dims_mapping[id].extrapolate_ = static_cast(extrapolation_enabled && + (orig_coord < 0.f || + orig_coord > static_cast(input_shape[axis] - 1))); dim = calc_nearest_pixel(orig_coord, scales[axis] < 1); if (dim >= input_shape[axis]) dim = input_shape[axis] - 1; if (dim < 0) dim = 0; @@ -293,21 +263,27 @@ __global__ void _ResizeBilinearCoordinateMapping( LinearMappingInfo* dims_mapping) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, SumHW); if (id < output_height) { // y = id - float input_y = scale_height == 1 ? static_cast(id) : - transform_coordinate(static_cast(id), scale_height, - static_cast(output_height), static_cast(input_height), - roi_height_start, roi_height_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_y < 0 || input_y > static_cast(input_height - 1))); + float input_y = scale_height == 1 ? static_cast(id) + : transform_coordinate(static_cast(id), scale_height, + static_cast(output_height), + static_cast(input_height), + roi_height_start, roi_height_end); + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_y < 0 || + input_y > static_cast(input_height - 1)))); input_y = max(0.0f, min(input_y, static_cast(input_height - 1))); int y_int = static_cast(input_y); dims_mapping[id].origin_ = y_int; dims_mapping[id].weight_ = (y_int >= input_height - 1) ? 0.5f : input_y - y_int; - } else { //x = id - output_height - float input_x = scale_width == 1 ? static_cast(id - output_height) : - transform_coordinate(static_cast(id - output_height), scale_width, - static_cast(output_width), static_cast(input_width), - roi_width_start, roi_width_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || input_x > static_cast(input_width - 1))); + } else { // x = id - output_height + float input_x = scale_width == 1 ? static_cast(id - output_height) + : transform_coordinate(static_cast(id - output_height), + scale_width, static_cast(output_width), + static_cast(input_width), roi_width_start, + roi_width_end); + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_x < 0 || + input_x > static_cast(input_width - 1)))); input_x = max(0.0f, min(input_x, static_cast(input_width - 1))); int x_int = static_cast(input_x); dims_mapping[id].origin_ = x_int; @@ -371,32 +347,40 @@ __global__ void _ResizeTrilinearCoordinateMapping( LinearMappingInfo* dims_mapping) { CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, SumDHW); if (id < output_depth) { // z = id - float input_z = scale_depth == 1 ? static_cast(id) : - transform_coordinate(static_cast(id), scale_depth, - static_cast(output_depth), static_cast(input_depth), - roi_depth_start, roi_depth_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_z < 0 || input_z > static_cast(input_depth - 1))); + float input_z = scale_depth == 1 ? static_cast(id) + : transform_coordinate(static_cast(id), scale_depth, + static_cast(output_depth), + static_cast(input_depth), + roi_depth_start, roi_depth_end); + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_z < 0 || + input_z > static_cast(input_depth - 1)))); input_z = max(0.0f, min(input_z, static_cast(input_depth - 1))); int z_int = static_cast(input_z); dims_mapping[id].origin_ = z_int; dims_mapping[id].weight_ = (z_int >= input_depth - 1) ? 0.5f : input_z - z_int; } else if (id >= output_depth && id < (output_depth + output_height)) { // y = id - output_depth - float input_y = scale_height == 1 ? static_cast(id - output_depth) : - transform_coordinate(static_cast(id - output_depth), scale_height, - static_cast(output_height), static_cast(input_height), - roi_height_start, roi_height_end); - - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_y < 0 || input_y > static_cast(input_height - 1))); + float input_y = scale_height == 1 ? static_cast(id - output_depth) + : transform_coordinate(static_cast(id - output_depth), + scale_height, static_cast(output_height), + static_cast(input_height), + roi_height_start, roi_height_end); + + dims_mapping[id].extrapolate_ = static_cast((extrapolation_enabled && + (input_y < 0 || + input_y > static_cast(input_height - 1)))); input_y = max(0.0f, min(input_y, static_cast(input_height - 1))); int y_int = static_cast(input_y); dims_mapping[id].origin_ = y_int; dims_mapping[id].weight_ = (y_int >= input_height - 1) ? 0.5f : input_y - y_int; - } else { //x = id - output_depth - output_height - float input_x = scale_width == 1 ? static_cast(id - output_depth - output_height) : - transform_coordinate(static_cast(id - output_depth - output_height), scale_width, - static_cast(output_width), static_cast(input_width), - roi_width_start, roi_width_end); - dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || input_x > static_cast(input_width - 1))); + } else { // x = id - output_depth - output_height + float input_x = scale_width == 1 ? static_cast(id - output_depth - output_height) + : transform_coordinate(static_cast(id - output_depth - output_height), + scale_width, static_cast(output_width), + static_cast(input_width), + roi_width_start, roi_width_end); + dims_mapping[id].extrapolate_ = (int)(extrapolation_enabled && (input_x < 0 || + input_x > static_cast(input_width - 1))); input_x = max(0.0f, min(input_x, static_cast(input_width - 1))); int x_int = static_cast(input_x); dims_mapping[id].origin_ = x_int; @@ -513,21 +497,33 @@ __global__ void _ResizeCubicCoordinateMapping( int max_input_coord = static_cast(is_y_axis ? input_height : input_width); float scale = is_y_axis ? scale_height : scale_width; - float input_coordinat = scale == 1 ? (is_y_axis ? id : id - output_height) : - transform_coordinate( - static_cast(is_y_axis ? id : id - output_height), - scale, - static_cast(is_y_axis ? output_height : output_width), - static_cast(max_input_coord), - (is_y_axis ? roi_height_start : roi_width_start), - (is_y_axis ? roi_height_end : roi_width_end)); + float input_coordinat = scale == 1 ? (is_y_axis ? id : id - output_height) + : transform_coordinate( + static_cast(is_y_axis ? id : id - output_height), + scale, + static_cast(is_y_axis ? output_height : output_width), + static_cast(max_input_coord), + (is_y_axis ? roi_height_start : roi_width_start), + (is_y_axis ? roi_height_end : roi_width_end)); int coord_int = static_cast(_Floor(input_coordinat)); float s_coord = abs(input_coordinat - coord_int); float coeff_sum = 1.0f; - float coeff_0 = static_cast(((cubic_coeff_a * (s_coord + 1) - 5 * cubic_coeff_a) * (s_coord + 1) + 8 * cubic_coeff_a) * (s_coord + 1) - 4 * cubic_coeff_a); - float coeff_1 = static_cast(((cubic_coeff_a + 2) * s_coord - (cubic_coeff_a + 3)) * s_coord * s_coord + 1); - float coeff_2 = static_cast(((cubic_coeff_a + 2) * (1 - s_coord) - (cubic_coeff_a + 3)) * (1 - s_coord) * (1 - s_coord) + 1); - float coeff_3 = static_cast(((cubic_coeff_a * (2 - s_coord) - 5 * cubic_coeff_a) * (2 - s_coord) + 8 * cubic_coeff_a) * (2 - s_coord) - 4 * cubic_coeff_a); + float coeff_0 = static_cast(((cubic_coeff_a * (s_coord + 1) - 5 * cubic_coeff_a) * + (s_coord + 1) + + 8 * cubic_coeff_a) * + (s_coord + 1) - + 4 * cubic_coeff_a); + float coeff_1 = static_cast(((cubic_coeff_a + 2) * s_coord - (cubic_coeff_a + 3)) * + s_coord * s_coord + + 1); + float coeff_2 = static_cast(((cubic_coeff_a + 2) * (1 - s_coord) - (cubic_coeff_a + 3)) * + (1 - s_coord) * (1 - s_coord) + + 1); + float coeff_3 = static_cast(((cubic_coeff_a * (2 - s_coord) - 5 * cubic_coeff_a) * + (2 - s_coord) + + 8 * cubic_coeff_a) * + (2 - s_coord) - + 4 * cubic_coeff_a); if (exclude_outside) { coeff_0 = (coord_int - 1 < 0 || coord_int - 1 >= max_input_coord) ? 0.0 : coeff_0; coeff_1 = (coord_int + 0 < 0 || coord_int + 0 >= max_input_coord) ? 0.0 : coeff_1; @@ -540,7 +536,8 @@ __global__ void _ResizeCubicCoordinateMapping( dm.coeff1_ = coeff_1 / coeff_sum; dm.coeff2_ = coeff_2 / coeff_sum; dm.coeff3_ = coeff_3 / coeff_sum; - dm.extrapolate_ = (int)(extrapolation_enabled && (input_coordinat < 0 || input_coordinat > static_cast(max_input_coord - 1))); + dm.extrapolate_ = (int)(extrapolation_enabled && (input_coordinat < 0 || + input_coordinat > static_cast(max_input_coord - 1))); } template @@ -569,21 +566,30 @@ __global__ void _ResizeBiCubicKernel( int x_int = x_info.origin_; int y_int = y_info.origin_; const T* image = input_data + input_index; - output_data[id] = y_info.coeff0_ * CubicInterpolationRowwise(image, x_int, y_int - 1, input_height, input_width, w0, w1, w2, w3) + - y_info.coeff1_ * CubicInterpolationRowwise(image, x_int, y_int, input_height, input_width, w0, w1, w2, w3) + - y_info.coeff2_ * CubicInterpolationRowwise(image, x_int, y_int + 1, input_height, input_width, w0, w1, w2, w3) + - y_info.coeff3_ * CubicInterpolationRowwise(image, x_int, y_int + 2, input_height, input_width, w0, w1, w2, w3); + output_data[id] = y_info.coeff0_ * + CubicInterpolationRowwise(image, x_int, y_int - 1, input_height, input_width, w0, w1, w2, w3) + + y_info.coeff1_ * + CubicInterpolationRowwise(image, x_int, y_int, input_height, input_width, w0, w1, w2, w3) + + y_info.coeff2_ * + CubicInterpolationRowwise(image, x_int, y_int + 1, input_height, input_width, w0, w1, w2, w3) + + y_info.coeff3_ * + CubicInterpolationRowwise(image, x_int, y_int + 2, input_height, input_width, w0, w1, w2, w3); } size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode, const gsl::span& output_dims) { switch (upsample_mode) { case UpsampleMode::NN: - return sizeof(int64_t) * output_dims.size() + sizeof(NearestMappingInfo) * static_cast(std::accumulate(output_dims.begin(), output_dims.end(), (int64_t)0)); + return sizeof(int64_t) * output_dims.size() + + sizeof(NearestMappingInfo) * + static_cast(std::accumulate(output_dims.begin(), + output_dims.end(), (int64_t)0)); case UpsampleMode::LINEAR: - return sizeof(LinearMappingInfo) * static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); + return sizeof(LinearMappingInfo) * + static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); case UpsampleMode::CUBIC: - return sizeof(CubicMappingInfo) * static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); + return sizeof(CubicMappingInfo) * + static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); } return 0; } @@ -616,7 +622,8 @@ void ResizeNearestImpl( if (could2d) { int64_t output_height = output_shape[rank - 2]; int64_t output_width = output_shape[rank - 1]; - fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 3] : fast_divmod(static_cast(output_height * output_width)); + fast_divmod div_output_image = (rank > 2) ? output_div_pitches[rank - 3] + : fast_divmod(static_cast(output_height * output_width)); int blocksPerDimsMappingGrid = static_cast(ceil((output_height + output_width) / 32.0)); DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(transform_coordinate, [&]() { @@ -694,13 +701,6 @@ void ResizeImpl( ResizeCoordinateTransformationMode coordinate_transform_mode, ResizeNearestMode nearest_mode, void* dims_mapping) { - bool isSame = std::all_of(scales_vals.Data(), scales_vals.Data() + rank, [](float v) { return v == 1.0f; }) && - (coordinate_transform_mode != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE); - if (isSame) { - CUDA_CALL_THROW(cudaMemcpyAsync(output_data, input_data, N * sizeof(T), cudaMemcpyDeviceToDevice, stream)); - return; - } - if (upsample_mode == UpsampleMode::NN) { ResizeNearestImpl( stream, rank, input_shape, output_shape, input_strides, output_div_pitches, @@ -761,7 +761,7 @@ void ResizeImpl( } else if (is_3D) { DISPATCH_RESIZE_COORDINATE_TRANSFORMATION_MODE(coordinate_transform_mode, [&]() { _ResizeTrilinearCoordinateMapping<<>>( - input_shape[rank - 3] , input_shape[rank - 2], input_shape[rank - 1], + input_shape[rank - 3], input_shape[rank - 2], input_shape[rank - 1], output_depth, output_height, output_width, scales_vals[rank - 3], scales_vals[rank - 2], scales_vals[rank - 1], roi_vals[rank - 3], roi_vals[rank - 3 + rank], @@ -778,7 +778,7 @@ void ResizeImpl( reinterpret_cast(dims_mapping)); return; } - ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize"); + ORT_THROW("Resize support 2-D and 3-D dimensions in LINEAR mode."); break; case UpsampleMode::CUBIC: if (is_2D) { @@ -801,7 +801,7 @@ void ResizeImpl( reinterpret_cast(dims_mapping)); return; } - ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize"); + ORT_THROW("Resize supports only 2-D in CUBIC mode."); case UpsampleMode::NN: ORT_THROW("Only bilinear/trilinear and bicubic modes are supported in Resize"); } @@ -809,7 +809,7 @@ void ResizeImpl( #define SPECIALIZED_IMPL(T) \ template void ResizeImpl( \ - cudaStream_t stream, \ + cudaStream_t stream, \ const UpsampleMode upsample_mode, \ const int rank, \ TArray& input_shape, \ diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.h b/onnxruntime/core/providers/cuda/tensor/resize_impl.h index d459dbff18d3e..ad06eebb9efb1 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.h @@ -2,15 +2,69 @@ // Licensed under the MIT License. #pragma once + #include + +#include + #include "core/providers/cuda/shared_inc/cuda_utils.h" #include "core/common/common.h" #include "core/providers/cpu/tensor/upsamplebase.h" #include "core/providers/cuda/cuda_common.h" namespace onnxruntime { +template <> +struct AccumulateType { + using type = float; +}; namespace cuda { +struct TransformCoordinate_ASYMMETRIC { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, + float, float, float, float) const { + return x_resized / x_scale; + } +}; + +struct TransformCoordinate_HALF_PIXEL { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, + float, float, float, float) const { + return ((x_resized + 0.5f) / x_scale) - 0.5f; + } +}; + +struct TransformCoordinate_PYTORCH_HALF_PIXEL { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, float length_resized, float, + float, float) const { + return length_resized > 1 ? (x_resized + 0.5f) / x_scale - 0.5f : 0.0f; + } +}; + +struct TransformCoordinate_TF_HALF_PIXEL_FOR_NN { + __device__ __host__ __forceinline__ float operator()(float x_resized, float x_scale, + float, float, float, float) const { + return (x_resized + 0.5f) / x_scale; + } +}; + +struct TransformCoordinate_ALIGN_CORNERS { + __device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized, + float length_original, float, float) const { + return length_resized == 1 ? 0 : x_resized * (length_original - 1) / (length_resized - 1); + } +}; + +struct TransformCoordinate_TF_CROP_AND_RESIZE { + __device__ __host__ __forceinline__ float operator()(float x_resized, float, float length_resized, + float length_original, float roi_start, float roi_end) const { + auto orig = length_resized > 1 + ? roi_start * (length_original - 1) + + (x_resized * (roi_end - roi_start) * (length_original - 1)) / (length_resized - 1) + : 0.5 * (roi_start + roi_end) * (length_original - 1); + return static_cast(orig); + } +}; + size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode, const gsl::span& output_dims); @@ -36,5 +90,62 @@ void ResizeImpl( onnxruntime::ResizeNearestMode nearest_mode, void* dims_mapping); +using TempSpaceAllocateFunc = std::function(size_t buffer_size)>; + +template +void ResizeAntiAliasImpl( + cudaStream_t stream, + int rank, + const UpsampleMode upsample_mode, + ResizeCoordinateTransformationMode coordinate_transform_mode, + gsl::span input_shape, + gsl::span output_shape, + int64_t batch_size, int64_t num_channels, + std::tuple inferred_input_dims, + std::tuple inferred_output_dims, + std::tuple inferred_dim_rscales, + const TArray& output_div_pitches, + gsl::span roi_vals, // CPU + const std::optional& extrapolation_value, + bool exclude_outside, + TempSpaceAllocateFunc allocate_temp_space, + const uint8_t* clip8_lookups, + const T* input_data, + T* output_data, + const size_t N); + +/// +/// Compute scaled support value for a given dimension inverse scale +/// +/// Support value from parameters +/// inverse scale value comes from input/attr for +/// +inline float ComputeScaledSupportValue(float support_value, float rscale) { + const float scale = 1.0f / rscale; + float scaled_support = (scale >= 1.0f) ? (support_value * 0.5f) * scale : support_value * 0.5f; + return scaled_support; +} + +/// +/// Compute window size for a given dimension scaled support value. +/// +/// +/// +inline int32_t ComputeWindowSize(float scaled_support) { + SafeInt window_size(ceilf(scaled_support)); + return window_size * 2 + 1; +} + +/// +/// Computes scale buffer size in number of elements for allocation purposes. +/// +/// +/// +/// Number of elements to fit in the buffer +inline SafeInt ComputeWeightedCoeffBufferSize(int64_t output_size, int32_t window_size) { + SafeInt buffer_size(output_size); + return buffer_size * window_size; +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.cc b/onnxruntime/core/providers/cuda/tensor/upsample.cc index ae12ca328bc7c..17533eb3d9a72 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.cc +++ b/onnxruntime/core/providers/cuda/tensor/upsample.cc @@ -2,6 +2,9 @@ // Licensed under the MIT License. #include "upsample.h" + +#include + #include "upsample_impl.h" #include "core/providers/cuda/tensor/resize_impl.h" #include "core/providers/cpu/tensor/utils.h" @@ -37,11 +40,23 @@ REGISTER_VERSIONED_TYPED_KERNEL(MLFloat16, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(int32_t, 9, 9); REGISTER_VERSIONED_TYPED_KERNEL(uint8_t, 9, 9); +template +Upsample::Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { + if (UpsampleBase::antialias_) { + // Copy the table on DEVICE + const uint8_t* lookup_table = GetLookupTableShared(); + auto alloc = info.GetAllocator(OrtMemTypeDefault); + shared_lookup_table_ondevice_ = IAllocator::MakeUniquePtr(std::move(alloc), kLookupTableSize); + CUDA_CALL_THROW(cudaMemcpyAsync(shared_lookup_table_ondevice_.get(), lookup_table, kLookupTableSize, + cudaMemcpyHostToDevice, nullptr)); + } +} + template Status Upsample::BaseCompute(OpKernelContext* context, - const std::vector& roi, - const std::vector& scales, - const gsl::span& output_dims) const { + gsl::span roi, + gsl::span scales, + gsl::span output_dims) const { const Tensor* X = context->Input(0); auto X_dims = X->Shape().GetDims(); int32_t rank = static_cast(X_dims.size()); @@ -52,7 +67,8 @@ Status Upsample::BaseCompute(OpKernelContext* context, is_resize_ ? "Resize: input tensor cannot be scalar." : "Upsample: input tensor cannot be scalar."); if (rank != static_cast(scales.size())) return Status(ONNXRUNTIME, INVALID_ARGUMENT, - is_resize_ ? "Resize: input tensor's dimension does not match the scales." : "Upsample: input tensor's dimension does not match the scales."); + is_resize_ ? "Resize: input tensor's dimension does not match the scales." + : "Upsample: input tensor's dimension does not match the scales."); if (roi.size() != 2 * X_dims.size()) return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Resize: size of roi array should be 2 * N where N is the rank of input tensor X."); @@ -79,22 +95,194 @@ Status Upsample::BaseCompute(OpKernelContext* context, size_t output_count = Y->Shape().Size(); if (is_resize_) { - TArray input_shape(X_dims); - TArray output_shape(output_dims); - TArray roi_vals(roi); - TArray scales_vals(scales); - - size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims); - auto dims_mapping_buffer = GetScratchBuffer(temp_buffer_size, context->GetComputeStream()); - void* dims_mapping = reinterpret_cast(dims_mapping_buffer.get()); - ResizeImpl(Stream(context), mode_, (int)rank, input_shape, output_shape, - input_strides, output_div_pitches, scales_vals, roi_vals, - reinterpret_cast(X->Data()), - reinterpret_cast(Y->MutableData()), - output_count, use_extrapolation_, ToCudaType::FromFloat(extrapolation_value_), - cubic_coeff_a_, exclude_outside_, - coordinate_transform_mode_, nearest_mode_, - dims_mapping); + const bool is_same = std::all_of(scales.begin(), scales.end(), [](float v) { return v == 1.0f; }) && + (coordinate_transform_mode_ != ResizeCoordinateTransformationMode::TF_CROP_AND_RESIZE); + if (is_same) { + CUDA_CALL_THROW(cudaMemcpyAsync(Y->MutableData(), X->Data(), + output_count * sizeof(T), cudaMemcpyDeviceToDevice, Stream(context))); + return Status::OK(); + } + + if (antialias_) { + TempSpaceAllocateFunc allocate_temp_space = [&](size_t bytes_size) { + return GetScratchBuffer(bytes_size, context->GetComputeStream()); + }; + + std::optional extrapolation_value; + if (use_extrapolation_) + extrapolation_value.emplace(extrapolation_value_); + + switch (mode_) { + case UpsampleMode::LINEAR: { + if (X_dims.size() == 2 || X_dims.size() == 4) { + const bool is_2D = X_dims.size() == 2; + + int64_t batch_size = 1; + int64_t num_channels = 1; + + int64_t input_height; + int64_t input_width; + + int64_t output_height; + int64_t output_width; + + float height_scale; + float width_scale; + + if (is_2D) { + input_height = X_dims[0]; + input_width = X_dims[1]; + + output_height = output_dims[0]; + output_width = output_dims[1]; + + height_scale = scales[0]; + width_scale = scales[1]; + } else { + if (scales[0] == 1.0f && scales[1] == 1.0f) { + batch_size = X_dims[Channels::N]; + num_channels = X_dims[Channels::C]; + input_height = X_dims[Channels::H]; + input_width = X_dims[Channels::W]; + + output_height = output_dims[Channels::H]; + output_width = output_dims[Channels::W]; + + height_scale = scales[2]; + width_scale = scales[3]; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", ": NHWC is not supported yet"); + } + } + + ResizeAntiAliasImpl(Stream(context), + rank, + mode_, + coordinate_transform_mode_, + X_dims, output_dims, + batch_size, num_channels, + std::make_tuple(0, input_height, input_width), + std::make_tuple(0, output_height, output_width), + std::make_tuple(0.f, height_scale, width_scale), + output_div_pitches, + roi, + extrapolation_value, + exclude_outside_, + allocate_temp_space, + shared_lookup_table_ondevice_.get(), + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); + + } else if (X_dims.size() == 3 || X_dims.size() == 5) { + const bool is_3D = X_dims.size() == 3; + + if (!is_3D) { + if (!(scales[0] == 1.0f && scales[1] == 1.0f)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", ": NDHWC is not supported yet"); + } + } + + const int64_t batch_size = is_3D ? 1 : X_dims[0]; + const int64_t num_channels = is_3D ? 1 : X_dims[1]; + const int64_t input_depth = is_3D ? X_dims[0] : X_dims[2]; + const int64_t input_height = is_3D ? X_dims[1] : X_dims[3]; + const int64_t input_width = is_3D ? X_dims[2] : X_dims[4]; + + const int64_t output_depth = is_3D ? output_dims[0] : output_dims[2]; + const int64_t output_height = is_3D ? output_dims[1] : output_dims[3]; + const int64_t output_width = is_3D ? output_dims[2] : output_dims[4]; + + const float depth_scale = is_3D ? scales[0] : scales[2]; + const float height_scale = is_3D ? scales[1] : scales[3]; + const float width_scale = is_3D ? scales[2] : scales[4]; + + ResizeAntiAliasImpl(Stream(context), + rank, + mode_, + coordinate_transform_mode_, + X_dims, output_dims, + batch_size, num_channels, + std::make_tuple(input_depth, input_height, input_width), + std::make_tuple(output_depth, output_height, output_width), + std::make_tuple(depth_scale, height_scale, width_scale), + output_div_pitches, + roi, + extrapolation_value, + exclude_outside_, + allocate_temp_space, + shared_lookup_table_ondevice_.get(), + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", + ": 'Linear' mode only support 2-D inputs or 3-D inputs ('Bilinear', 'Trilinear') " + "or 4-D inputs or 5-D inputs with the corresponding outermost 2 scale values " + "being 1."); + } + } break; + case UpsampleMode::CUBIC: { + if (X_dims.size() != 2 && X_dims.size() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Resize", + ": 'Cubic' mode only support 2-D inputs ('Bicubic') or 4-D inputs " + "with the corresponding outermost 2 scale values being 1."); + } + + const bool is_2D = X_dims.size() == 2; + const bool is_nchw = is_2D ? true : (scales[1] == 1.0f && scales[1] == 1.0f); + + ORT_RETURN_IF_NOT(is_nchw, + "Resize 'Cubic' mode only supports NCWH layout " + " with 2-D or 4-D with leading dims equal to 1"); + + const int64_t batch_size = is_2D ? 1 : X_dims[Channels::N]; + const int64_t num_channels = is_2D ? 1 : X_dims[Channels::C]; + const int64_t input_height = is_2D ? X_dims[0] : X_dims[Channels::H]; + const int64_t input_width = is_2D ? X_dims[1] : X_dims[Channels::W]; + + const int64_t output_height = is_2D ? output_dims[0] : output_dims[Channels::H]; + const int64_t output_width = is_2D ? output_dims[1] : output_dims[Channels::W]; + const float height_scale = is_2D ? scales[0] : scales[2]; + const float width_scale = is_2D ? scales[1] : scales[3]; + + ResizeAntiAliasImpl(Stream(context), rank, mode_, coordinate_transform_mode_, + X_dims, output_dims, + batch_size, num_channels, + std::make_tuple(0, input_height, input_width), + std::make_tuple(0, output_height, output_width), + std::make_tuple(0.f, height_scale, width_scale), + output_div_pitches, + roi, + extrapolation_value, + exclude_outside_, + allocate_temp_space, + shared_lookup_table_ondevice_.get(), + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count); + } break; + default: + return Status(ONNXRUNTIME, INVALID_ARGUMENT, "Resize: unexpected mode"); + } + } else { + TArray input_shape(X_dims); + TArray output_shape(output_dims); + TArray roi_vals(roi); + TArray scales_vals(scales); + + size_t temp_buffer_size = CalcResizeBufferSize(mode_, output_dims); + auto dims_mapping_buffer = GetScratchBuffer(temp_buffer_size, context->GetComputeStream()); + void* dims_mapping = reinterpret_cast(dims_mapping_buffer.get()); + ResizeImpl(Stream(context), mode_, rank, input_shape, output_shape, + input_strides, output_div_pitches, scales_vals, roi_vals, + reinterpret_cast(X->Data()), + reinterpret_cast(Y->MutableData()), + output_count, use_extrapolation_, ToCudaType::FromFloat(extrapolation_value_), + cubic_coeff_a_, exclude_outside_, + coordinate_transform_mode_, nearest_mode_, + dims_mapping); + } } else { TArray scales_div(rank); @@ -124,7 +312,7 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { auto input_dims = X->Shape().GetDims(); TensorShapeVector output_dims(input_dims.size()); - std::vector roi_array(input_dims.size() * 2, 0.0f); + InlinedVector roi_array(input_dims.size() * 2, 0.0f); if (!roi_cached_) { bool use_default_roi = true; if (need_roi_input_) { @@ -147,29 +335,37 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { } } - const std::vector& roi = roi_cached_ ? roi_ : roi_array; - std::vector scales_array = scales_; + ComputeROIWithAxes(roi_array, input_dims.size()); + InlinedVector scales_array(input_dims.size()); + // opset < 10 if (OpKernel::Node().InputDefs().size() == 1) { - // Compute output shape from scales and input dims + // Compute output shape from scales attributes and input dims + scales_array = scales_; + ComputeOutputShape(scales_array, input_dims, output_dims); - return BaseCompute(context, roi, scales_, output_dims); + return BaseCompute(context, roi_array, scales_, output_dims); } const Tensor* scales = context->Input(scales_input_idx_); const Tensor* sizes = context->Input(sizes_input_idx_); + // This is when scales are obtained and cached from a constant initializer if (scales_cached_) { - ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); + ORT_RETURN_IF_NOT(sizes == nullptr, "Only one of scales or sizes must be provided as input."); + scales_array = scales_; + // Compute output shape from scales and input dims ComputeOutputShape(scales_array, input_dims, output_dims); - return BaseCompute(context, roi, scales_, output_dims); + return BaseCompute(context, roi_array, scales_array, output_dims); } - scales_array.resize((input_dims.size())); + // Scales and sizes are input to the node if (scales != nullptr && scales->Shape().Size() != 0) { // use scales input data ORT_ENFORCE(sizes == nullptr, "Only one of scales or sizes must be provided as input."); ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, input_dims.size())); + + // Compute output shape from scales and input dims ComputeOutputShape(scales_array, input_dims, output_dims); } else { // When sizes input is available directly populate it into the output_dims array. @@ -179,7 +375,7 @@ Status Upsample::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(ParseScalesDataAndAdjustOutputSize(output_dims, input_dims, scales_array)); } - return BaseCompute(context, roi, scales_array, output_dims); + return BaseCompute(context, roi_array, scales_array, output_dims); } } // namespace cuda diff --git a/onnxruntime/core/providers/cuda/tensor/upsample.h b/onnxruntime/core/providers/cuda/tensor/upsample.h index 7bf2a23ede399..50597e0fba1b9 100644 --- a/onnxruntime/core/providers/cuda/tensor/upsample.h +++ b/onnxruntime/core/providers/cuda/tensor/upsample.h @@ -13,12 +13,14 @@ namespace cuda { template class Upsample : public UpsampleBase, public CudaKernel { public: - Upsample(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { - } + explicit Upsample(const OpKernelInfo& info); Status ComputeInternal(OpKernelContext* context) const override; - Status BaseCompute(OpKernelContext* context, const std::vector& roi, const std::vector& scales, - const gsl::span& output_dims) const; + Status BaseCompute(OpKernelContext* context, gsl::span roi, gsl::span scales, + gsl::span output_dims) const; + + private: + IAllocatorUniquePtr shared_lookup_table_ondevice_; }; } // namespace cuda diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 3fd5423681b81..0265c06b9a938 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1145,11 +1145,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, ReduceSumSquare); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int64_t, GatherND); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Dropout); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, int32_t, Resize); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, uint8_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, int32_t, Resize); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, uint8_t, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, If); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 18, Loop); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Flatten); @@ -1304,6 +1304,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, int32_t, Resize); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, uint8_t, Resize); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, Split); // Opset 19 @@ -2081,11 +2086,16 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2240,6 +2250,16 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, // Opset 19 diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index da17135878fe5..7b73ab36b3742 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -24,6 +24,7 @@ #include "core/providers/cpu/tensor/size.h" #include "core/providers/cpu/tensor/scatter_nd.h" #include "core/providers/cpu/tensor/unsqueeze.h" +#include "core/providers/cpu/tensor/upsamplebase.h" #include "core/providers/cpu/tensor/tile.h" #ifndef DISABLE_CONTRIB_OPS @@ -572,6 +573,11 @@ std::unique_ptr> EinsumTypedComputeProcessor template <> std::unique_ptr> EinsumTypedComputeProcessor::Create(OpKernelContext* context, AllocatorPtr allocator, concurrency::ThreadPool* tp, EinsumComputePreprocessor& einsum_compute_preprocessor, void* einsum_cuda_assets) { return g_host_cpu.EinsumTypedComputeProcessor_MLFloat16__Create(context, allocator, tp, einsum_compute_preprocessor, einsum_cuda_assets); } +void UpsampleBase::AdjustOutputSizeAsPolicy(TensorShapeVector& output_dims, gsl::span input_dims, + InlinedVector& scales) const { + g_host_cpu.UpsampleBase__AdjustOutputSizeAsPolicy(this, output_dims, input_dims, scales); +} + #ifndef DISABLE_CONTRIB_OPS namespace contrib { Status embed_layer_norm::CheckInputs(const OpKernelContext* context, bool quantizedVersion) { @@ -648,7 +654,6 @@ Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, c const SessionState& subgraph_session_state) { return g_host_cpu.Sampling__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state); } - } // namespace transformers #ifdef ENABLE_ATEN diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index 0c9e2e9fc17a2..09666c8039402 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -288,7 +288,7 @@ Status Resize::Compute(OpKernelContext* ctx) const { // Get scales data const auto* scales = ctx->Input(scales_input_idx_); - std::vector scales_array(X->Shape().GetDims().size()); + InlinedVector scales_array(X->Shape().GetDims().size()); if (scales != nullptr && scales->Shape().Size() != 0) { ORT_RETURN_IF_ERROR(ParseScalesData(scales, scales_array, output_shape.size())); diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 10f02349a24d5..1d31f3fdb4eb4 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -11,7 +11,8 @@ namespace test { TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_tf_crop_and_resize) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.20000028610229492, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] " + << "is 0.20000028610229492, which exceeds threshold"; } OpTester test("Resize", 13); @@ -32,7 +33,8 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_tf_crop_and_resize) { test.AddInput("X", {H, W}, X); test.AddInput("roi", {4}, roi); - test.AddInput("", {0}, scales); // opset13 requires either 'sizes' or 'scales' must be provided, but not both of them + // opset13 requires either 'sizes' or 'scales' must be provided, but not both of them + test.AddInput("", {0}, scales); test.AddInput("sizes", {2}, sizes); std::vector Y = {7.600004f, 7.9f, 8.2f, @@ -188,7 +190,9 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch // DML: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); + test.Run( + OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8) { @@ -317,7 +321,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { // The output size is [1,1,2,4].*[1,1,0.6,0.6]=[1,1,1,2] // NNAPI will recaluclate the scales as the output size divided by input size // scales = [1,1,1,2]./[1,1,2,4] = [1,1,0.5,0.5] -// See, https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/reference/reference_ops.h +// See:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/internal/reference/reference_ops.h // So the result of the above example will be different than CPU EP // Add the following 2 tests to test with scales valid to NNAPI TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear1) { @@ -475,7 +479,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_int TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_2DBilinear_pytorch_half_pixel) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << " The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold"; } OpTester test("Resize", 13); @@ -533,7 +538,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch // DML: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) { @@ -721,7 +727,8 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_2DBilinear_align_corners) { TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_3DTrilinear_pytorch_half_pixel) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 1.5000001192092896, which exceeds threshold"; } OpTester test("Resize", 13); @@ -1088,7 +1095,8 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Floor_Align_Corners) { TEST(ResizeOpTest, ResizeOpNearest_OneToOneMappingBetweenInputAndOutputDataDims) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 3, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 3, which exceeds threshold"; } OpTester test("Resize", 12); // tf_half_pixel_for_nn is deprecated since opset 13 @@ -1480,7 +1488,8 @@ TEST(ResizeOpTest, ResizeOpCubicUpSampleTest_tf_half_pixel_for_nn) { TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_Ver10) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold"; } OpTester test("Resize", 10); @@ -1505,7 +1514,8 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_Ver10) { TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_2DBilinear_Ver10) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 1.6666665077209473, which exceeds threshold "; } OpTester test("Resize", 10); @@ -1530,7 +1540,8 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_2DBilinear_Ver10) { TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_Ver10) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.5, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 0.5, which exceeds threshold"; } OpTester test("Resize", 10); @@ -1565,7 +1576,8 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_Ver10) { TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_2DBilinear_Ver10) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: The difference between expected[i] and output[i] is 0.5, which exceeds threshold"; + GTEST_SKIP() << "Skipping because of the following error: " + << "The difference between expected[i] and output[i] is 0.5, which exceeds threshold"; } OpTester test("Resize", 10); @@ -1676,7 +1688,8 @@ TEST(UpsampleOpTest, ResizeOpNearestNoScaleTest_Ver10) { TEST(ResizeOpTest, ResizeOp_MissingRoiAndMissingScalesOptionalInputs) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(1876): The parameter is incorrect."; + GTEST_SKIP() << "Skipping because of the following error: " + << "MLOperatorAuthorImpl.cpp(1876): The parameter is incorrect."; } OpTester test("Resize", 13); @@ -1827,7 +1840,8 @@ template void TestAntialiasing(std::map attributes, std::vector input_shape, std::vector input_data, - std::vector output_shape_or_scale, std::vector output_data) { + std::vector output_shape_or_scale, std::vector output_data, + gsl::span excluded_ep = {}) { auto parse_attr = [](const std::string& str, auto typed_v) { using Tdata = decltype(typed_v); std::vector vect; @@ -1891,13 +1905,22 @@ void TestAntialiasing(std::map attributes, } test.AddOutput("Y", output_shape, output_data); - // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accurarcy issue. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + + std::unordered_set excluded_eps; + std::transform(excluded_ep.begin(), excluded_ep.end(), + std::inserter(excluded_eps, excluded_eps.end()), [](std::string_view ep) { + return std::string(ep); + }); + // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accuracy issue. + excluded_eps.insert(kTensorrtExecutionProvider); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_eps); } TEST(ResizeOpTest, Antialias_Bilinear_No_ExcludeOutside) { if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + GTEST_SKIP() << "Skipping because dml implementation of antialias " + << "is slightly different and doesn't match in all cases."; } std::vector X(16); std::iota(X.begin(), X.end(), 1.f); @@ -1939,7 +1962,8 @@ TEST(ResizeOpTest, Antialias_Bilinear_dtype) { std::vector Y = {1, 3, 4, 6, 8, 9, 11, 13, 14}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y); + InlinedVector excluded_eps = {kCudaExecutionProvider}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 1, 4, 4}, X, {1, 1, 3, 3}, Y, excluded_eps); } { std::vector X(16); @@ -1982,17 +2006,21 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear) { 33.5f, 73.5f, 113.5f, 35.074074f, 75.07407f, 115.07407f, 36.590908f, 76.59091f, 116.59091f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 5, 8, 3}, X, {1, 4, 5, 3}, Y); + + // Nchw is not supported by CUDA Resize implementation + InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 5, 8, 3}, X, {1, 4, 5, 3}, Y, excluded_eps); } TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { + InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; { std::vector X(16); std::iota(X.begin(), X.end(), uint8_t(0)); std::vector Y = {1, 3, 4, 6, 8, 9, 11, 13, 14}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y, excluded_eps); } { std::vector X(16); @@ -2000,7 +2028,7 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { std::vector Y = {1, 3, 4, 6, 8, 9, 11, 13, 14}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y, excluded_eps); } { std::vector X(16); @@ -2008,13 +2036,14 @@ TEST(ResizeOpTest, Antialias_NhwcBilinear_dtype) { std::vector Y = {1, 3, 4, 6, 8, 9, 11, 13, 14}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {1, 4, 4, 1}, X, {1, 3, 3, 1}, Y, excluded_eps); } } TEST(ResizeOpTest, Antialias_Trilinear_No_ExcludeOutside) { if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + GTEST_SKIP() << "Skipping because dml implementation of " + << "antialias is slightly different and doesn't match in all cases."; } std::vector X(16 * 4); std::iota(X.begin(), X.end(), 0.f); @@ -2038,13 +2067,17 @@ TEST(ResizeOpTest, Antialias_Trilinear_ExcludeOutside) { TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) { if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + GTEST_SKIP() << "Skipping because dml implementation of antialias" + << " is slightly different and doesn't match in all cases."; } + + InlinedVector excluded_eps = {kCudaExecutionProvider}; std::vector X(16 * 4 * 4); std::iota(X.begin(), X.end(), 0.f); { std::vector Y = X; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 4}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 4}, Y, + excluded_eps); } { std::vector Y = {0.625f, 2.375f, 4.625f, 6.375f, 8.625f, 10.375f, 12.625f, @@ -2066,7 +2099,8 @@ TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) { 224.625f, 226.375f, 228.625f, 230.375f, 232.625f, 234.375f, 236.625f, 238.375f, 240.625f, 242.375f, 244.625f, 246.375f, 248.625f, 250.375f, 252.625f, 254.375f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 2}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 4, 2}, Y, + excluded_eps); } { std::vector Y = {2.5f, 3.5f, 4.5f, 5.5f, 9.5f, 10.5f, 11.5f, 12.5f, 18.5f, @@ -2084,7 +2118,8 @@ TEST(ResizeOpTest, Antialias_Trilinear_Scale_Is_11s_and_1s1) { 217.5f, 218.5f, 219.5f, 220.5f, 226.5f, 227.5f, 228.5f, 229.5f, 233.5f, 234.5f, 235.5f, 236.5f, 242.5f, 243.5f, 244.5f, 245.5f, 249.5f, 250.5f, 251.5f, 252.5f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 2, 4}, Y); + TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "0"}}, {4, 1, 4, 4, 4}, X, {4, 1, 4, 2, 4}, Y, + excluded_eps); } } @@ -2124,12 +2159,15 @@ TEST(ResizeOpTest, Antialias_NHWCBicubic_ExcludeOutside) { 19.576872f, 43.57687f, 21.126253f, 45.126255f, 22.606192f, 46.606194f, 19.878183f, 43.87818f, 21.358122f, 45.35812f, 22.907503f, 46.907505f, 24.387442f, 48.387444f}; - TestAntialiasing({{"mode", "cubic"}, {"exclude_outside", "0"}}, {1, 4, 6, 2}, X, {1, 8, 4, 2}, Y); + + InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; + TestAntialiasing({{"mode", "cubic"}, {"exclude_outside", "0"}}, {1, 4, 6, 2}, X, {1, 8, 4, 2}, Y, excluded_eps); } TEST(ResizeOpTest, Antialias_Linear_AlignCorners) { if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly different and doesn't match in all cases."; + GTEST_SKIP() << "Skipping because dml implementation of antialias" + << "is slightly different and doesn't match in all cases."; } std::vector X(256); std::iota(X.begin(), X.end(), 0.0f); @@ -2145,9 +2183,40 @@ TEST(ResizeOpTest, Antialias_Linear_AlignCorners) { 187.08333f, 195.91667f, 198.41667f, 205.91667f, 208.41667f, 217.25f, 219.75f, 227.25f, 229.75f, 238.58333f, 241.08333f, 248.58333f, 251.08333f}; + InlinedVector excluded_eps = {kCudaExecutionProvider, kRocmExecutionProvider}; TestAntialiasing( {{"mode", "linear"}, {"exclude_outside", "0"}, {"coordinate_transformation_mode", "align_corners"}}, - {4, 1, 4, 4, 4}, X, {4, 1, 3, 2, 2}, Y); + {4, 1, 4, 4, 4}, X, {4, 1, 3, 2, 2}, Y, excluded_eps); +} + +TEST(ResizeOpTest, Antialias_Linear_AlignCorners_3D) { + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because dml implementation of antialias is slightly " + << "different and doesn't match in all cases."; + } + std::vector X(256); + std::iota(X.begin(), X.end(), 0.0f); + std::vector Y{ + 1.25f, 3.75f, 11.25f, 13.75f, + 17.25f, 19.75f, 27.25f, 29.75f, + 33.25f, 35.75f, 43.25f, 45.75f, + 49.25f, 51.75f, 59.25f, 61.75f, + 65.25f, 67.75f, 75.25f, 77.75f, + 81.25f, 83.75f, 91.25f, 93.75f, + 97.25f, 99.75f, 107.25f, 109.75f, + 113.25f, 115.75f, 123.25f, 125.75f, + 129.25f, 131.75f, 139.25f, 141.75f, + 145.25f, 147.75f, 155.25f, 157.75f, + 161.25f, 163.75f, 171.25f, 173.75f, + 177.25f, 179.75f, 187.25f, 189.75f, + 193.25f, 195.75f, 203.25f, 205.75f, + 209.25f, 211.75f, 219.25f, 221.75f, + 225.25f, 227.75f, 235.25f, 237.75f, + 241.25f, 243.75f, 251.25f, 253.75f}; + + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "0"}, {"coordinate_transformation_mode", "align_corners"}}, + {16, 4, 4}, X, {16, 2, 2}, Y); } TEST(ResizeOpTest, Antialias_Bicubic_ExcludeOutside) { @@ -2166,19 +2235,23 @@ TEST(ResizeOpTest, Antialias_Bicubic_Dtype) { std::vector X(36); std::iota(X.begin(), X.end(), uint8_t(0)); std::vector Y = {4, 6, 7, 16, 18, 19, 28, 30, 31}; - TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y); + TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, + X, {1, 1, 3, 3}, Y); } { std::vector X(36); std::iota(X.begin(), X.end(), int8_t(0)); std::vector Y = {4, 6, 7, 16, 18, 19, 28, 30, 31}; - TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y); + InlinedVector excluded_eps = {kCudaExecutionProvider}; + TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, + X, {1, 1, 3, 3}, Y, excluded_eps); } { std::vector X(36); std::iota(X.begin(), X.end(), 0); std::vector Y = {4, 6, 7, 16, 18, 19, 28, 30, 31}; - TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, X, {1, 1, 3, 3}, Y); + TestAntialiasing({{"mode", "cubic"}, {"cubic_coeff_a", "-0.5f"}, {"exclude_outside", "1"}}, {1, 1, 6, 6}, + X, {1, 1, 3, 3}, Y); } } @@ -2189,8 +2262,10 @@ TEST(ResizeOpTest, Antialias_Axes_and_Scale) { std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, {1, 1, 4, 4, 4}, X, - std::vector{3 / 4.0f, 3 / 4.0f, 3 / 4.0f}, Y); + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, + {1, 1, 4, 4, 4}, X, + std::vector{3 / 4.0f, 3 / 4.0f, 3 / 4.0f}, Y); } TEST(ResizeOpTest, Antialias_Axes_and_Size) { @@ -2199,8 +2274,10 @@ TEST(ResizeOpTest, Antialias_Axes_and_Size) { std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, {1, 1, 4, 4, 4}, X, - {3, 3, 3}, Y); + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}}, + {1, 1, 4, 4, 4}, X, + {3, 3, 3}, Y); } TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoLarger) { @@ -2209,9 +2286,13 @@ TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoLarger) { std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, {"policy", "not_larger"}}, - {1, 1, 4, 4, 4}, X, - {3, 4, 5}, Y); + // clang-format off + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, + {"policy", "not_larger"}}, + {1, 1, 4, 4, 4}, X, + {3, 4, 5}, Y); + // clang-format on } TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoSmaller) { @@ -2220,9 +2301,13 @@ TEST(ResizeOpTest, Antialias_Axes_and_PolicyNoSmaller) { std::vector Y = {6.3f, 7.5f, 8.7f, 11.1f, 12.3f, 13.5f, 15.9f, 17.1f, 18.3f, 25.5f, 26.7f, 27.9f, 30.3f, 31.5f, 32.7f, 35.1f, 36.3f, 37.5f, 44.7f, 45.9f, 47.1f, 49.5f, 50.7f, 51.9f, 54.3f, 55.5f, 56.7f}; - TestAntialiasing({{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, {"policy", "not_smaller"}}, - {1, 1, 4, 4, 4}, X, - {1, 2, 3}, Y); + // clang-format off + TestAntialiasing( + {{"mode", "linear"}, {"exclude_outside", "1"}, {"axes", "{2,3,4}"}, {"output_shape", "{1,1,3,3,3}"}, + {"policy", "not_smaller"}}, + {1, 1, 4, 4, 4}, X, + {1, 2, 3}, Y); + // clang-format on } TEST(ResizeOpTest, Antialias_Use_Extrapolation) { From 2a857d9a86ca3049829256df3347521069ccd6b4 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Fri, 1 Mar 2024 10:23:29 +1000 Subject: [PATCH 177/207] Add ML Program support for more operators (#19527) ### Description Add support for: - Clip/Relu/Relu6 - Add/Mul/Div/Sub/Pow - GlobalAveragePool/GlobalMaxPool/AveragePool/MaxPool - Reshape - Gemm/MatMul Fix some build issues/warnings from changes. Fix a couple of potential issues with the Resize op as well (noticed due to change to reject inputs with empty data at a higher level). ### Motivation and Context Enable mobilenetv2 with ML Program --- cmake/onnxruntime_providers_coreml.cmake | 2 +- .../providers/coreml/builders/coreml_spec.h | 7 +- .../core/providers/coreml/builders/helper.cc | 14 +- .../coreml/builders/impl/base_op_builder.cc | 13 +- .../coreml/builders/impl/base_op_builder.h | 6 +- .../coreml/builders/impl/binary_op_builder.cc | 113 +++--- .../coreml/builders/impl/builder_utils.cc | 68 ++++ .../coreml/builders/impl/builder_utils.h | 17 +- .../coreml/builders/impl/clip_op_builder.cc | 187 ++++++--- .../coreml/builders/impl/conv_op_builder.cc | 94 +---- .../coreml/builders/impl/gemm_op_builder.cc | 332 +++++++++++----- .../coreml/builders/impl/pool_op_builder.cc | 218 +++++++---- .../builders/impl/reshape_op_builder.cc | 70 ++-- .../coreml/builders/impl/resize_op_builder.cc | 16 +- .../coreml/builders/impl/slice_op_builder.cc | 2 +- .../builders/impl/softmax_op_builder.cc | 4 +- .../coreml/builders/model_builder.cc | 366 +++++++++++++----- .../providers/coreml/builders/model_builder.h | 63 +-- .../coreml/coreml_execution_provider.cc | 82 ++-- .../providers/coreml/dump_mlprogram_model.py | 27 ++ .../core/providers/coreml/model/host_utils.h | 6 + .../core/providers/coreml/model/host_utils.mm | 10 + .../core/providers/coreml/model/model.h | 19 +- .../core/providers/coreml/model/model.mm | 13 + .../core/providers/coreml/model/model_stub.cc | 4 + .../providers/cpu/tensor/reshape_helper.h | 6 +- .../test/perftest/command_args_parser.cc | 25 +- onnxruntime/test/perftest/ort_test_session.cc | 30 +- .../providers/coreml/coreml_basic_test.cc | 20 + .../test/providers/cpu/math/clip_test.cc | 27 +- .../test/providers/cpu/math/gemm_test.cc | 37 +- .../providers/cpu/nn/batch_norm_op_test.cc | 37 ++ .../providers/cpu/tensor/resize_op_test.cc | 4 +- 33 files changed, 1344 insertions(+), 595 deletions(-) create mode 100644 onnxruntime/core/providers/coreml/dump_mlprogram_model.py diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake index c9f35e5337f9b..8f3b1828e1c61 100644 --- a/cmake/onnxruntime_providers_coreml.cmake +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -111,7 +111,7 @@ if(_enable_ML_PROGRAM) file(GLOB onnxruntime_providers_coreml_modelpackage_cc_srcs CONFIGURE_DEPENDS "${coremltools_SOURCE_DIR}/modelpackage/src/ModelPackage.?pp" - "${coremltools_SOURCE_DIR}/modelpackage/src/Utils/JsonMap.?pp" + "${coremltools_SOURCE_DIR}/modelpackage/src/utils/JsonMap.?pp" ) set(coremltools_srcs diff --git a/onnxruntime/core/providers/coreml/builders/coreml_spec.h b/onnxruntime/core/providers/coreml/builders/coreml_spec.h index c9adba9e579d0..9448f1167990e 100644 --- a/onnxruntime/core/providers/coreml/builders/coreml_spec.h +++ b/onnxruntime/core/providers/coreml/builders/coreml_spec.h @@ -17,14 +17,19 @@ #ifdef HAS_SHORTEN_64_TO_32 #pragma GCC diagnostic ignored "-Wshorten-64-to-32" #endif +#elif defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4244) // conversion from long to int #endif // Model.pb.h is generated in the build output directory from the CoreML protobuf files in -// onnxruntime/core/providers/coreml/coremltools/mlmodel/format +// /_deps/coremltools-src/mlmodel/format #include "coreml_proto/Model.pb.h" #if defined(__GNUC__) #pragma GCC diagnostic pop +#elif defined(_MSC_VER) +#pragma warning(pop) #endif namespace COREML_SPEC = CoreML::Specification; diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc index bc3ba4432e66d..b8ebbd05a2a20 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.cc +++ b/onnxruntime/core/providers/coreml/builders/helper.cc @@ -85,9 +85,15 @@ bool IsInputSupported(const Node& node, const NodeArg& input, } if (dim == 0) { - LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name - << ", shape: " << Shape2String(shape); - return false; + if (node.OpType() == "Resize" && &input == node.InputDefs()[1]) { + // one special case. Resize 'roi' input was originally a required input but is rarely used. + // ROI is not supported in the CoreML implementation so we will ignore the value, but is often added + // (at least in the unit tests) as an initializer with shape {0}. + } else { + LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name + << ", shape: " << Shape2String(shape); + return false; + } } } @@ -125,7 +131,7 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe bool CheckIsConstantInitializer(const NodeArg& node_arg, const GraphViewer& graph_viewer, const logging::Logger& logger, std::string_view input_description) { - if (graph_viewer.GetConstantInitializer(node_arg.Name(), true) == nullptr) { + if (graph_viewer.GetConstantInitializer(node_arg.Name()) == nullptr) { LOGS(logger, VERBOSE) << input_description << " (NodeArg name: '" << node_arg.Name() << "') is not a constant initializer tensor"; return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 2570e6d88ae0d..83a572f4b60fa 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -83,9 +83,14 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputPar } /* static */ -bool BaseOpBuilder::IsInput0Supported(const Node& node, const OpBuilderInputParams& /*input_params*/, - const logging::Logger& logger) { - const auto& input = *node.InputDefs()[0]; +bool BaseOpBuilder::IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) { + if (idx >= node.InputDefs().size()) { + LOGS(logger, VERBOSE) << "Input index [" << idx << "] is out of range"; + return false; + } + + const auto& input = *node.InputDefs()[idx]; int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; @@ -102,7 +107,7 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInpu const logging::Logger& logger) const { // We only check the type of input 0 by default // specific op builder can override this - return IsInput0Supported(node, input_params, logger); + return IsInputFloat(node, 0, input_params, logger); } bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index 06c4dd94ea30d..63f0b813d654c 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -28,9 +28,9 @@ class BaseOpBuilder : public IOpBuilder { void AddInitializersToSkip(ModelBuilder& /*model_builder*/, const Node& /*node*/) const override {} protected: - // check if the first input's data type is supported. - static bool IsInput0Supported(const Node& node, const OpBuilderInputParams& input_params, - const logging::Logger& logger); + // currently we only support float + static bool IsInputFloat(const Node& node, size_t idx, const OpBuilderInputParams& input_params, + const logging::Logger& logger); private: virtual bool IsOpSupportedImpl(const Node& /*node*/, const OpBuilderInputParams& /*input_params*/, diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index 6074fba1433d9..fb8e07633621f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -5,6 +5,7 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" @@ -19,6 +20,8 @@ class BinaryOpBuilder : public BaseOpBuilder { bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; namespace { @@ -57,38 +60,72 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - std::unique_ptr layer = model_builder.CreateNNLayer(node); - - if (op_type == "Add") { - // original mutable_add() has limited broadcasting support - // updated to use CoreML::AddBroadcastableLayerParams which has more general broadcasting support - if (CheckIfBothInputShapesMatch(node, logger)) { - layer->mutable_add(); +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#module-coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_binary + std::string_view coreml_op_type; + if (op_type == "Add") { + coreml_op_type = "add"; + } else if (op_type == "Mul") { + coreml_op_type = "mul"; + } else if (op_type == "Sub") { + coreml_op_type = "sub"; + } else if (op_type == "Div") { + // we only support fp32 currently. when we add support for integers we need to check the type and use + // "floor_div" or "real_div" accordingly + coreml_op_type = "real_div"; + } else if (op_type == "Pow") { + coreml_op_type = "pow"; } else { - layer->mutable_addbroadcastable(); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BinaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); } - } else if (op_type == "Mul") { - if (CheckIfBothInputShapesMatch(node, logger)) { - layer->mutable_multiply(); + + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + AddOperationInput(*op, "x", input_defs[0]->Name()); + AddOperationInput(*op, "y", input_defs[1]->Name()); + AddOperationOutput(*op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(op)); + } else +#endif // defined (COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + if (op_type == "Add") { + // original mutable_add() has limited broadcasting support + // updated to use CoreML::AddBroadcastableLayerParams which has more general broadcasting support + if (CheckIfBothInputShapesMatch(node, logger)) { + layer->mutable_add(); + } else { + layer->mutable_addbroadcastable(); + } + } else if (op_type == "Mul") { + if (CheckIfBothInputShapesMatch(node, logger)) { + layer->mutable_multiply(); + } else { + layer->mutable_multiplybroadcastable(); + } + } else if (op_type == "Sub") { + layer->mutable_subtractbroadcastable(); + } else if (op_type == "Div") { + layer->mutable_dividebroadcastable(); + } else if (op_type == "Pow") { + layer->mutable_powbroadcastable(); } else { - layer->mutable_multiplybroadcastable(); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "BinaryOpBuilder::AddToModelBuilderImpl, unexpected op: ", op_type); } - } else if (op_type == "Sub") { - layer->mutable_subtractbroadcastable(); - } else if (op_type == "Div") { - layer->mutable_dividebroadcastable(); - } else if (op_type == "Pow") { - layer->mutable_powbroadcastable(); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BinaryOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); - } - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_input()->Add() = input_defs[1]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_input()->Add() = input_defs[1]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } @@ -99,25 +136,11 @@ int BinaryOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - if (node.OpType() != "Pow") { - return IsInput0Supported(node, input_params, logger); - } - - const auto& input_1 = *node.InputDefs()[0]; - const auto& input_2 = *node.InputDefs()[1]; - - // Pow we only support both inputs as fp32 for now - int32_t input_type_1; - int32_t input_type_2; - if (!GetType(input_1, input_type_1, logger) || - !GetType(input_2, input_type_2, logger)) { - return false; - } - - if (input_type_1 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || input_type_1 != input_type_2) { - LOGS(logger, VERBOSE) << "Pow only supports fp32 inputs, actual input type" - << ", Input type 1: " << input_type_1 - << ", Input type 2: " << input_type_2; + // Add/Sub/Mul/Div spec says inputs must be of the same type. + // Pow spec says inputs can be different types. + // We only support float for all of these inputs. + if (!IsInputFloat(node, 0, input_params, logger) || + ((node.OpType() == "Pow") && !IsInputFloat(node, 1, input_params, logger))) { return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index 710f596b2a562..cbea969904ed5 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -7,6 +7,7 @@ #include "core/framework/tensorprotoutils.h" #include "core/providers/coreml/builders/coreml_spec.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/shared/utils/utils.h" #include "core/optimizer/initializer.h" @@ -132,6 +133,7 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::spansize(); + size_t num_dims = num_pads / 2; + std::vector reordered_pads(num_pads, 0); + for (size_t i = 0; i < num_pads; ++i) { + auto cur_dim = i % num_dims; + if (i < num_dims) { // start values + reordered_pads[cur_dim * 2] = (*onnx_pads)[i]; + } else { // end values + reordered_pads[cur_dim * 2 + 1] = (*onnx_pads)[i]; + } + } + + AddOperationInput(op, "pad", model_builder.AddConstant(op_type, "pad", reordered_pads)); + + break; + } + + // fall through if explicit pads were not provided as the default value for `pads` is all zeros, + // which is the same as 'valid' padding. + [[fallthrough]]; + } + case AutoPadType::VALID: + AddOperationInput(op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string("valid"))); + + break; + case AutoPadType::SAME_UPPER: + case AutoPadType::SAME_LOWER: { + const auto pad_type = (auto_pad_type == AutoPadType::SAME_UPPER ? "same" : "same_lower"); + AddOperationInput(op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string(pad_type))); + + // despite what the spec says, a 'pad' input seems to be required. + // https://github.com/apple/coremltools/issues/2127 + // Provide the default value as that's what coremltools does for conv/avg_pool/max_pool. + std::vector ignored_pads(num_spatial_dims * 2, 0); + AddOperationInput(op, "pad", model_builder.AddConstant(op_type, "pad", ignored_pads)); + + break; + } + } +} +#endif // defined(COREML_ENABLE_MLPROGRAM) } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 8126f0c126914..2804589065631 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -11,13 +11,15 @@ #include "core/common/status.h" #include "core/graph/basic_types.h" #include "core/providers/common.h" - #include "core/providers/coreml/builders/coreml_spec.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { class NodeArg; namespace coreml { +class ModelBuilder; + // Try to see if we can map explicit padding to auto padding for Conv/Pool // Since usually use auto padding is more efficient Status HandleAutoPad(const std::vector input_shape, @@ -45,6 +47,7 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); +#if defined(COREML_ENABLE_MLPROGRAM) // // MLProgram utils // @@ -130,5 +133,17 @@ void AddOperationInput(COREML_SPEC::MILSpec::Operation& op, /// Operation to update. /// NodeArg with details of output to add. void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output); + +/// +/// Add pad_type and pad values. +/// +/// Operator to update +/// ModelBuilder to add constants with. +/// Operator type. +/// Node attribute helper. +/// Number of spatial dims in input. Generally rank - 2 (ignore N and C dims). +void AddPadTypeAndPads(COREML_SPEC::MILSpec::Operation& op, ModelBuilder& model_builder, std::string_view op_type, + const NodeAttrHelper& helper, int num_spatial_dims); +#endif // defined(COREML_ENABLE_MLPROGRAM) } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index 9aca172abec98..41f4041ef1181 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" @@ -17,11 +18,31 @@ class ClipOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; void ClipOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + bool skip = true; + + if (model_builder.CreateMLProgram()) { + float min, max; + ORT_IGNORE_RETURN_VALUE(GetClipMinMax(model_builder.GetGraphViewer(), node, min, max, model_builder.Logger())); + + bool has_min = min != std::numeric_limits::lowest(); + bool has_max = max != std::numeric_limits::max(); + if (has_min && has_max && min == 0.f && max == 6.f) { + // relu6 - skip both + } else if (has_min && min == 0.f && !has_max) { + // relu - skip both + } else { + // clip - we will use both + skip = false; + } + } + // Both min and max values will be injected into the layer, no need to add to the model - if (node.SinceVersion() >= 11) { + if (skip && node.SinceVersion() >= 11) { if (node.InputDefs().size() > 1) model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); @@ -35,72 +56,126 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const logging::Logger& logger) const { const auto& node_name = node.Name(); const auto& input_name = node.InputDefs()[0]->Name(); - const auto& output_name = node.OutputDefs()[0]->Name(); + const auto& output = *node.OutputDefs()[0]; + const auto& output_name = output.Name(); float min, max; ORT_RETURN_IF_NOT(GetClipMinMax(model_builder.GetGraphViewer(), node, min, max, logger), "GetClipMinMax failed"); bool has_min = min != std::numeric_limits::lowest(); bool has_max = max != std::numeric_limits::max(); - if (!has_min && !has_max) { - // Clip without min/max is an identity node - // In CoreML we don't have identity, use ActivationLinear instead - std::unique_ptr layer = model_builder.CreateNNLayer(node); - layer->mutable_activation()->mutable_linear()->set_alpha(1.0f); - *layer->mutable_input()->Add() = input_name; - *layer->mutable_output()->Add() = output_name; - - model_builder.AddLayer(std::move(layer)); - } else { - // The implementation of clip(min, max) is done by - // 1. Clipping at min -> max(input, min) is handled by - // min_output = threshold(input, min) - // 2. Clipping at max -> min(min_output, max) is handled by - // output = -1 * (threshold(-min_output, -max)) - - // Now we have at least one or min or max is not default value - // Clipping at max will need take the output of clipping at min, or the node input, if min value is default - // If max value is default, the output of clipping at min will be the output of the node - std::string min_output_name = output_name; - if (has_max) { - min_output_name = has_min - ? model_builder.GetUniqueName(node_name + "min_output") - : input_name; +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + std::unique_ptr op; + if (!has_min && !has_max) { + // Clip without min/max is an identity node. + op = model_builder.CreateOperation(node, "identity"); + Operation& identity_op = *op; + AddOperationInput(identity_op, "x", input_name); + } else { + if (has_min && has_max && min == 0.f && max == 6.f) { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.activation.relu6 + op = model_builder.CreateOperation(node, "relu6"); + Operation& relu6_op = *op; + AddOperationInput(relu6_op, "x", input_name); + } else if (has_min && min == 0.f && !has_max) { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.activation.relu + op = model_builder.CreateOperation(node, "relu"); + Operation& relu_op = *op; + AddOperationInput(relu_op, "x", input_name); + } else { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.elementwise_unary.clip + op = model_builder.CreateOperation(node, "clip"); + + Operation& clip_op = *op; + AddOperationInput(clip_op, "x", input_name); + + // if min and max were attributes we need to add initializers. otherwise we use the existing inputs + const bool min_max_attribs = node.SinceVersion() < 11; + std::string_view min_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "min", min) + : node.InputDefs()[1]->Name(); + + AddOperationInput(clip_op, "alpha", min_name); + + if (has_max) { + std::string_view max_name = min_max_attribs ? model_builder.AddScalarConstant(clip_op.type(), "max", max) + : node.InputDefs()[2]->Name(); + AddOperationInput(clip_op, "beta", max_name); + } + } } - // Handle clipping at min first - if (has_min) { - std::unique_ptr min_layer = model_builder.CreateNNLayer(node, "_Clip_min"); - if (min == 0.0f) { // If min is 0. then this min will be handled by relu - min_layer->mutable_activation()->mutable_relu(); - } else { // otherwise, min will be handled by unary->threshold - min_layer->mutable_unary()->set_alpha(min); - min_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); + AddOperationOutput(*op, output); + model_builder.AddOperation(std::move(op)); + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + // TODO: CoreML has a Clip layer for NeuralNetwork. Added in CoreML 4. We could potentially use that if available + // to simplify. + // https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#cliplayerparams + + if (!has_min && !has_max) { + // Clip without min/max is an identity node + // In CoreML we don't have identity, use ActivationLinear instead + std::unique_ptr layer = model_builder.CreateNNLayer(node); + layer->mutable_activation()->mutable_linear()->set_alpha(1.0f); + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; + + model_builder.AddLayer(std::move(layer)); + } else { + // The implementation of clip(min, max) is done by + // 1. Clipping at min -> max(input, min) is handled by + // min_output = threshold(input, min) + // 2. Clipping at max -> min(min_output, max) is handled by + // output = -1 * (threshold(-min_output, -max)) + + // Now we have at least one or min or max is not default value + // Clipping at max will need take the output of clipping at min, or the node input, if min value is default + // If max value is default, the output of clipping at min will be the output of the node + std::string min_output_name = output_name; + if (has_max) { + min_output_name = has_min + ? model_builder.GetUniqueName(node_name + "min_output") + : input_name; } - *min_layer->mutable_input()->Add() = input_name; - *min_layer->mutable_output()->Add() = min_output_name; - model_builder.AddLayer(std::move(min_layer)); - } - - // Clipping at max is handled by -1 * (threshold (-min_output, -max)) - if (has_max) { - const auto threshold_output_name = model_builder.GetUniqueName(MakeString(node_name, "threshold_output")); - { // Add threshold layer, which is actually max( -1 * min_output, -max) - auto threshold_layer = model_builder.CreateNNLayer(node, "_Clip_max_threshold"); - threshold_layer->mutable_unary()->set_alpha(-max); - threshold_layer->mutable_unary()->set_scale(-1.0f); - threshold_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); - *threshold_layer->mutable_input()->Add() = min_output_name; - *threshold_layer->mutable_output()->Add() = threshold_output_name; - model_builder.AddLayer(std::move(threshold_layer)); + // Handle clipping at min first + if (has_min) { + std::unique_ptr min_layer = model_builder.CreateNNLayer(node, "_Clip_min"); + if (min == 0.0f) { // If min is 0. then this min will be handled by relu + min_layer->mutable_activation()->mutable_relu(); + } else { // otherwise, min will be handled by unary->threshold + min_layer->mutable_unary()->set_alpha(min); + min_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); + } + + *min_layer->mutable_input()->Add() = input_name; + *min_layer->mutable_output()->Add() = min_output_name; + model_builder.AddLayer(std::move(min_layer)); } - { // Add linear activation layer -1 * threshold_output - auto linear_layer = model_builder.CreateNNLayer(node, "_Clip_max_linear"); - linear_layer->mutable_activation()->mutable_linear()->set_alpha(-1.0f); - *linear_layer->mutable_input()->Add() = threshold_output_name; - *linear_layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(linear_layer)); + + // Clipping at max is handled by -1 * (threshold (-min_output, -max)) + if (has_max) { + const auto threshold_output_name = model_builder.GetUniqueName(MakeString(node_name, "threshold_output")); + { // Add threshold layer, which is actually max( -1 * min_output, -max) + auto threshold_layer = model_builder.CreateNNLayer(node, "_Clip_max_threshold"); + threshold_layer->mutable_unary()->set_alpha(-max); + threshold_layer->mutable_unary()->set_scale(-1.0f); + threshold_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); + *threshold_layer->mutable_input()->Add() = min_output_name; + *threshold_layer->mutable_output()->Add() = threshold_output_name; + model_builder.AddLayer(std::move(threshold_layer)); + } + { // Add linear activation layer -1 * threshold_output + auto linear_layer = model_builder.CreateNNLayer(node, "_Clip_max_linear"); + linear_layer->mutable_activation()->mutable_linear()->set_alpha(-1.0f); + *linear_layer->mutable_input()->Add() = threshold_output_name; + *linear_layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(linear_layer)); + } } } } diff --git a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc index 05e43dbbd16af..38125957bf481 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc @@ -67,99 +67,25 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N AddOperationInput(*conv_op, "bias", input_defs[2]->Name()); } - // ONNX attributes. Add as inputs if specified/required - auto strides = helper.GetInt64s("strides"); - auto dilations = helper.GetInt64s("dilations"); - auto groups = helper.GetInt64("group"); - // we know this input has a valid shape due to the check in IsOpSupportedImpl. ignore N and C dims. const auto num_spatial_dims = input_defs[1]->Shape()->dim_size() - 2; const auto& op_type = conv_op->type(); - if (strides) { - AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", *strides)); - } else { - // spec says optional. testing suggests otherwise for at least the iOS15 target (CoreML5) - static const auto default_value = std::vector(num_spatial_dims, 1); - AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", default_value)); - } + // Spec says strides and dilations are optional, but reality is they're required for at least the iOS15 target + // (CoreML5). + const auto strides = helper.Get("strides", std::vector(num_spatial_dims, 1)); + auto dilations = helper.Get("dilations", std::vector(num_spatial_dims, 1)); + auto groups = helper.GetInt64("group"); - if (dilations) { - AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", *dilations)); - } else { - // spec says optional. testing suggests otherwise for at least the iOS15 target (CoreML5) - static const auto default_value = std::vector(num_spatial_dims, 1); - AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", default_value)); - } + AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", strides)); + AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", dilations)); if (groups) { AddOperationInput(*conv_op, "groups", model_builder.AddScalarConstant(op_type, "groups", *groups)); } - AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); - - // pad type (string) - // valid - no pads (ONNX auto_pad VALID) - // custom - pads input (ONNX NOTSET) - // same - inferred to be `d_out[i] = ceil(d_in[i] / strides[i])` (assuming == ONNX SAME_UPPER) - // same_lower - as per same but any extra rows/cols are added at top/left if padding is odd (ONNX SAME_LOWER) - // - // TODO: See if we want to update HandleAutoPad to support 1D (and 3D) so we can infer if an autopad value - // can be used. TBD if that provides any performance benefit with ML Program though as CoreML could - // potentially do that for us. - switch (auto_pad_type) { - case AutoPadType::NOTSET: { - // use `pads` attribute. - auto onnx_pads = helper.GetInt64s("pads"); // 'pads' must be provided if auto_pad is NOTSET - if (onnx_pads) { - AddOperationInput(*conv_op, "pad_type", - model_builder.AddScalarConstant(op_type, "pad_type", std::string("custom"))); - - // need to re-order from x1_start, x2_start..., x1_end, x2_end... to - // x1_start, x1_end, x2_start, x2_end,... - size_t num_pads = onnx_pads->size(); - size_t num_dims = num_pads / 2; - std::vector reordered_pads(num_pads, 0); - for (size_t i = 0; i < num_pads; ++i) { - auto cur_dim = i % num_dims; - if (i < num_dims) { // start values - reordered_pads[cur_dim * 2] = (*onnx_pads)[i]; - } else { // end values - reordered_pads[cur_dim * 2 + 1] = (*onnx_pads)[i]; - } - } - - AddOperationInput(*conv_op, "pad", model_builder.AddConstant(op_type, "pad", reordered_pads)); - - break; - } - - // in theory the pads may not be provided and in that case the default is no padding. - // as that is the same as 'valid', fall through - [[fallthrough]]; - } - case AutoPadType::VALID: - AddOperationInput(*conv_op, "pad_type", - model_builder.AddScalarConstant(op_type, "pad_type", std::string("valid"))); - - break; - case AutoPadType::SAME_UPPER: - case AutoPadType::SAME_LOWER: { - const auto pad_type = (auto_pad_type == AutoPadType::SAME_UPPER ? "same" : "same_lower"); - AddOperationInput(*conv_op, "pad_type", - model_builder.AddScalarConstant(op_type, "pad_type", std::string(pad_type))); - - // despite what the spec says, a 'pad' input seems to be required. - // https://github.com/apple/coremltools/issues/2127 - // provide the default value. passing in an empty vector also works. TBD what's better. - std::vector ignored_pads(num_spatial_dims * 2, 0); - AddOperationInput(*conv_op, "pad", model_builder.AddConstant(op_type, "pad", ignored_pads)); - - break; - } - } + AddPadTypeAndPads(*conv_op, model_builder, op_type, helper, num_spatial_dims); - // set output AddOperationOutput(*conv_op, *node.OutputDefs()[0]); model_builder.AddOperation(std::move(conv_op)); @@ -297,7 +223,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara const auto& input_defs = node.InputDefs(); const auto& weight_name = input_defs[1]->Name(); - const auto* weight = input_params.graph_viewer.GetConstantInitializer(weight_name, true); + const auto* weight = input_params.graph_viewer.GetConstantInitializer(weight_name); #if defined(COREML_ENABLE_MLPROGRAM) if (input_params.create_mlprogram) { @@ -324,7 +250,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return false; } - if (input_defs.size() > 2 && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name(), true)) { + if (input_defs.size() > 2 && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name())) { LOGS(logger, VERBOSE) << "The bias of Conv [" << name << "] must be a constant initializer"; return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 48f77354d7c30..8daf64dc4a457 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -22,18 +22,51 @@ class GemmOpBuilder : public BaseOpBuilder { Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; - bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, - const logging::Logger& /* logger */) const override; + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& op = node.OpType(); const auto& input_defs(node.InputDefs()); - // We have already embedded the weights (matrix B and C(if any)) into the coreml layer - // No need to copy them later to reduce memory consumption - model_builder.AddInitializerToSkip(input_defs[1]->Name()); - if (op == "Gemm" && input_defs.size() > 2) { - model_builder.AddInitializerToSkip(input_defs[2]->Name()); + const bool is_gemm = op == "Gemm"; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + // we have to transpose the weight input of Gemm if transB is false, and potentially override the bias shape + if (is_gemm) { + NodeAttrHelper helper(node); + const auto transB = helper.Get("transB", 0); + if (transB == 0) { + model_builder.AddInitializerToSkip(input_defs[1]->Name()); + } + + if (input_defs.size() > 2) { + // ONNX spec requires B to be 2D and we required it to be a constant initializer so reading N this way is safe + // B is {K, N] by default. or {N, K} if transB is true + int N_dim = transB ? 0 : 1; + int64_t N = input_defs[1]->Shape()->dim().at(N_dim).dim_value(); + + const auto& bias_name = input_defs[2]->Name(); + const auto& bias = *model_builder.GetConstantInitializer(bias_name); + if (bias.dims_size() != 1 || bias.dims(0) != N) { + // we have to override the shape/duplicate data to convert {}, {1} or {1, N} to 1D {N} + // when adding the Gemm operation so skip adding the original initializer + model_builder.AddInitializerToSkip(bias_name); + } + } + } + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + // We have already embedded the weights (matrix B and C(if any)) into the coreml layer + // No need to copy them later to reduce memory consumption + model_builder.AddInitializerToSkip(input_defs[1]->Name()); + if (is_gemm && input_defs.size() > 2) { + model_builder.AddInitializerToSkip(input_defs[2]->Name()); + } } } @@ -57,54 +90,152 @@ static Status GetTensorFloatDataTransposed(const ONNX_NAMESPACE::TensorProto& te } Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& /* logger */) const { + const logging::Logger& logger) const { std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - const auto& b_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - const auto& b_shape = b_tensor.dims(); - - auto* coreml_inner_product = layer->mutable_innerproduct(); - - // The coreml innerproduct weight (matrix B) is stored transposed - // - for MatMul and Gemm (transB = 0), the coreml weight is B' - // - for Gemm (transB = 1), the coreml weight is B - if (op_type == "MatMul") { - coreml_inner_product->set_inputchannels(b_shape[0]); - coreml_inner_product->set_outputchannels(b_shape[1]); - // Add weight (b of MatMul) - std::vector b_transposed; - ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(b_tensor, b_transposed)); - CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed); - } else { // Gemm - NodeAttrHelper helper(node); - const auto transB = helper.Get("transB", 0); - if (transB == 0) { - coreml_inner_product->set_inputchannels(b_shape[0]); - coreml_inner_product->set_outputchannels(b_shape[1]); + const auto& a = *input_defs[0]; + const auto& b = *input_defs[1]; + const auto* b_initializer = model_builder.GetConstantInitializer(b.Name()); // MLProgram MatMul may not be constant + + const bool is_matmul = op_type == "MatMul"; + const bool is_gemm = op_type == "Gemm"; + + NodeAttrHelper helper(node); + const auto transB = is_gemm ? helper.Get("transB", 0) : 0; + + std::vector b_shape; + ORT_IGNORE_RETURN_VALUE(GetShape(b, b_shape, logger)); + int64_t b0 = -1, b1 = -1; + + // ML Program MatMul supports N-D input + if (model_builder.CreateMLProgram() && is_matmul) { + if (b_shape.size() == 1) { + // B is treated as {b_shape[0], 1} according to the numpy rules. + b0 = b_shape[0]; + b1 = 1; + } else { + // last 2 dims are used + b0 = b_shape[b_shape.size() - 2]; + b1 = b_shape[b_shape.size() - 1]; + } + } else { + // we only support 2D input + b0 = b_shape[0]; + b1 = b_shape[1]; + } + + // B is {K, N} in ONNX spec by default, or {N, K} in Gemm if transB is true + const auto K = transB ? b1 : b0; + const auto N = transB ? b0 : b1; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + if (is_gemm) { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.linear.linear + auto gemm_op = model_builder.CreateOperation(node, "linear"); + AddOperationInput(*gemm_op, "x", a.Name()); + + // CoreML takes weight input as {N, K} which is the reverse of ONNX. + // if transB is true the input weight is {N, K} so can be added directly. + if (transB) { + AddOperationInput(*gemm_op, "weight", b.Name()); + } else { + // transpose from {K, N} to {N, K} + std::vector weight_nk; + std::vector weight_nk_shape = {N, K}; + ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, weight_nk)); + + AddOperationInput(*gemm_op, "weight", + model_builder.AddConstant(gemm_op->type(), b.Name() + "_t", weight_nk, weight_nk_shape)); + } + + if (input_defs.size() == 3) { + const auto& bias_arg = *input_defs[2]; + const auto& bias = *model_builder.GetConstantInitializer(bias_arg.Name()); + + // CoreML linear op requires bias to be 1D tensor of size N + if (bias.dims_size() == 1 && bias.dims().at(0) == N) { + // can use existing initializer + AddOperationInput(*gemm_op, "bias", bias_arg.Name()); + } else { + Initializer unpacked_tensor(bias); + auto bias_data = unpacked_tensor.DataAsSpan(); + std::string_view bias_data_name; + if (bias_data.size() == 1) { + // expand scalar to N + std::vector expanded_bias_data(N, bias_data[0]); + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", expanded_bias_data); + } else { + // can use data as-is but need to adjust shape (inferred by AddConstant as {bias_data.size()}) + bias_data_name = model_builder.AddConstant(gemm_op->type(), "bias", bias_data); + } + + AddOperationInput(*gemm_op, "bias", bias_data_name); + } + } + + AddOperationOutput(*gemm_op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(gemm_op)); + } else { + // CoreML implementation is the same as ONNX MatMul. + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.linear.matmul + auto matmul_op = model_builder.CreateOperation(node, "matmul"); + AddOperationInput(*matmul_op, "x", a.Name()); + AddOperationInput(*matmul_op, "y", b.Name()); + + // once again the spec lies and says transpose_y and transpose_x are optional... + auto false_value_name = model_builder.AddScalarConstant(matmul_op->type(), "false", false); + AddOperationInput(*matmul_op, "transpose_x", false_value_name); + AddOperationInput(*matmul_op, "transpose_y", false_value_name); + + AddOperationOutput(*matmul_op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(matmul_op)); + } + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + auto* coreml_inner_product = layer->mutable_innerproduct(); + + *layer->mutable_input()->Add() = a.Name(); + + coreml_inner_product->set_inputchannels(K); + coreml_inner_product->set_outputchannels(N); + + // CoreML takes weight input as {N, K} which is the reverse of ONNX. + // if Gemm's transB is true the input weight is {N, K} and can be added directly. + if (transB) { + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), *b_initializer)); + } else { std::vector b_transposed; - ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(b_tensor, b_transposed)); + ORT_RETURN_IF_ERROR(GetTensorFloatDataTransposed(*b_initializer, b_transposed)); CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_transposed); - } else { - coreml_inner_product->set_inputchannels(b_shape[1]); - coreml_inner_product->set_outputchannels(b_shape[0]); - // Add weight (b of MatMul) - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_weights(), b_tensor)); } - // Add bias if present - if (input_defs.size() > 2) { + if (is_gemm && input_defs.size() > 2) { + // Add bias coreml_inner_product->set_hasbias(true); - const auto& bias_tensor = *model_builder.GetInitializerTensors().at(input_defs[2]->Name()); - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), bias_tensor)); + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + + // if scalar, or single value expand to 1D tensor of size N + // IsOpSupportedImpl enforces it's scalar, {1}, {N}, or {1, N}. + Initializer unpacked_tensor(bias_tensor); + auto bias_data = unpacked_tensor.DataAsSpan(); + if (bias_data.size() == 1 && N > 1) { + std::vector expanded_bias_data(N, bias_data[0]); + CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), expanded_bias_data); + } else { + CreateCoreMLWeight(*coreml_inner_product->mutable_bias(), bias_data); + } } - } - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } @@ -112,98 +243,105 @@ bool GemmOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs(node.InputDefs()); + const bool is_matmul = op_type == "MatMul"; + const bool is_gemm = op_type == "Gemm"; + size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (!Contains(initializers, input_defs[b_idx]->Name())) { - LOGS(logger, VERBOSE) << "B of Gemm/Matmul must be an initializer tensor"; + std::vector a_shape; + if (!GetShape(*input_defs[a_idx], a_shape, logger)) { return false; } - std::vector a_shape; - { - if (!GetShape(*input_defs[a_idx], a_shape, logger)) - return false; - - if (a_shape.size() != 2) { - LOGS(logger, VERBOSE) << "A must be 2D"; - return false; - } + std::vector b_shape; + if (!GetShape(*input_defs[b_idx], b_shape, logger)) { + return false; + } - // TODO is it ok if the shape is dynamic and empty? - if (Product(a_shape) == 0) { - LOGS(logger, VERBOSE) << "A must be non-empty"; + if (!input_params.graph_viewer.GetConstantInitializer(input_defs[b_idx]->Name())) { + if (input_params.create_mlprogram && is_matmul) { + // ML Program MatMul allows non-constant B input + } else { + LOGS(logger, VERBOSE) << op_type << " B input must be a constant initializer"; return false; } } - std::vector b_shape; - { - if (!GetShape(*input_defs[b_idx], b_shape, logger)) - return false; - - if (b_shape.size() != 2) { - LOGS(logger, VERBOSE) << "B must be 2D"; - return false; - } + if (is_matmul) { + if (input_params.create_mlprogram) { + // ML Program matmul op has numpy semantics the same as the ONNX spec so we can use directly + } else { + // we could potentially support 1D and 3D if required. beyond 3D the dims that merge diverge. + // https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/onnx/_operators.py#L1607 + // https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/backend/nn/op_mapping.py#L1374 + // https://apple.github.io/coremltools/mlmodel/Format/NeuralNetwork.html#innerproductlayerparams + if (a_shape.size() != 2 || b_shape.size() != 2) { + LOGS(logger, VERBOSE) << "a and b inputs must be 2D. "; + return false; + } - if (Product(b_shape) == 0) { - LOGS(logger, VERBOSE) << "B must be non-empty"; - return false; + if (input_defs.size() > 2) { + LOGS(logger, VERBOSE) << "MatMul with C input is not supported"; + return false; + } } } - if (op_type == "Gemm") { + if (is_gemm) { + // A and B are 2D due to the ONNX spec NodeAttrHelper helper(node); const auto transA = helper.Get("transA", 0); const auto transB = helper.Get("transB", 0); const auto alpha = helper.Get("alpha", 1.0f); const auto beta = helper.Get("beta", 1.0f); + + // TODO: We can support transA, alpha and beta by using multiple layers/operations if needed. if (!(transA == 0 && alpha == 1.f && beta == 1.f)) { - LOGS(logger, VERBOSE) << "Only transA == 0, alpha == 1.0 " - << "and beta == 1.0 is supported." + LOGS(logger, VERBOSE) << "Only support for transA == 0, alpha == 1.0 " + << "and beta == 1.0 is currently implemented." << " transA " << transA << " alpha " << alpha << " beta " << beta; return false; } - // C of Gemm - // For now we only support {n} or {1,n} tensor if (input_defs.size() == 3) { - if (!Contains(initializers, input_defs[c_idx]->Name())) { - LOGS(logger, VERBOSE) << "C of Gemm must be an initializer tensor"; + if (!input_params.graph_viewer.GetConstantInitializer(input_defs[c_idx]->Name())) { + LOGS(logger, VERBOSE) << "C of Gemm must be a constant initializer"; return false; } std::vector c_shape; - if (!GetShape(*input_defs[c_idx], c_shape, logger)) + if (!GetShape(*input_defs[c_idx], c_shape, logger)) { return false; + } - size_t c_dim = c_shape.size(); + // B is {K, N} in ONNX spec by default, or {N, K} in Gemm if transB is true + const auto N = transB ? b_shape[0] : b_shape[1]; - if (c_dim == 0) { - LOGS(logger, VERBOSE) << "C of Gemm cannot be a scalar"; - return false; - } + size_t c_rank = c_shape.size(); - if (c_dim != 1) { - // If C is a (2+)d tensor, it must have the format {1, 1, ..., 1, n} - // where every except the last dimension should be 1 - for (size_t i = 0; i < c_dim - 1; ++i) { - if (c_shape[i] != 1) { - LOGS(logger, VERBOSE) << "C of Gemm must be a vector or a tensor with only last dimension != 1"; - return false; + // allowed: scalar, or 1D where the value is 1 or N, 2D with shape {1, N} + bool c_valid = false; + switch (c_rank) { + case 0: + c_valid = true; + break; + case 1: + if (c_shape[0] == 1 || c_shape[0] == N) { + c_valid = true; } - } + break; + case 2: + if (c_shape[0] == 1 && c_shape[1] == N) { + c_valid = true; + } + break; } - auto c_size = c_shape[c_dim - 1]; - if (c_size != (transB == 0 ? b_shape[1] : b_shape[0])) { - LOGS(logger, VERBOSE) << "C of Gemm must be a vector of b_shape[" - << (transB == 0 ? "1" : "0") << "]" - << " b_shape: [" << b_shape[0] << ", " << b_shape[1] << "]" - << " c_size: " << c_size; + if (!c_valid) { + LOGS(logger, VERBOSE) << "Shape of C Gemm input must be {}, {1}, {N}, or {1, N}. N:" << N << " C shape:" + << Shape2String(c_shape); return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc index 01aced739b36d..17910ba6fd486 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc @@ -19,104 +19,176 @@ class PoolOpBuilder : public BaseOpBuilder { bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + + bool SupportsMLProgram() const override { return true; } }; Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = model_builder.CreateNNLayer(node); - - auto* coreml_pool = layer->mutable_pooling(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - bool is_global_pooling = false; - if (op_type == "GlobalAveragePool") { - is_global_pooling = true; - coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE); - } else if (op_type == "GlobalMaxPool") { - is_global_pooling = true; - coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX); - } else if (op_type == "AveragePool") { - coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE); - } else if (op_type == "MaxPool") { - coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unknown op: ", op_type); - } +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; + + std::string_view coreml_op_type; + bool is_global = false; + bool is_avg_pool = false; + if (op_type == "GlobalAveragePool") { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.reduction.reduce_mean + coreml_op_type = "reduce_mean"; + is_global = true; + } else if (op_type == "GlobalMaxPool") { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.reduction.reduce_max + coreml_op_type = "reduce_max"; + is_global = true; + } else if (op_type == "AveragePool") { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.pool.avg_pool + coreml_op_type = "avg_pool"; + is_avg_pool = true; + } else if (op_type == "MaxPool") { + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.pool.max_pool + coreml_op_type = "max_pool"; + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unexpected op: ", op_type); + } - if (is_global_pooling) { - coreml_pool->set_globalpooling(true); - coreml_pool->mutable_valid(); - } else { // AveragePool or MaxPool - NodeAttrHelper helper(node); - const auto kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); - const auto strides = helper.Get("strides", std::vector{1, 1}); - const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - - coreml_pool->add_kernelsize(kernel_shape[0]); - coreml_pool->add_kernelsize(kernel_shape[1]); - coreml_pool->add_stride(strides[0]); - coreml_pool->add_stride(strides[1]); - coreml_pool->set_avgpoolexcludepadding(helper.Get("count_include_pad", 0) == 0); - coreml_pool->set_globalpooling(false); - - // Add Padding - // Usually using autopadding is more efficient than using explicit padding - // Try to see if we can map explicit padding to auto padding - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - AutoPadType auto_pad_type; - ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, kernel_shape[0], kernel_shape[1], - onnx_pads, strides, {1, 1} /* dilations */, - StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), - auto_pad_type)); - - if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { - auto* padding_type = coreml_pool->mutable_same(); - if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER - padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + std::unique_ptr op = model_builder.CreateOperation(node, coreml_op_type); + + AddOperationInput(*op, "x", input_defs[0]->Name()); + + if (is_global) { + // keep N and C dims, reduce the rest with keepdims=True. equivalent to the ONNX Global*Pool ops. + std::vector axes{2, 3}; // we only support 4D input currently. + AddOperationInput(*op, "axes", model_builder.AddConstant(op->type(), "axes", axes)); + AddOperationInput(*op, "keep_dims", model_builder.AddScalarConstant(op->type(), "keep_dims", true)); + } else { + NodeAttrHelper helper(node); + constexpr int num_spatial_dims = 2; // we only support 4D. -2 for N and C dims. + + AddPadTypeAndPads(*op, model_builder, op->type(), helper, num_spatial_dims); + + const auto kernel_shape = helper.GetInt64s("kernel_shape"); // required + AddOperationInput(*op, "kernel_sizes", model_builder.AddConstant(op->type(), "kernel_sizes", *kernel_shape)); + + // in theory all these values are optional according to the CoreML spec but simpler to just provide default + // values as the actual model compilation tends to require them. + const auto strides = helper.Get("strides", std::vector(num_spatial_dims, 1)); + const bool ceil_mode = helper.Get("ceil_mode", int64_t(0)); // convert int64_t to bool + + AddOperationInput(*op, "strides", model_builder.AddConstant(op->type(), "strides", strides)); + AddOperationInput(*op, "ceil_mode", model_builder.AddScalarConstant(op->type(), "ceil_mode", ceil_mode)); + + if (is_avg_pool) { + const bool count_exclude_pad = helper.Get("count_include_pad", int64_t(0)) == 0; + AddOperationInput(*op, "exclude_padding_from_average", + model_builder.AddScalarConstant(op->type(), "count_exclude_pad", count_exclude_pad)); } + } + + AddOperationOutput(*op, *node.OutputDefs()[0]); + model_builder.AddOperation(std::move(op)); + + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + auto* coreml_pool = layer->mutable_pooling(); + + bool is_global_pooling = false; + if (op_type == "GlobalAveragePool") { + is_global_pooling = true; + coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE); + } else if (op_type == "GlobalMaxPool") { + is_global_pooling = true; + coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX); + } else if (op_type == "AveragePool") { + coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_AVERAGE); + } else if (op_type == "MaxPool") { + coreml_pool->set_type(COREML_SPEC::PoolingLayerParams_PoolingType_MAX); } else { - auto* padding_type = coreml_pool->mutable_valid(); - if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { - // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts - auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - height_border->set_startedgesize(onnx_pads[0]); - height_border->set_endedgesize(onnx_pads[2]); - auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - width_border->set_startedgesize(onnx_pads[1]); - width_border->set_endedgesize(onnx_pads[3]); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "PoolOpBuilder, unexpected op: ", op_type); + } + + if (is_global_pooling) { + coreml_pool->set_globalpooling(true); + coreml_pool->mutable_valid(); + } else { // AveragePool or MaxPool + NodeAttrHelper helper(node); + const auto kernel_shape = helper.Get("kernel_shape", std::vector{0, 0}); + const auto strides = helper.Get("strides", std::vector{1, 1}); + const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + + coreml_pool->add_kernelsize(kernel_shape[0]); + coreml_pool->add_kernelsize(kernel_shape[1]); + coreml_pool->add_stride(strides[0]); + coreml_pool->add_stride(strides[1]); + coreml_pool->set_avgpoolexcludepadding(helper.Get("count_include_pad", 0) == 0); + coreml_pool->set_globalpooling(false); + + // Add Padding + // Usually using autopadding is more efficient than using explicit padding + // Try to see if we can map explicit padding to auto padding + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + AutoPadType auto_pad_type; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, kernel_shape[0], kernel_shape[1], + onnx_pads, strides, {1, 1} /* dilations */, + StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), + auto_pad_type)); + + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + auto* padding_type = coreml_pool->mutable_same(); + if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER + padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + } + } else { + auto* padding_type = coreml_pool->mutable_valid(); + if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { + // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts + auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + height_border->set_startedgesize(onnx_pads[0]); + height_border->set_endedgesize(onnx_pads[2]); + auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + width_border->set_startedgesize(onnx_pads[1]); + width_border->set_endedgesize(onnx_pads[3]); + } } } - } - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + *layer->mutable_input()->Add() = input_defs[0]->Name(); + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } - model_builder.AddLayer(std::move(layer)); return Status::OK(); } -bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, +bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) + if (!GetShape(*input_defs[0], input_shape, logger)) { return false; + } + // TODO: ML Program supports 3D and 5D. Add if we have a use case for that. const auto input_size = input_shape.size(); if (input_size != 4) { - LOGS(logger, VERBOSE) - << op_type << " only supports rank-4 tensor, input [" - << input_defs[0]->Name() << "] has actual dim count " << input_size; + LOGS(logger, VERBOSE) << op_type << " only supports rank-4 tensor, input [" + << input_defs[0]->Name() << "] has actual dim count " << input_size; return false; } if (op_type == "AveragePool" || op_type == "MaxPool") { NodeAttrHelper helper(node); + const auto storage_order = helper.Get("storage_order", 0); if (storage_order == 1) { LOGS(logger, VERBOSE) << "storage_order == 1 is not supported"; @@ -128,12 +200,14 @@ bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return false; } - // TODO, add support of the ceil_mode by adjusting the padding - // See https://stackoverflow.com/questions/59906456/in-pytorchs-maxpool2d-is-padding-added-depending-on-ceil-mode - // and https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/frontend/torch/ops.py#L621-L644 - if (helper.Get("ceil_mode", 0) == 1) { - LOGS(logger, VERBOSE) << "ceil_mode == 1 is not supported for pooling"; - return false; + if (!input_params.create_mlprogram) { + // TODO, add support of the ceil_mode by adjusting the padding + // See https://stackoverflow.com/questions/59906456/in-pytorchs-maxpool2d-is-padding-added-depending-on-ceil-mode + // and https://github.com/apple/coremltools/blob/1931758aae383c83daddfc56f11a24a9d2bf4b87/coremltools/converters/mil/frontend/torch/ops.py#L621-L644 + if (helper.Get("ceil_mode", 0) == 1) { + LOGS(logger, VERBOSE) << "ceil_mode == 1 is not supported for pooling"; + return false; + } } if (helper.Get("dilations", std::vector{1, 1}) != diff --git a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc index 7ae1746be3122..27d24d9c21893 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc @@ -1,11 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/framework/tensorprotoutils.h" #include "core/optimizer/initializer.h" -#include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" @@ -26,34 +25,56 @@ class ReshapeOpBuilder : public BaseOpBuilder { // Reshape opset 4- uses attributes for new shape which we do not support for now int GetMinSupportedOpSet(const Node& /* node */) const override { return 5; } + + bool SupportsMLProgram() const override { return true; } }; void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // Skip the second input which is the new shape as we always have to create a new version as the CoreML rules + // are different from ONNX. model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); } Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = model_builder.CreateNNLayer(node); - const auto& input_defs = node.InputDefs(); - const auto& initializers(model_builder.GetInitializerTensors()); - const auto& target_shape_tensor = *initializers.at(input_defs[1]->Name()); - const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty() - ? reinterpret_cast(target_shape_tensor.raw_data().data()) - : target_shape_tensor.int64_data().data(); - - const auto size = target_shape_tensor.dims()[0]; - TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size}; std::vector input_shape; - ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - ReshapeHelper helper(TensorShape(input_shape), target_shape); - *layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; - *layer->mutable_input()->Add() = input_defs[0]->Name(); - *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + ORT_RETURN_IF_NOT(GetStaticShape(*input_defs[0], input_shape, logger), "Cannot get shape of data"); + + const auto& data_name = input_defs[0]->Name(); + const auto& new_shape_name = input_defs[1]->Name(); + Initializer unpacked_tensor(*model_builder.GetConstantInitializer(new_shape_name)); + TensorShapeVector new_shape = ToShapeVector(unpacked_tensor.DataAsSpan()); + + // ReshapeHelper applies the ONNX rules to create the concrete output shape + ReshapeHelper helper(TensorShape(input_shape), new_shape); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; - model_builder.AddLayer(std::move(layer)); + // https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html#coremltools.converters.mil.mil.ops.defs.iOS15.tensor_transformation.reshape + std::unique_ptr reshape_op = model_builder.CreateOperation(node, "reshape"); + + AddOperationInput(*reshape_op, "x", data_name); + AddOperationInput(*reshape_op, "shape", + model_builder.AddConstant(reshape_op->type(), "shape", ToConstSpan(new_shape))); + + AddOperationOutput(*reshape_op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(reshape_op)); + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + *layer->mutable_reshapestatic()->mutable_targetshape() = {new_shape.cbegin(), new_shape.cend()}; + *layer->mutable_input()->Add() = data_name; + *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); + + model_builder.AddLayer(std::move(layer)); + } return Status::OK(); } @@ -61,14 +82,15 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& new_shape_name = input_defs[1]->Name(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (!Contains(initializers, new_shape_name)) { + const auto* new_shape_tensor = input_params.graph_viewer.GetConstantInitializer(new_shape_name); + if (!new_shape_tensor) { + // ONNX has different rules around how -1 and 0 values are used/combined, and + // we can't check if those can be translated to CoreML if the shape is unknown. LOGS(logger, VERBOSE) << "New shape of reshape must be a constant initializer"; return false; } - const auto& new_shape_tensor = *initializers.at(new_shape_name); - Initializer unpacked_tensor(new_shape_tensor); + Initializer unpacked_tensor(*new_shape_tensor); auto new_shape = unpacked_tensor.DataAsSpan(); if (new_shape.empty()) { LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty"; @@ -84,7 +106,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP return false; } - // CoreML reshape doesn't support new shape with more than 5 dimensions + // CoreML reshape doesn't support new shape with more than 5 dimensions. if (new_shape.size() > 5) { LOGS(logger, VERBOSE) << "Reshape does not support new shape with rank greater than 5. Input shape: " << Shape2String(input_shape) << ", new shape: " << Shape2String(new_shape); @@ -93,7 +115,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputP // CoreML reshape does not support 0 as dimension NodeAttrHelper helper(node); - const bool allow_zero = helper.Get("allowzero ", 0) == 1; + const bool allow_zero = helper.Get("allowzero", 0) == 1; if (allow_zero) { if (std::find(new_shape.begin(), new_shape.end(), int64_t{0}) != new_shape.end()) { LOGS(logger, VERBOSE) << "Reshape does not support new shape with 0 as dimension when allowzero is enabled. " diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 35dcde41a6bcf..6c2fcc2ace856 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -98,7 +98,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); - if (input_defs.size() == 3) { // use scales + if (input_defs.size() >= 3 && input_defs[2]->Exists()) { // use scales std::vector scales; ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); coreml_upsample->add_scalingfactor(static_cast(scales[2])); @@ -182,20 +182,24 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPa return false; } + bool using_scales = input_defs.size() >= 3 && input_defs[2]->Exists(); // scales - if (input_defs.size() == 3 && !Contains(initializers, input_defs[2]->Name())) { - LOGS(logger, VERBOSE) << "Input scales of Resize must be known"; + if (using_scales && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name())) { + LOGS(logger, VERBOSE) << "scales input of Resize must be a constant initializer"; return false; } // sizes - if (input_defs.size() > 3 && !Contains(initializers, input_defs[3]->Name())) { - LOGS(logger, VERBOSE) << "Input sizes of Resize must be known"; + if (!using_scales && + (input_defs.size() < 4 || + !input_defs[3]->Exists() || + !input_params.graph_viewer.GetConstantInitializer(input_defs[3]->Name()))) { + LOGS(logger, VERBOSE) << "sizes input of Resize must be a constant initializer"; return false; } // We want to check if the scales or sizes are not trying to resize on N/C channels here - if (input_defs.size() == 3) { // we are using scales + if (using_scales) { std::vector scales; if (!GetResizeScales(initializers, node, scales, logger)) return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index b716af738e1b1..39bfbfe5bba1f 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -54,7 +54,7 @@ Status PrepareSliceComputeMetadataFromConstantInitializers(const Node& slice_nod return Status::OK(); } - const auto* tensor_proto = graph_viewer.GetConstantInitializer(input_defs[input_idx]->Name(), true); + const auto* tensor_proto = graph_viewer.GetConstantInitializer(input_defs[input_idx]->Name()); ORT_RETURN_IF_NOT(tensor_proto, "Failed to get constant initializer."); Initializer unpacked_tensor(*tensor_proto, graph_viewer.ModelPath()); const auto data_type = unpacked_tensor.data_type(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc index 266396a0fe90e..d6584124c6aba 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc @@ -52,7 +52,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, target_shape.push_back(size_to_dimension); target_shape.push_back(size_from_dimension); - const auto reshape1_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "reshape1_output")); + const auto reshape1_output_name = model_builder.GetUniqueName(node, "reshape1_output"); { // Add reshape layer auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape1"); *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; @@ -60,7 +60,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, *reshape_layer->mutable_output()->Add() = reshape1_output_name; model_builder.AddLayer(std::move(reshape_layer)); } - const auto softmax_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "softmax_output")); + const auto softmax_output_name = model_builder.GetUniqueName(node, "softmax_output"); { auto* coreml_softmaxnd = layer->mutable_softmaxnd(); coreml_softmaxnd->set_axis(-1); diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index daab36f7b933d..eb4723a3b9746 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -144,14 +144,18 @@ void CopyOnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_prot break; } case ONNX_NAMESPACE::TensorProto_DataType_INT64: { - // from: int64_data/raw, to: longints - if (has_raw_data) { - CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); - - } else { - tensor_value.mutable_longints()->mutable_values()->CopyFrom(tensor_proto.int64_data()); - } - break; + // enable when this is proven to not be the case + ORT_THROW( + "INT64 is unexpected as CoreML uses 32-bit int for indices. " + "Most likely an initializer that should have been skipped was not."); + //// from: int64_data/raw, to: longints + // if (has_raw_data) { + // CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); + + //} else { + // tensor_value.mutable_longints()->mutable_values()->CopyFrom(tensor_proto.int64_data()); + //} + // break; } case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { // from: int32_data/raw, to: bytes @@ -186,18 +190,22 @@ void CopyOnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_prot break; } case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { - // from: uint64_data/raw, to: longints - if (has_raw_data) { - CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); - } else { - // TODO: Is this safe? Need to check the CopyFrom implementation. As it's a straight copy of bytes this - // hopefully can do it as one block instead of iterating and potentially doing a static_cast of each - // individual value. - tensor_value.mutable_longints()->mutable_values()->CopyFrom( - reinterpret_cast&>(tensor_proto.uint64_data())); - } - - break; + // enable when this is proven to not be the case + ORT_THROW( + "UINT64 is unexpected as CoreML uses 32-bit int for indices. " + "Most likely an initializer that should have been skipped was not."); + //// from: uint64_data/raw, to: longints + // if (has_raw_data) { + // CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); + // } else { + // // TODO: Is this safe? Need to check the CopyFrom implementation. As it's a straight copy of bytes this + // // hopefully can do it as one block instead of iterating and potentially doing a static_cast of each + // // individual value. + // tensor_value.mutable_longints()->mutable_values()->CopyFrom( + // reinterpret_cast&>(tensor_proto.uint64_data())); + // } + + // break; } case ONNX_NAMESPACE::TensorProto_DataType_BOOL: { // from: int32_data/raw, to: bools @@ -392,23 +400,28 @@ std::string GetModelOutputPath(bool create_ml_program) { } // namespace ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - int32_t coreml_version, uint32_t coreml_flags) + int32_t coreml_version, uint32_t coreml_flags, + std::vector&& onnx_input_names, + std::vector&& onnx_output_names) : graph_viewer_(graph_viewer), logger_(logger), coreml_version_(coreml_version), coreml_flags_(coreml_flags), create_ml_program_((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0), model_output_path_(GetModelOutputPath(create_ml_program_)), + onnx_input_names_(std::move(onnx_input_names)), + onnx_output_names_(std::move(onnx_output_names)), coreml_model_(std::make_unique()) { if (create_ml_program_) { #if defined(COREML_ENABLE_MLPROGRAM) coreml_model_->set_specificationversion(CoreMLSpecVersion()); MILSpec::Program& mlprogram = *coreml_model_->mutable_mlprogram(); - MILSpec::Function& main = (*mlprogram.mutable_functions())["main"]; + mlprogram.set_version(1); + mlprogram_main_fn_ = &(*mlprogram.mutable_functions())["main"]; const std::string coreml_opset = "CoreML" + std::to_string(CoreMLVersion()); - *main.mutable_opset() = coreml_opset; - mlprogram_main_ = &(*main.mutable_block_specializations())[coreml_opset]; + *mlprogram_main_fn_->mutable_opset() = coreml_opset; + mlprogram_main_block_ = &(*mlprogram_main_fn_->mutable_block_specializations())[coreml_opset]; // create the ModelPackage. this creates the output directory. mlpackage_ = std::make_unique(model_output_path_, /* create */ true); @@ -426,6 +439,8 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge weights_file_writer_ = std::make_unique(weights_info->path() + "/weight.bin"); #else // should never happen due to handling in coreml_execution_provider.cc + // throw here so all other code in this class can assume create_ml_program_ is only ever true in a build + // where ML Program support is enabled. ORT_THROW("ML Program is not enabled in this build"); #endif } else { @@ -435,6 +450,28 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge neural_network->set_arrayinputshapemapping( CoreML::Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); } + + // populate names. + const auto& initializers = graph_viewer_.GetAllInitializedTensors(); + const auto& inputs = graph_viewer_.GetInputs(); + // rough guess to try and avoid reallocs. most nodes produce one output but some have more so allow for that. + // also need to convert attributes to constants so allow for that + unique_names_.reserve(initializers.size() + inputs.size() + size_t(graph_viewer_.NumberOfNodes() * 1.5)); + for (const auto& pair : initializers) { + unique_names_.insert(pair.first); + } + + for (const auto* input : inputs) { + unique_names_.insert(input->Name()); + } + + for (const auto& node : graph_viewer_.Nodes()) { + for (const auto& def : node.OutputDefs()) { + if (def->Exists()) { + unique_names_.insert(def->Name()); + } + } + } } ModelBuilder::~ModelBuilder() = default; @@ -455,11 +492,94 @@ void ModelBuilder::AddLayer(std::unique_ptr layer) { neural_network->mutable_layers()->AddAllocated(layer.release()); } -#if defined(COREML_ENABLE_MLPROGRAM) - /* * ML Program related helpers */ +#if defined(COREML_ENABLE_MLPROGRAM) +const std::string& ModelBuilder::GetSafeName(const std::string& name) { + // Check the name is valid according to the MILSpec rules + // `Identifiers, generally used for names and keys, must match the regular expression [A-Za-z\_][A-Za-z0-9\_@]*.` + // + // There is a secondary list of reserved words that the coremltools python uses, but it's not clear if those are + // required here, or if we will ever hit a model that uses one of them. Due to that, skip checking them for now as + // it adds cost and code complexity + // https://github.com/apple/coremltools/blob/8b37641f243b1a3e81452feea311c6e30dcc9287/coremltools/converters/mil/mil/passes/defs/preprocess.py#L151C1-L175C10 + // static InlinedHashSet reserved_names = + // {"any", "bool", "program", "func", "tensor", "list", "dict", "tuple", "true", "false", + // "string", "bf16", "fp16", "fp32", "fp64", "int8", "int16", "int32", "int64", + // "uint8", "uint16", "uint32", "uint64"}; + + // handle empty name. shouldn't happen but code below assumes name is not empty + if (name.empty()) { + return name; + } + + // We don't need '@' or '\' even though they're allowed. Optimize for a good name that does not need to be changed. + + // has been sanitized and changed already + const auto entry = values_to_rename_.find(name); + if (entry != values_to_rename_.end()) { + return entry->second; + } + + // Replace anything but a good char with '_'. If first char is 0-9 we prefix with '_'; + bool changed = false; + std::string result = name; + + if (std::isdigit(result[0])) { + changed = true; + result = '_' + name; + } + + for (char& c : result) { + if (!std::isalnum(c) && c != '_') { + changed = true; + c = '_'; + } + } + + if (!changed) { + return name; // return original as the return value is a reference that must remain valid + } + + return (values_to_rename_[name] = GetUniqueName(result)); +} + +void ModelBuilder::SanitizeNames() { + // ML Model level inputs/outputs + auto* desc = coreml_model_->mutable_description(); + for (auto& input : *desc->mutable_input()) { + input.set_name(GetSafeName(input.name())); + } + + for (auto& output : *desc->mutable_output()) { + output.set_name(GetSafeName(output.name())); + } + + // main function inputs/outputs. + for (auto& input : *mlprogram_main_fn_->mutable_inputs()) { + input.set_name(GetSafeName(input.name())); + } + + // outputs from block with operations for current coreml version + for (auto& output : *mlprogram_main_block_->mutable_outputs()) { + output = GetSafeName(output); + } + + // iterate operations changing input/output/node names + for (auto& op : *mlprogram_main_block_->mutable_operations()) { + for (auto& input : *op.mutable_inputs()) { + for (auto& arg : *input.second.mutable_arguments()) { + arg.set_name(GetSafeName(arg.name())); + } + } + + for (auto& output : *op.mutable_outputs()) { + output.set_name(GetSafeName(output.name())); + } + } +} + std::unique_ptr ModelBuilder::CreateOperation(const Node& node, std::string_view op_type, std::string_view suffix) { @@ -472,14 +592,9 @@ std::unique_ptr ModelBuilder::CreateOperation(c return op; } -void ModelBuilder::AddConstant(std::string_view name, const ONNX_NAMESPACE::TensorProto& initializer) { - MILSpec::Value coreml_tensor = OnnxTensorToCoreMLTensor(initializer, *weights_file_writer_); - AddConstantOperation(name, std::move(coreml_tensor)); -} - -void ModelBuilder::AddConstantOperation(std::string_view name, MILSpec::Value&& coreml_tensor) { +const std::string& ModelBuilder::AddConstantOperation(std::string_view name, MILSpec::Value&& coreml_tensor) { // Replicates coremltools/converters/mil/backend/mil/load.py translate_const logic - MILSpec::Operation& const_op = *mlprogram_main_->mutable_operations()->Add(); + MILSpec::Operation& const_op = *mlprogram_main_block_->mutable_operations()->Add(); const_op.set_type("const"); MILSpec::NamedValueType& output = *const_op.mutable_outputs()->Add(); @@ -487,58 +602,63 @@ void ModelBuilder::AddConstantOperation(std::string_view name, MILSpec::Value&& *output.mutable_type() = coreml_tensor.type(); auto& attr_map = *const_op.mutable_attributes(); - attr_map["name"] = CreateScalarTensorValue(std::string(name)); + // the operation name doesn't really matter as it isn't used elsewhere, so sanitize name now + attr_map["name"] = CreateScalarTensorValue(GetSafeName(output.name())); attr_map["val"] = std::move(coreml_tensor); + + return output.name(); } // Add operation to the Block for the main function in the ML Program void ModelBuilder::AddOperation(std::unique_ptr operation) { - mlprogram_main_->mutable_operations()->AddAllocated(operation.release()); + mlprogram_main_block_->mutable_operations()->AddAllocated(operation.release()); } -std::string ModelBuilder::AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, - MILSpec::Value&& input_value) { +const std::string& ModelBuilder::AddTensorValueAsConstantOperation(std::string_view op_type, + std::string_view value_type, + MILSpec::Value&& input_value) { auto unique_value_name = GetUniqueName(MakeString(op_type, "_", value_type)); - AddConstantOperation(unique_value_name, std::move(input_value)); - return unique_value_name; + return AddConstantOperation(unique_value_name, std::move(input_value)); } template -std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, - std::optional> shape) { +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { // add specialization below static_assert(false_for_T, "Missing specialization for value type"); - return ""; // unreachable + + return "ModelBuilder::AddConstant error"; // unreachable } template <> -std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, - gsl::span value, - std::optional> shape) { +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { auto input_value = CreateTensorValue(value, shape); return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } template <> -std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, - gsl::span value, - std::optional> shape) { +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { auto input_value = CreateTensorValue(value, shape); // CoreML uses int32 return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } template <> -std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, - gsl::span value, - std::optional> shape) { +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { auto input_value = CreateTensorValue(value, shape); return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } template <> -std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, - gsl::span value, - std::optional> shape) { +std::string_view ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { auto input_value = CreateTensorValue(value, shape); return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } @@ -581,11 +701,13 @@ Status ModelBuilder::RegisterInitializers() { continue; } - if (create_ml_program_) { #if defined(COREML_ENABLE_MLPROGRAM) - AddConstant(name, tensor); + if (create_ml_program_) { + MILSpec::Value coreml_tensor = OnnxTensorToCoreMLTensor(tensor, *weights_file_writer_); + ORT_IGNORE_RETURN_VALUE(AddConstantOperation(name, std::move(coreml_tensor))); + } else #endif - } else { + { std::unique_ptr layer = std::make_unique(); layer->set_name(GetUniqueName("initializer_" + name)); @@ -616,32 +738,33 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i if (is_input) { // input should not be an initializer - if (Contains(GetInitializerTensors(), name)) + if (Contains(GetInitializerTensors(), name)) { return Status::OK(); + } // This input will not be used - if (Contains(skipped_inputs_, name)) + if (Contains(skipped_inputs_, name)) { return Status::OK(); + } } auto* model_description = coreml_model_->mutable_description(); - auto& input_output = is_input - ? *model_description->mutable_input()->Add() - : *model_description->mutable_output()->Add(); + auto& input_output = is_input ? *model_description->mutable_input()->Add() + : *model_description->mutable_output()->Add(); input_output.set_name(name); + auto* multi_array = input_output.mutable_type()->mutable_multiarraytype(); std::vector shape; - ORT_RETURN_IF_NOT(GetShape(node_arg, shape, logger_), - "Unable to get shape for ", input_output_type, ": ", name); + ORT_RETURN_IF_NOT(GetShape(node_arg, shape, logger_), "Unable to get shape for ", input_output_type, ": ", name); if (shape.empty()) { - // If we have an empty shape, this is a scalar input, - // Since all the input output of CoreML EP is MultiArray, we will make the scalar input output as a {1} MultiArray + // If we have an empty shape, this is a scalar + // Since all the input/output of CoreML EP is MultiArray, we will make the scalar input/output a {1} MultiArray shape.push_back(1); - // we need to change the shapes of these scalar outputs back to {} when CoreML EP returns these values to ORT + // we need to change the shapes of scalar outputs back to {} when CoreML EP returns values to ORT if (!is_input) { AddScalarOutput(name); } @@ -713,13 +836,20 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i #if defined(COREML_ENABLE_MLPROGRAM) if (create_ml_program_) { - MILSpec::Function& main = (*coreml_model_->mutable_mlprogram()->mutable_functions())["main"]; if (is_input) { - // the model inputs need to be wired up as args to the 'main' function - main.mutable_inputs()->Add(CreateNamedTensorValueType(node_arg)); + // the model inputs need to be wired up as args to the 'main' function. + auto tensor_value_type = CreateNamedTensorValueType(node_arg); + tensor_value_type.set_name(name); + if (node_arg.Shape()->dim_size() == 0) { + // update shape from {} to {1} (same change we made at the model input level above). + tensor_value_type.mutable_type()->mutable_tensortype()->set_rank(1); + tensor_value_type.mutable_type()->mutable_tensortype()->add_dimensions()->mutable_constant()->set_size(1); + } + + mlprogram_main_fn_->mutable_inputs()->Add(std::move(tensor_value_type)); } else { // the model outputs need to be set as outputs of the Block for the 'main' function - *mlprogram_main_->mutable_outputs()->Add() = node_arg.Name(); + *mlprogram_main_block_->mutable_outputs()->Add() = name; } } #endif // defined(COREML_ENABLE_MLPROGRAM) @@ -744,7 +874,7 @@ Status ModelBuilder::ProcessNodes() { // This shouldn't happen as this is called from CoreMLExecutionProvider::Compile and should only be processing // nodes that we said were supported and were returned from CoreMLExecutionProvider::GetCapability. return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Node [", node.Name(), "], type [", node.OpType(), "] is not supported"); + "Node [", node.Name(), "], type [", node.OpType(), "] was not able to be processed"); } } @@ -767,6 +897,12 @@ Status ModelBuilder::CreateModel() { ORT_RETURN_IF_ERROR(ProcessNodes()); ORT_RETURN_IF_ERROR(RegisterModelOutputs()); +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + SanitizeNames(); + } +#endif + return Status::OK(); } @@ -795,7 +931,7 @@ Status ModelBuilder::SaveModel() { #if defined(COREML_ENABLE_MLPROGRAM) // need to delete the ModelPackage instance for it to write out the manifest. clear out the other ML Program // related types as well. - mlprogram_main_ = nullptr; + mlprogram_main_block_ = nullptr; mlpackage_.reset(); weights_file_writer_.reset(); #endif @@ -804,11 +940,51 @@ Status ModelBuilder::SaveModel() { } Status ModelBuilder::LoadModel(std::unique_ptr& model) { - model = std::make_unique(model_output_path_, - std::move(input_output_info_), - std::move(scalar_outputs_), - std::move(int64_outputs_), - logger_, coreml_flags_); +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + // we need to provide the sanitized names for model inputs/outputs so that info is captured. + // the input/output matching when we execute the model from the CoreML EP is based on order, so the change + // to the names doesn't matter for that. + auto get_sanitized_names = [this](std::vector&& names) -> std::vector { + std::vector output(std::move(names)); + + for (std::string& name : output) { + name = GetSafeName(name); + } + + return output; + }; + + // also need to update the keys in input_output_info_ + auto get_sanitized_io_info = [this](std::unordered_map&& info) { + std::unordered_map output; + output.reserve(info.size()); + + for (auto entry = info.begin(), end = info.end(); entry != end; ++entry) { + output.emplace(GetSafeName(entry->first), std::move(entry->second)); + } + + return output; + }; + + model = std::make_unique(model_output_path_, + get_sanitized_names(std::move(onnx_input_names_)), + get_sanitized_names(std::move(onnx_output_names_)), + get_sanitized_io_info(std::move(input_output_info_)), + std::move(scalar_outputs_), + std::move(int64_outputs_), + logger_, coreml_flags_); + } else +#endif + { + model = std::make_unique(model_output_path_, + std::move(onnx_input_names_), + std::move(onnx_output_names_), + std::move(input_output_info_), + std::move(scalar_outputs_), + std::move(int64_outputs_), + logger_, coreml_flags_); + } return model->LoadModel(); // load using CoreML API, including compilation } @@ -816,8 +992,11 @@ Status ModelBuilder::LoadModel(std::unique_ptr& model) { // static Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger, int32_t coreml_version, uint32_t coreml_flags, + std::vector&& onnx_input_names, + std::vector&& onnx_output_names, std::unique_ptr& model) { - ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_flags); + ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_flags, + std::move(onnx_input_names), std::move(onnx_output_names)); ORT_RETURN_IF_ERROR(builder.CreateModel()); ORT_RETURN_IF_ERROR(builder.SaveModel()); @@ -847,20 +1026,31 @@ void ModelBuilder::AddInputToSkip(const std::string& input_name) { skipped_inputs_.insert(input_name); } -std::string ModelBuilder::GetUniqueName(std::string_view base_name) { +const std::string& ModelBuilder::GetUniqueName(const std::string& base_name) { + if (unique_names_.find(base_name) == unique_names_.end()) { + return *unique_names_.insert(base_name).first; + } + std::string unique_name; - do { - std::ostringstream os; - os << base_name << "_token_" << name_token_++; - unique_name = os.str(); - } while (Contains(unique_names_, unique_name)); + std::string suffix; + + // supports up to 1000 unique names without having to grow in the loop + unique_name.reserve(base_name.size() + 5); + unique_name = base_name; + + while (Contains(unique_names_, unique_name)) { + // assign followed by += to avoid creating temporary strings. + unique_name = base_name; + unique_name += "__"; + unique_name += std::to_string(name_token_++); + } - return unique_name; + return *unique_names_.insert(unique_name).first; } -std::string ModelBuilder::GetUniqueName(const Node& node, std::string_view suffix) { +const std::string& ModelBuilder::GetUniqueName(const Node& node, std::string_view suffix) { if (node.Name().empty()) { - return GetUniqueName(MakeString("Node_", node.Index(), "_", node.OpType(), suffix)); + return GetUniqueName(MakeString(node.OpType(), "_", node.Index(), suffix)); } else { return GetUniqueName(node.Name() + std::string(suffix)); } diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index 961ba647257b5..8f85ab2c09e7c 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -25,17 +25,20 @@ namespace onnxruntime { namespace coreml { class IOpBuilder; -class Model; class ModelBuilder { private: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - int32_t coreml_version, uint32_t coreml_flags); + int32_t coreml_version, uint32_t coreml_flags, + std::vector&& onnx_input_names, + std::vector&& onnx_output_names); public: // Create the CoreML model, serialize to disk, load and compile using the CoreML API and return in `model` static Status Build(const GraphViewer& graph_viewer, const logging::Logger& logger, int32_t coreml_version, uint32_t coreml_flags, + std::vector&& onnx_input_names, + std::vector&& onnx_output_names, std::unique_ptr& model); ~ModelBuilder(); @@ -101,8 +104,8 @@ class ModelBuilder { /// /// Unique name generated for value. template - std::string AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, - std::optional> shape = std::nullopt) { + std::string_view AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape = std::nullopt) { static_assert(std::is_same_v || std::is_same_v || std::is_same_v || @@ -113,8 +116,8 @@ class ModelBuilder { } template - std::string AddConstant(std::string_view op_type, std::string_view value_type, const std::vector& value, - std::optional> shape = std::nullopt) { + std::string_view AddConstant(std::string_view op_type, std::string_view value_type, const std::vector& value, + std::optional> shape = std::nullopt) { return AddConstant(op_type, value_type, AsSpan(value), shape); } @@ -122,17 +125,10 @@ class ModelBuilder { /// Add a scalar value as a 'const' operation. See AddConstant for details. /// template - std::string AddScalarConstant(std::string_view op_type, std::string_view value_type, const T& value) { + std::string_view AddScalarConstant(std::string_view op_type, std::string_view value_type, const T& value) { return AddConstant(op_type, value_type, AsSpan({value}), AsSpan({})); } - /// - /// Add an existing a constant ONNX initializer to the ML Program as a 'const' operation - /// - /// Initializer name - /// Initializer data - void AddConstant(std::string_view name, const ONNX_NAMESPACE::TensorProto& initializer); - // add the operation to the main function void AddOperation(std::unique_ptr operation); #endif @@ -149,18 +145,26 @@ class ModelBuilder { // be added to CoreML model, since CoreML does not like input unused void AddInputToSkip(const std::string& input_name); - std::string GetUniqueName(std::string_view base_name); - std::string GetUniqueName(const Node& node, std::string_view suffix); + const std::string& GetUniqueName(const std::string& base_name); + const std::string& GetUniqueName(const Node& node, std::string_view suffix); + + const logging::Logger& Logger() const { return logger_; } private: #if defined(COREML_ENABLE_MLPROGRAM) template - std::string AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, - std::optional> shape = std::nullopt); - - void AddConstantOperation(std::string_view name, COREML_SPEC::MILSpec::Value&& initializer); - std::string AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, - COREML_SPEC::MILSpec::Value&& input_value); + std::string_view AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape = std::nullopt); + + // apply the CoreML naming rules and fix any invalid names. + const std::string& GetSafeName(const std::string& name); + // sanitize all the names in the ML Model + void SanitizeNames(); + + // add Value as a const operation. return value name in case sanitization changed it + const std::string& AddConstantOperation(std::string_view name, COREML_SPEC::MILSpec::Value&& initializer); + const std::string& AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, + COREML_SPEC::MILSpec::Value&& input_value); #endif // Convert the ONNX model in graph_viewer_ to a CoreML::Specification::Model and serialize to disk. @@ -193,6 +197,9 @@ class ModelBuilder { const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old) const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel + std::vector onnx_input_names_; + std::vector onnx_output_names_; + std::unique_ptr coreml_model_; std::unordered_set scalar_outputs_; std::unordered_set int64_outputs_; @@ -208,9 +215,19 @@ class ModelBuilder { // mlprogram_main_ is the main block of the CoreML ML Program. // It is set in CreateModel to the CoreML Model.mlprogram.functions['main'].block_specializations['CoreML'] // entry we create. - COREML_SPEC::MILSpec::Block* mlprogram_main_{nullptr}; + COREML_SPEC::MILSpec::Function* mlprogram_main_fn_{nullptr}; // Function that contains a Block with the operations + COREML_SPEC::MILSpec::Block* mlprogram_main_block_{nullptr}; // Block that all the operations are added to std::unique_ptr mlpackage_; std::unique_ptr weights_file_writer_; + + // Values must start with [a-zA-A_] + // Additionally they can't be in a list of reserved words. + // If we need to sanitize an initializer name we do so during PreprocessInitializers and apply the change during + // RegisterInitializers. + // We also check inputs in AddOperation and apply the change there. + // This means an op builder author doesn't need to be aware of the renaming. + // https://github.com/apple/coremltools/blob/8b37641f243b1a3e81452feea311c6e30dcc9287/coremltools/converters/mil/mil/passes/defs/preprocess.py#L146-L149 + std::unordered_map values_to_rename_; #endif }; diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index 8e718da07703c..0ba715cc7c6d9 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -114,28 +114,27 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector& node_compute_funcs) { for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { Node& fused_node = fused_node_and_graph.fused_node; - const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); std::unique_ptr coreml_model; - ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_flags_, - coreml_model)); - { - const auto& input_defs = fused_node.InputDefs(); - std::vector onnx_input_names(input_defs.size()); - for (size_t i = 0, end = input_defs.size(); i < end; ++i) { - onnx_input_names[i] = input_defs[i]->Name(); - } - coreml_model->SetOnnxInputs(std::move(onnx_input_names)); - } + auto get_names = [](const ConstPointerContainer>& args) -> std::vector { + std::vector names; + names.reserve(args.size()); - { - const auto& output_defs = fused_node.OutputDefs(); - std::vector onnx_output_names(output_defs.size()); - for (size_t i = 0, end = output_defs.size(); i < end; ++i) { - onnx_output_names[i] = output_defs[i]->Name(); - } - coreml_model->SetOnnxOutputs(std::move(onnx_output_names)); + for (const NodeArg* def : args) { + names.push_back(def->Name()); + } + + return names; + }; + + std::vector onnx_input_names = get_names(fused_node.InputDefs()); + std::vector onnx_output_names = get_names(fused_node.OutputDefs()); + + const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); + ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_flags_, + std::move(onnx_input_names), std::move(onnx_output_names), + coreml_model)); } coreml_models_.emplace(fused_node.Name(), std::move(coreml_model)); @@ -153,13 +152,14 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector(state); - const auto& model_inputs = model->GetOnnxInputs(); - const auto& model_outputs = model->GetOnnxOutputs(); + + // input/output names used by the CoreML model in the order that matches the fused_node InputDefs/OutputDefs + const auto& model_inputs = model->GetOrderedInputs(); + const auto& model_outputs = model->GetOrderedOutputs(); ORT_RETURN_IF_NOT(model_inputs.size() <= num_inputs, "Inconsistent input sizes"); ORT_RETURN_IF_NOT(model_outputs.size() == num_outputs, "Inconsistent output sizes"); @@ -182,28 +182,25 @@ common::Status CoreMLExecutionProvider::Compile(const std::vectorshape; - ORT_RETURN_IF(!coreml::IsStaticShape(inferred_shape) && coreml::DoesShapeSpecifyZeroElements(shape), - "Input (", input_name, ") has a dynamic shape (", coreml::Shape2String(inferred_shape), - ") but the runtime shape (", coreml::Shape2String(shape), - ") has zero elements. This is not supported by the CoreML EP."); - } + const auto& inferred_shape = input_info->shape; + ORT_RETURN_IF(!coreml::IsStaticShape(inferred_shape) && coreml::DoesShapeSpecifyZeroElements(shape), + "Input (", input_name, ") has a dynamic shape (", coreml::Shape2String(inferred_shape), + ") but the runtime shape (", coreml::Shape2String(shape), + ") has zero elements. This is not supported by the CoreML EP."); // If we have an empty shape, this is a scalar input, // Since all the input output of CoreML EP is MultiArray, we will make the scalar input as a {1} MultiArray - if (shape.empty()) + if (shape.empty()) { shape.push_back(1); + } // CoreML MLMultiArray API expect input to be non-const // https://developer.apple.com/documentation/coreml/mlmultiarray/2881219-initwithdatapointer?language=objc void* inputBuffer = const_cast(input_tensor.GetTensorRawData()); - inputs.emplace( - input_name, - coreml::OnnxTensorData{ - coreml::OnnxTensorInfo{tensor_info.GetElementType(), shape}, - inputBuffer, - }); + inputs.emplace(input_name, coreml::OnnxTensorData{ + coreml::OnnxTensorInfo{tensor_info.GetElementType(), shape}, + inputBuffer, + }); } // From this point we will need to take the exclusive lock on the model until the Predict is @@ -215,14 +212,13 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector static_shape) -> void* { + [&ctx, &model_outputs](const std::string& name, + int32_t requested_onnx_tensor_element_type, + gsl::span static_shape) -> void* { const auto model_output_it = std::find(model_outputs.begin(), model_outputs.end(), name); ORT_ENFORCE(model_output_it != model_outputs.end(), "Failed to find CoreML model output name: ", name); - const auto output_idx = gsl::narrow_cast(std::distance(model_outputs.begin(), model_output_it)); + const auto output_idx = gsl::narrow_cast(std::distance(model_outputs.begin(), model_output_it)); auto output_tensor = ctx.GetOutput(output_idx, static_shape.data(), static_shape.size()); const auto type_and_shape_info = output_tensor.GetTensorTypeAndShapeInfo(); @@ -243,13 +239,15 @@ common::Status CoreMLExecutionProvider::Compile(const std::vectorIsScalarOutput(output_name)) + if (model->IsScalarOutput(output_name)) { output_shape.clear(); + } // Since CoreML EP only accepts int32 output type and onnx requires int64 output, // We are going to set the model output (from int32) ->int64 - if (model->IsInt64Output(output_name)) + if (model->IsInt64Output(output_name)) { output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; + } outputs.emplace(output_name, coreml::OnnxTensorInfo{output_type, output_shape}); } diff --git a/onnxruntime/core/providers/coreml/dump_mlprogram_model.py b/onnxruntime/core/providers/coreml/dump_mlprogram_model.py new file mode 100644 index 0000000000000..a3ceee70684dc --- /dev/null +++ b/onnxruntime/core/providers/coreml/dump_mlprogram_model.py @@ -0,0 +1,27 @@ +import sys + +import coremltools as ct + +if len(sys.argv) < 2: + print(f"Usage: {sys.argv[0]} ") + print("If generated by onnxruntime this will be /Data/com.microsoft.onnxruntime/model.mlmodel") + sys.exit(-1) + +model_path = sys.argv[1] +m = ct.models.MLModel(model_path) + +spec = m.get_spec() +print(spec) + +# Example code if you want to filter output or do more advanced things +# main = spec.mlProgram.functions["main"] +# block = main.block_specializations[main.opset] +# print(f"{len(block.operations)} operators") +# for op in block.operations: +# if op.type == 'const': +# if op.attributes["name"].immediateValue.tensor.strings.values[0] == "conv_0_pad_type_0": +# print(f"Conv pad_type={op.attributes['val'].immediateValue.tensor.strings.values}") +# +# if op.type == 'conv': +# #print(op) +# pass diff --git a/onnxruntime/core/providers/coreml/model/host_utils.h b/onnxruntime/core/providers/coreml/model/host_utils.h index 4f9a014c4d885..a9991ccb945ce 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.h +++ b/onnxruntime/core/providers/coreml/model/host_utils.h @@ -67,6 +67,12 @@ int CoreMLVersion(); // Get a temporary macOS/iOS temp file path std::string GetTemporaryFilePath(); +#if !defined(NDEBUG) && defined(__APPLE__) +// Override location the model is written to so that a) it's easily found and b) it is not automatically deleted +// when the EP exits. Use to debug the model that is generated. +// See onnxruntime/core/providers/coreml/dump_mlprogram_model.py for a script to dump the ML Program. +constexpr const char* kOverrideModelOutputDirectoryEnvVar = "ORT_COREML_EP_MODEL_DIR"; +#endif } // namespace util } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/host_utils.mm b/onnxruntime/core/providers/coreml/model/host_utils.mm index 0ae0cf8f0d207..5487ea35388f5 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.mm +++ b/onnxruntime/core/providers/coreml/model/host_utils.mm @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/platform/env.h" #include "core/providers/coreml/model/host_utils.h" #import @@ -31,6 +32,15 @@ int32_t CoreMLVersion() { std::string GetTemporaryFilePath() { // Get temporary directory for user. NSURL* temporary_directory_url = [NSURL fileURLWithPath:NSTemporaryDirectory() isDirectory:YES]; + +#if !defined(NDEBUG) + std::string path_override = Env::Default().GetEnvironmentVar(kOverrideModelOutputDirectoryEnvVar); + if (!path_override.empty()) { + NSString* ns_path_override = [NSString stringWithUTF8String:path_override.c_str()]; + temporary_directory_url = [NSURL fileURLWithPath:ns_path_override isDirectory:YES]; + } +#endif + // Generate a Unique file name to use. NSString* temporary_filename = [[NSProcessInfo processInfo] globallyUniqueString]; diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index b940c4b768aec..e3cd43d786fc3 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -35,6 +35,8 @@ using GetOutputTensorMutableRawDataFn = std::function&& model_input_names, + std::vector&& model_output_names, std::unordered_map&& input_output_info, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, @@ -60,12 +62,11 @@ class Model { // Mutex for exclusive lock to this model object OrtMutex& GetMutex() { return mutex_; } - // Input and output names in the onnx model's order - const std::vector& GetOnnxInputs() const { return onnx_inputs_; } - void SetOnnxInputs(std::vector&& inputs) { onnx_inputs_ = std::move(inputs); } - - const std::vector& GetOnnxOutputs() const { return onnx_outputs_; } - void SetOnnxOutputs(std::vector&& outputs) { onnx_outputs_ = std::move(outputs); } + // Input and output names in the ORT fused node's order. + // Names may have been adjusted from the originals due to CoreML naming rules. + // We do inputs/outputs based on order at the ONNX level so this doesn't matter. + const std::vector& GetOrderedInputs() const { return model_input_names_; } + const std::vector& GetOrderedOutputs() const { return model_output_names_; } const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const { const auto info_it = input_output_info_.find(name); @@ -80,13 +81,13 @@ class Model { private: std::unique_ptr execution_; + std::vector model_input_names_; // input names in the order of the ORT fused node's inputs + std::vector model_output_names_; // output names in the order of the ORT fused node's outputs + std::unordered_map input_output_info_; std::unordered_set scalar_outputs_; std::unordered_set int64_outputs_; - std::vector onnx_inputs_; - std::vector onnx_outputs_; - OrtMutex mutex_; }; diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index d5cd70bff9479..1434043e064f4 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -19,6 +19,7 @@ #include "core/common/narrow.h" #include "core/common/span_utils.h" #include "core/graph/onnx_protobuf.h" +#include "core/platform/env.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/coreml_provider_factory.h" #include "core/providers/coreml/model/host_utils.h" @@ -287,6 +288,14 @@ - (void)cleanup { compiled_model_path_ = nil; } +#if !defined(NDEBUG) + std::string path_override = Env::Default().GetEnvironmentVar(util::kOverrideModelOutputDirectoryEnvVar); + if (!path_override.empty()) { + // don't cleanup + coreml_model_path_ = nil; + } +#endif + if (coreml_model_path_ != nil) { error = nil; [[NSFileManager defaultManager] removeItemAtPath:coreml_model_path_ error:&error]; @@ -487,12 +496,16 @@ Status Predict(const std::unordered_map& inputs, } Model::Model(const std::string& path, + std::vector&& model_input_names, + std::vector&& model_output_names, std::unordered_map&& input_output_info, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, const logging::Logger& logger, uint32_t coreml_flags) : execution_(std::make_unique(path, logger, coreml_flags)), + model_input_names_(std::move(model_input_names)), + model_output_names_(std::move(model_output_names)), input_output_info_(std::move(input_output_info)), scalar_outputs_(std::move(scalar_outputs)), int64_outputs_(std::move(int64_outputs)) { diff --git a/onnxruntime/core/providers/coreml/model/model_stub.cc b/onnxruntime/core/providers/coreml/model/model_stub.cc index 087c9f8c05d5f..c6f2e7401ea1e 100644 --- a/onnxruntime/core/providers/coreml/model/model_stub.cc +++ b/onnxruntime/core/providers/coreml/model/model_stub.cc @@ -9,12 +9,16 @@ namespace coreml { class Execution {}; Model::Model(const std::string& /*path*/, + std::vector&& model_input_names, + std::vector&& model_output_names, std::unordered_map&& input_output_info, std::unordered_set&& scalar_outputs, std::unordered_set&& int64_outputs, const logging::Logger& /*logger*/, uint32_t /*coreml_flags*/) : execution_(std::make_unique()), + model_input_names_(std::move(model_input_names)), + model_output_names_(std::move(model_output_names)), input_output_info_(std::move(input_output_info)), scalar_outputs_(std::move(scalar_outputs)), int64_outputs_(std::move(int64_outputs)) { diff --git a/onnxruntime/core/providers/cpu/tensor/reshape_helper.h b/onnxruntime/core/providers/cpu/tensor/reshape_helper.h index 5961686674424..d7ceda16e61ea 100644 --- a/onnxruntime/core/providers/cpu/tensor/reshape_helper.h +++ b/onnxruntime/core/providers/cpu/tensor/reshape_helper.h @@ -37,12 +37,14 @@ class ReshapeHelper { if (unknown_dim != -1) { // calculate unknown dimension ORT_ENFORCE(size != 0 && (input_shape_size % size) == 0, - "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, ", requested shape:", TensorShape(requested_shape)); + "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, + ", requested shape:", TensorShape(requested_shape)); requested_shape[unknown_dim] = input_shape_size / size; } else { // check if the output shape is valid. ORT_ENFORCE(input_shape_size == size, - "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, ", requested shape:", TensorShape(requested_shape)); + "The input tensor cannot be reshaped to the requested shape. Input shape:", input_shape, + ", requested shape:", TensorShape(requested_shape)); } } }; diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 7d4111e3b9c39..729ad34368453 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -64,17 +64,22 @@ namespace perftest { "\t Refer to onnxruntime_session_options_config_keys.h for valid keys and values. \n" "\t [Example] -C \"session.disable_cpu_ep_fallback|1 ep.context_enable|1\" \n" "\t-i: Specify EP specific runtime options as key value pairs. Different runtime options available are: \n" + "\t [Usage]: -e -i '| |'\n" + "\n" "\t [DML only] [performance_preference]: DML device performance preference, options: 'default', 'minimum_power', 'high_performance', \n" "\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" "\t [DML only] [disable_metacommands]: Options: 'true', 'false', \n" "\t [DML only] [enable_dynamic_graph_fusion]: Options: 'true', 'false', \n" "\t [DML only] [enable_graph_serialization]: Options: 'true', 'false', \n" + "\n" "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" "\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n" "\t [OpenVINO only] [num_of_threads]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n" "\t [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n" + "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" + "\n" "\t [QNN only] [backend_path]: QNN backend path. e.g '/folderpath/libQnnHtp.so', '/folderpath/libQnnCpu.so'.\n" "\t [QNN only] [profiling_level]: QNN profiling level, options: 'basic', 'detailed', default 'off'.\n" "\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" @@ -89,9 +94,8 @@ namespace perftest { "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n" "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" - "\t [Usage]: -e -i '| |'\n\n" - "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" - "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n\n" + "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" + "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" "\t [TensorRT only] [trt_min_subgraph_size]: Minimum size of TensorRT subgraphs.\n" "\t [TensorRT only] [trt_max_workspace_size]: Set TensorRT maximum workspace size in byte.\n" @@ -108,20 +112,23 @@ namespace perftest { "\t [TensorRT only] [trt_force_sequential_engine_build]: Force TensorRT engines to be built sequentially.\n" "\t [TensorRT only] [trt_context_memory_sharing_enable]: Enable TensorRT context memory sharing between subgraphs.\n" "\t [TensorRT only] [trt_layer_norm_fp32_fallback]: Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow.\n" - "\t [Usage]: -e -i '| |'\n\n" - "\t [Example] [For TensorRT EP] -e tensorrt -i 'trt_fp16_enable|true trt_int8_enable|true trt_int8_calibration_table_name|calibration.flatbuffers trt_int8_use_native_calibration_table|false trt_force_sequential_engine_build|false'\n" + "\t [Example] [For TensorRT EP] -e tensorrt -i 'trt_fp16_enable|true trt_int8_enable|true trt_int8_calibration_table_name|calibration.flatbuffers trt_int8_use_native_calibration_table|false trt_force_sequential_engine_build|false'\n" + "\n" "\t [NNAPI only] [NNAPI_FLAG_USE_FP16]: Use fp16 relaxation in NNAPI EP..\n" "\t [NNAPI only] [NNAPI_FLAG_USE_NCHW]: Use the NCHW layout in NNAPI EP.\n" "\t [NNAPI only] [NNAPI_FLAG_CPU_DISABLED]: Prevent NNAPI from using CPU devices.\n" "\t [NNAPI only] [NNAPI_FLAG_CPU_ONLY]: Using CPU only in NNAPI EP.\n" - "\t [Usage]: -e -i ' '\n\n" - "\t [Example] [For NNAPI EP] -e nnapi -i \" NNAPI_FLAG_USE_FP16 NNAPI_FLAG_USE_NCHW NNAPI_FLAG_CPU_DISABLED \"\n" + "\t [Example] [For NNAPI EP] -e nnapi -i \"NNAPI_FLAG_USE_FP16 NNAPI_FLAG_USE_NCHW NNAPI_FLAG_CPU_DISABLED\"\n" + "\n" + "\t [CoreML only] [COREML_FLAG_CREATE_MLPROGRAM]: Create an ML Program model instead of Neural Network.\n" + "\t [Example] [For CoreML EP] -e coreml -i \"COREML_FLAG_CREATE_MLPROGRAM\"\n" + "\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" "\t [SNPE only] [priority]: execution priority, options: 'low', 'normal'. \n" "\t [SNPE only] [buffer_type]: options: 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. default: ITENSOR'. \n" "\t [SNPE only] [enable_init_cache]: enable SNPE init caching feature, set to 1 to enabled it. Disabled by default. \n" - "\t [Usage]: -e -i '| |' \n\n" - "\t [Example] [For SNPE EP] -e snpe -i \"runtime|CPU priority|low\" \n\n" + "\t [Example] [For SNPE EP] -e snpe -i \"runtime|CPU priority|low\" \n\n" + "\n" "\t-T [Set intra op thread affinities]: Specify intra op thread affinity string\n" "\t [Example]: -T 1,2;3,4;5,6 or -T 1-2;3-4;5-6 \n" "\t\t Use semicolon to separate configuration between threads.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 1934314b8ce43..9679ca6159464 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -468,7 +468,10 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); nnapi_flags |= NNAPI_FLAG_CPU_ONLY; } else if (key.empty()) { } else { - ORT_THROW("[ERROR] [NNAPI] wrong key type entered. Choose from the following runtime key options that are available for NNAPI. ['NNAPI_FLAG_USE_FP16', 'NNAPI_FLAG_USE_NCHW', 'NNAPI_FLAG_CPU_DISABLED', 'NNAPI_FLAG_CPU_ONLY'] \n"); + ORT_THROW( + "[ERROR] [NNAPI] wrong key type entered. Choose from the following runtime key options " + "that are available for NNAPI. " + "['NNAPI_FLAG_USE_FP16', 'NNAPI_FLAG_USE_NCHW', 'NNAPI_FLAG_CPU_DISABLED', 'NNAPI_FLAG_CPU_ONLY'] \n"); } } Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_Nnapi(session_options, nnapi_flags)); @@ -476,10 +479,31 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); ORT_THROW("NNAPI is not supported in this build\n"); #endif } else if (provider_name_ == onnxruntime::kCoreMLExecutionProvider) { +#ifdef __APPLE__ #ifdef USE_COREML - Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, 0)); + uint32_t coreml_flags = 0; + std::string ov_string = performance_test_config.run_config.ep_runtime_config_string; + std::istringstream ss(ov_string); + + std::string key; + while (ss >> key) { + if (key == "COREML_FLAG_CREATE_MLPROGRAM") { + coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM; + std::cout << "Enabling ML Program.\n"; + } else if (key.empty()) { + } else { + ORT_THROW( + "[ERROR] [CoreML] wrong key type entered. Choose from the following runtime key options " + "that are available for CoreML. ['COREML_FLAG_CREATE_MLPROGRAM'] \n"); + } + } + // COREML_FLAG_CREATE_MLPROGRAM + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(session_options, coreml_flags)); +#else + ORT_THROW("CoreML is not supported in this build\n"); +#endif #else - ORT_THROW("COREML is not supported in this build\n"); + ORT_THROW("COREML is not supported on this platform.\n"); #endif } else if (provider_name_ == onnxruntime::kDmlExecutionProvider) { #ifdef USE_DML diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index 7b6f1b9244be9..94817158017bd 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -192,5 +192,25 @@ TEST(CoreMLExecutionProviderTest, TestOrtFormatModel) { #endif } +// Test that we fix invalid names in model inputs, initializers and outputs. +// Names in CoreML cannot start with [0-9] or contain anything but "[a-z][A-Z][0-9]_" +TEST(CoreMLExecutionProviderTest, TestNameSanitization) { + OpTester test("Clip", 11); + + std::vector dims{3, 3}; + test.AddInput("0", dims, + {-1.0f, 0.0f, 1.0f, + -6.0f, 0.0f, 6.0f, + -5.4f, 2.0f, 6.0f}); + test.AddInput("1.min", {}, {-5}, true); // add as initializers + test.AddInput("2/max", {}, {5}, true); + test.AddOutput("3", dims, + {-1.0f, 0.0f, 1.0f, + -5.0f, 0.0f, 5.0f, + -5.0f, 2.0f, 5.0f}); + + // TensorRT does not support Clip opset 11 yet. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/math/clip_test.cc b/onnxruntime/test/providers/cpu/math/clip_test.cc index efb46e86d04e4..b5d5f84df950a 100644 --- a/onnxruntime/test/providers/cpu/math/clip_test.cc +++ b/onnxruntime/test/providers/cpu/math/clip_test.cc @@ -182,7 +182,7 @@ TEST(MathOpTest, Clip) { run_test(true); } -// Use clip between [0, 6] as Relu6 (for some EPs, such as NNAPI) +// Use clip between [0, 6] as Relu6 to test optimized path in some EPs, such as NNAPI and CoreML TEST(MathOpTest, Clip_Relu6) { // To test NNAPI EP, we need the min/max to be in initializers auto run_test = [](bool min_max_are_initializer) { @@ -208,6 +208,31 @@ TEST(MathOpTest, Clip_Relu6) { run_test(true); } +// Use clip between [0, inf] as Relu to test optimized path in some EPs, such as CoreML +TEST(MathOpTest, Clip_Relu) { + // To test NNAPI EP, we need the min/max to be in initializers + auto run_test = [](bool min_max_are_initializer) { + OpTester test("Clip", 11); + + std::vector dims{3, 3}; + test.AddInput("X", dims, + {-1.0f, 0.0f, 1.0f, + -6.0f, 3.5f, 6.0f, + -5.4f, 2.0f, 8.0f}); + test.AddInput("min", {}, {0.0f}, min_max_are_initializer); + test.AddOutput("Y", dims, + {0.0f, 0.0f, 1.0f, + 0.0f, 3.5f, 6.0f, + 0.0f, 2.0f, 8.0f}); + + // TensorRT does not support Clip opset 11 yet. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + }; + + run_test(false); + run_test(true); +} + // Use clip between [-1, 1] as Relu1 (for some EPs, such as NNAPI) TEST(MathOpTest, Clip_Relu1) { // To test NNAPI EP, we need the min/max to be in initializers diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index bf089e083d67e..428925e154497 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -281,24 +281,31 @@ using GemmOpTypedTestsTypes = ::testing::Types; TYPED_TEST_SUITE(GemmOpTypedTests, GemmOpTypedTestsTypes); TYPED_TEST(GemmOpTypedTests, TestGemmScalarBroadcast) { - OpTester test("Gemm"); + auto run_test = [](bool b_is_initializer, bool c_is_initializer) { + OpTester test("Gemm"); - test.AddAttribute("transA", (int64_t)0); - test.AddAttribute("transB", (int64_t)0); - test.AddAttribute("alpha", 1.0f); - test.AddAttribute("beta", 1.0f); + test.AddAttribute("transA", (int64_t)0); + test.AddAttribute("transB", (int64_t)0); + test.AddAttribute("alpha", 1.0f); + test.AddAttribute("beta", 1.0f); - test.AddInput("A", {2, 4}, - {static_cast(1.0f), static_cast(2.0f), static_cast(3.0f), static_cast(4.0f), - static_cast(-1.0f), static_cast(-2.0f), static_cast(-3.0f), static_cast(-4.0f)}); - test.AddInput("B", {4, 3}, std::vector(12, static_cast(1.0f))); - test.AddInput("C", {1}, std::vector{static_cast(1.0f)}); - test.AddOutput("Y", {2, 3}, - {static_cast(11.0f), static_cast(11.0f), static_cast(11.0f), - static_cast(-9.0f), static_cast(-9.0f), static_cast(-9.0f)}); - test.Config(run_with_tunable_op) - .RunWithConfig(); + test.AddInput("A", {2, 4}, + {static_cast(1.0f), static_cast(2.0f), static_cast(3.0f), static_cast(4.0f), + static_cast(-1.0f), static_cast(-2.0f), static_cast(-3.0f), static_cast(-4.0f)}); + test.AddInput("B", {4, 3}, std::vector(12, static_cast(1.0f)), b_is_initializer); + test.AddInput("C", {1}, std::vector{static_cast(1.0f)}, c_is_initializer); + test.AddOutput("Y", {2, 3}, + {static_cast(11.0f), static_cast(11.0f), static_cast(11.0f), + static_cast(-9.0f), static_cast(-9.0f), static_cast(-9.0f)}); + test.Config(run_with_tunable_op) + .RunWithConfig(); + }; + + run_test(false, false); + // CoreML EP requires weight and bias to be initializers + run_test(true, true); } + TYPED_TEST(GemmOpTypedTests, TestGemm2DBroadcast_2) { OpTester test("Gemm"); diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index ee18cf2cea6cb..cbb4531a50b7c 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -75,6 +75,43 @@ TEST(BatchNormTest, PositiveTestCase) { input_data_map.insert({"mean", mean}); input_data_map.insert({"var", var}); + InputShapesMap input_shapes_map; + vector input_shape{1, 1, 7, 7}; + input_shapes_map.insert({"X", input_shape}); + input_shapes_map.insert({"scale", {1}}); + input_shapes_map.insert({"B", {1}}); + input_shapes_map.insert({"mean", {1}}); + input_shapes_map.insert({"var", {1}}); + + auto expected_output = {1.01359f, 0.703983f, 0.641631f, 1.08571f, 0.939167f, 0.762469f, 0.682729f, 0.762401f, 0.787021f, + 1.06744f, 0.604378f, 0.957476f, 0.667302f, 0.901764f, 1.07566f, 1.01117f, 0.928324f, 0.897667f, + 0.705842f, 0.660885f, 0.977291f, 0.878918f, 0.818345f, 1.06608f, 0.839057f, 1.04796f, 0.621471f, + 0.781831f, 0.760527f, 0.835665f, 1.05825f, 0.611442f, 0.781873f, 1.08437f, 0.907454f, 0.926173f, + 1.03375f, 0.707961f, 0.968646f, 0.621757f, 0.973095f, 0.700301f, 0.916723f, 0.807602f, 0.692598f, + 0.621972f, 0.707334f, 0.63723f, 0.63062f}; + float epsilon = 1e-05f; + TestBatchNorm(input_data_map, input_shapes_map, epsilon, expected_output, input_shape); +} + +TEST(BatchNormTest, PositiveTestCase_5D) { + // This input was taken from the SpatialBN_1.pb, SpatialBN_1_input.pb and SpatialBN_1_output.pb files. + vector X{0.329876f, -0.287158f, -0.411425f, 0.473621f, 0.18156f, -0.170596f, -0.329516f, -0.170733f, -0.121664f, 0.4372f, + -0.485668f, 0.218049f, -0.360263f, 0.107016f, 0.45358f, 0.325056f, 0.15995f, 0.098852f, -0.283453f, -0.373051f, + 0.257542f, 0.0614853f, -0.0592363f, 0.434488f, -0.0179583f, 0.398374f, -0.451602f, -0.132009f, -0.174468f, + -0.0247169f, 0.418897f, -0.47159f, -0.131925f, 0.470943f, 0.118357f, 0.155664f, 0.370062f, -0.279229f, 0.240311f, + -0.451034f, 0.249178f, -0.294496f, 0.13683f, -0.0806475f, -0.309849f, -0.450604f, -0.28048f, -0.420197f, -0.433369f}; + vector scale{0.589433f}; + vector B{-0.384622f}; + vector mean{-2.45673f}; + vector var{1.37998f}; + + InputDataMap input_data_map; + input_data_map.insert({"X", X}); + input_data_map.insert({"scale", scale}); + input_data_map.insert({"B", B}); + input_data_map.insert({"mean", mean}); + input_data_map.insert({"var", var}); + InputShapesMap input_shapes_map; vector input_shape{1, 1, 7, 7, 1}; input_shapes_map.insert({"X", input_shape}); diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 1d31f3fdb4eb4..5addb5dd9ce46 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -572,8 +572,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kDmlExecutionProvider}); } -TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers +TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric_scales) { + // To test CoreML/NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; From acbfc29f272b5578145e7600bc42342e116ffbc2 Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 1 Mar 2024 10:57:14 +0800 Subject: [PATCH 178/207] Follow up fix for Gelu impl (#19693) ### Follow up fix for Gelu impl There are two minor comments in https://github.com/microsoft/onnxruntime/pull/19560. Fix them in this pull request. ### Motivation and Context --- docs/ORTModule_Training_Guidelines.md | 2 +- onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc | 8 +++----- onnxruntime/contrib_ops/cuda/bert/fast_gelu.h | 4 +++- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 91057d3dfb120..f50b18b736936 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -293,7 +293,7 @@ A classical usage of disabling the deep copy: when the deep copy before module e export ORTMODULE_MEMORY_OPT_LEVEL=0 ``` -### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT +#### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT - **Feature Area**: *ORTMODULE/Optimizations* - **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes. diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc index e8974a29476b6..8b8e4e267f895 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.cc @@ -8,8 +8,7 @@ #include "contrib_ops/cpu/bert/bias_gelu_helper.h" #ifdef USE_ROCM #include "contrib_ops/rocm/bert/elementwise.h" -#endif -#ifdef USE_CUDA +#else #include "contrib_ops/cuda/bert/transformer_common.h" #endif @@ -36,7 +35,7 @@ using namespace ONNX_NAMESPACE; template FastGelu::FastGelu(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info) { -#ifdef USE_CUDA +#ifndef USE_ROCM const TransformerOptions* options = TransformerOptions::GetInstance(); use_half2_ = !options->DisableHalf2(); #endif @@ -63,8 +62,7 @@ Status FastGelu::ComputeInternal(OpKernelContext* context) const { reinterpret_cast(input->Data()), static_cast(input_length), (nullptr != bias) ? reinterpret_cast(bias->Data()) : nullptr, static_cast(bias_length), reinterpret_cast(output->MutableData())); -#endif -#ifdef USE_CUDA +#else return LaunchFastGeluKernel(GetDeviceProp(), Stream(context), static_cast(input_length), diff --git a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h index d563556593e6e..26f3bd5a03928 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h +++ b/onnxruntime/contrib_ops/cuda/bert/fast_gelu.h @@ -18,7 +18,9 @@ class FastGelu final : public CudaKernel { Status ComputeInternal(OpKernelContext* ctx) const override; private: - bool use_half2_; // Only applicable to CUDA kernel (not ROCM). +#ifndef USE_ROCM + bool use_half2_; +#endif }; } // namespace cuda From ed550b5fe5aa41e182db84d2b2f2fb768121fd7a Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 29 Feb 2024 20:36:29 -0800 Subject: [PATCH 179/207] Change webgpu CI pipeline to use a preinstalled chrome (#19729) ### Description Change webgpu CI pipeline to use a preinstalled chrome. Hopefully it can increase the stability. Now the chrome got from puppeteer often failed to start. --- .../github/azure-pipelines/templates/win-web-ci.yml | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 043da233cc674..b882d6fb167fd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -31,6 +31,7 @@ jobs: variables: webgpuCommandlineExtraFlags: '--chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de' runCodesignValidationInjection: false + CHROME_BIN: 'C:\Program Files\Google\Chrome\Application\chrome.exe' timeoutInMinutes: 60 workspace: clean: all @@ -95,18 +96,6 @@ jobs: targetFolder: $(Build.SourcesDirectory)\js\web\lib\wasm\binding flattenFolders: true displayName: 'Binplace js files' - - script: | - npm i -g puppeteer - workingDirectory: '$(Build.SourcesDirectory)' - displayName: 'Use puppeteer to prepare Chrome for tests' - - script: | - FOR /F "tokens=* USEBACKQ" %%F IN (`where /r %HOMEDRIVE%%HOMEPATH%\.cache\puppeteer chrome.exe`) DO ( - SET var=%%F - ECHO found chrome.exe: %%F - ) - ECHO ##vso[task.setvariable variable=CHROME_BIN;]%var% - workingDirectory: '$(Build.SourcesDirectory)' - displayName: 'Set CHROME_BIN' - script: | npm ci workingDirectory: '$(Build.SourcesDirectory)\js' From 5672cdebdf5648815fcc3a001dc00e610a9f9b51 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 1 Mar 2024 11:01:58 -0800 Subject: [PATCH 180/207] Update google benchmark to 1.8.3. (#19734) Update google benchmark to 1.8.3. Update deps_update_and_upload.py script to make it easier to use. --- cgmanifests/generated/cgmanifest.json | 2 +- cmake/deps.txt | 2 +- cmake/deps_update_and_upload.py | 135 ++++++++++++------ .../templates/download-deps.yml | 4 +- 4 files changed, 98 insertions(+), 45 deletions(-) diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index efd901787fdb7..cfad59be6b4c0 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -116,7 +116,7 @@ "component": { "type": "git", "git": { - "commitHash": "361e8d1cfe0c6c36d30b39f1b61302ece5507320", + "commitHash": "344117638c8ff7e239044fd0fa7085839fc03021", "repositoryUrl": "https://github.com/google/benchmark.git" }, "comments": "google_benchmark" diff --git a/cmake/deps.txt b/cmake/deps.txt index cb431f8c77397..9cba25b00157d 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -26,7 +26,7 @@ eigen;https://gitlab.com/libeigen/eigen/-/archive/e7248b26a1ed53fa030c5c459f7ea0 flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;ba0a75fd12dbef8f6557a74e611b7a3d0c5fe7bf fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 -google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908 +google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.3.zip;bf9870756ee3f8d2d3b346b24ee3600a41c74d3d google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034 googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73 diff --git a/cmake/deps_update_and_upload.py b/cmake/deps_update_and_upload.py index d357284d91225..63df3f6f03869 100644 --- a/cmake/deps_update_and_upload.py +++ b/cmake/deps_update_and_upload.py @@ -1,56 +1,109 @@ -# in case deps.txt is updated, run this file to update and upload the dependencies so that CI can use them. -# Before running the script, increase the version number found at: +# If deps.txt is updated, run this file to update and upload the dependencies so that CI can use them. +# +# Before running the script, find the latest version number at: # https://aiinfra.visualstudio.com/Lotus/_artifacts/feed/Lotus/UPack/onnxruntime_build_dependencies/versions +# Increment it to obtain a new version number to use. +# # Run without --do-upload once to verify downloading. Use --do-upload when you are ready to publish. -# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --do-upload -# update version number in tools\ci_build\github\azure-pipelines\templates\download-deps.yml +# E.g.: +# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 +# # check contents of C:/temp/onnxruntime_deps +# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --no-download --do-upload +# +# Next, update the version number in tools/ci_build/github/azure-pipelines/templates/download-deps.yml. + +import argparse +import contextlib +import pathlib import re import subprocess -import os -import argparse import tempfile +script_dir = pathlib.Path(__file__).parent + parser = argparse.ArgumentParser(description="Update dependencies and publish to Azure Artifacts") parser.add_argument( - "--root-path", type=str, default=tempfile.gettempdir(), help="Target root path for downloaded files" + "--root-path", + type=pathlib.Path, + help="Target root path for downloaded files. If not provided, a temporary directory is used.", +) +parser.add_argument( + "--version", + type=str, + help="Package version to publish", +) +parser.add_argument( + "--do-upload", + action="store_true", + dest="upload", + help="Upload the package to Azure Artifacts", +) +parser.add_argument( + "--no-download", + action="store_false", + dest="download", + help="Skip downloading the dependency files. " + "Use with '--do-upload' and '--root-path' to upload the package from existing dependency files.", ) -parser.add_argument("--version", type=str, default="1.0.82", help="Package version to publish") -parser.add_argument("--do-upload", action="store_true", help="Upload the package to Azure Artifacts") args = parser.parse_args() -with open("cmake/deps.txt") as file: +if args.upload: + assert args.version is not None, "'--version' must be specified if uploading." + +if args.upload != args.download: + assert args.root_path is not None, "'--root-path' must be specified if only downloading or uploading." + +deps_path = script_dir / "deps.txt" +with open(deps_path) as file: text = file.read() lines = [line for line in text.split("\n") if not line.startswith("#") and ";" in line] -root_path = args.root_path - -for line in lines: - url = re.sub("^[^;]+?;https://([^;]+?);.*", r"https://\1", line) - filename = re.sub("^[^;]+?;https://([^;]+?);.*", r"\1", line) - full_path = os.path.join(root_path, filename) - subprocess.run(["curl", "-sSL", "--create-dirs", "-o", full_path, url]) # noqa: PLW1510 - -package_name = "onnxruntime_build_dependencies" -version = args.version - -# Check if the user is logged in to Azure -result = subprocess.run("az account show", shell=True, capture_output=True, text=True) # noqa: PLW1510 -if "No subscriptions found" in result.stderr: - # Prompt the user to log in to Azure - print("You are not logged in to Azure. Please log in to continue.") - subprocess.run("az login", shell=True) # noqa: PLW1510 - -# Publish the package to Azure Artifacts if --no-upload is not specified - -cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' -if args.do_upload: - subprocess.run(cmd, shell=True) # noqa: PLW1510 -else: - print("would have run: " + cmd) - -cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' -if args.do_upload: - subprocess.run(cmd, shell=True) # noqa: PLW1510 -else: - print("would have run: " + cmd) +with contextlib.ExitStack() as context_stack: + if args.root_path is not None: + root_path = args.root_path.resolve() + root_path.mkdir(parents=True, exist_ok=True) + else: + temp_dir_name = context_stack.enter_context(tempfile.TemporaryDirectory()) + root_path = pathlib.Path(temp_dir_name) + + if args.download: + print(f"Downloading dependencies to directory: {root_path}") + + dep_pattern = re.compile(r"^[^;]+;https://([^;]+);.*$") + + for line in lines: + match = dep_pattern.fullmatch(line) + if match is None: + continue + + dep_path = match[1] + url = f"https://{dep_path}" + full_path = root_path / dep_path + + subprocess.run(["curl", "-sSL", "--create-dirs", "-o", str(full_path), url], check=True) + + package_name = "onnxruntime_build_dependencies" + version = args.version if args.version is not None else "VERSION_PLACEHOLDER" + + if args.upload: + # Check if the user is logged in to Azure + result = subprocess.run("az account show", shell=True, capture_output=True, text=True, check=False) + if "No subscriptions found" in result.stderr: + # Prompt the user to log in to Azure + print("You are not logged in to Azure. Please log in to continue.") + subprocess.run("az login", shell=True, check=True) + + # Publish the package to Azure Artifacts if --do-upload is specified + + cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' + if args.upload: + subprocess.run(cmd, shell=True, check=True) + else: + print("would have run: " + cmd) + + cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' + if args.upload: + subprocess.run(cmd, shell=True, check=True) + else: + print("would have run: " + cmd) diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 95e34cd863915..01be343795a56 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.133 + version: 1.0.134 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.133 + version: 1.0.134 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. From 22176a5fa8fe97efe05a63c1e7bb89b0e54cd201 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Fri, 1 Mar 2024 13:44:29 -0800 Subject: [PATCH 181/207] disable gemm f16 on CPU (#19744) ### Description Temporarily disable fp16 gemm on CPU because it usually needs a following Cast which offsets the gain. Need more fp16 operators implementation and performance tuning. Also fix a fusion error of LayerNormalization. ### Motivation and Context --- .vscode/settings.json | 5 ++++- .../core/optimizer/layer_norm_fusion.cc | 14 +++++++++++++ .../providers/cpu/cpu_execution_provider.cc | 21 ------------------- .../test/providers/cpu/math/gemm_test.cc | 2 +- 4 files changed, 19 insertions(+), 23 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 3e2b1f31dd6cf..98d23090fd474 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -21,5 +21,8 @@ "cpplint.filters": [ "-build/include_subdir", "-runtime/references" - ] + ], + "files.associations": { + "span": "cpp" + } } diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index b6ad4fde6c1f7..ce696154adb6d 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -447,6 +447,13 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, NodeArg* x_input = has_leading_cast ? graph.GetNode(p_reduce_mean_input_node->Index())->MutableInputDefs()[0] : reduce_mean_node.MutableInputDefs()[0]; + + // CPU doesn't support fp16 + if (reduce_mean_node.GetExecutionProviderType() == kCpuExecutionProvider && + x_input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + continue; + } + InlinedVector layer_norm_input_defs{x_input, scale, bias}; Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName("LayerNormalization"), "LayerNormalization", @@ -689,6 +696,13 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr NodeArg* x_input = has_leading_cast ? graph.GetNode(p_pow_input_node->Index())->MutableInputDefs()[0] : pow_node.MutableInputDefs()[0]; + + // CPU doesn't support fp16 + if (reduce_mean_node.GetExecutionProviderType() == kCpuExecutionProvider && + x_input->TypeAsProto()->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { + continue; + } + InlinedVector layer_norm_input_defs{x_input, scale}; Node& layer_norm_node = graph.AddNode(graph.GenerateNodeName("SimplifiedLayerNormalization"), "SimplifiedLayerNormalization", diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 37e7e42150413..7e0f919deb0a7 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -143,9 +143,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Aco class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, Atan); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, double, Gemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 7, 8, MLFloat16, Gemm); -#endif class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, Hardmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, float, LogSoftmax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, 10, double, LogSoftmax); @@ -335,9 +332,6 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOn class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Flatten); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, double, Gemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, MLFloat16, Gemm); -#endif class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, MatMul); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, int32_t, MatMul); @@ -497,9 +491,6 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Sp class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, ScatterND); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, double, Gemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, MLFloat16, Gemm); -#endif class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint8_t, BitShift); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, uint32_t, BitShift); @@ -606,9 +597,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, string, Expand); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, Gemm); -#ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, Gemm); -#endif class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, MatMul); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, MatMul); @@ -2617,15 +2605,6 @@ Status RegisterFp16Kernels(KernelRegistry& kernel_registry) { MLFloat16, LeakyRelu)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/test/providers/cpu/math/gemm_test.cc b/onnxruntime/test/providers/cpu/math/gemm_test.cc index 428925e154497..1a542fb67418e 100644 --- a/onnxruntime/test/providers/cpu/math/gemm_test.cc +++ b/onnxruntime/test/providers/cpu/math/gemm_test.cc @@ -277,7 +277,7 @@ class GemmOpTypedTests : public ::testing::Test { // On CPUs without fp16 instructions the tests will output a warning: // "registered execution providers CPUExecutionProvider were unable to run the model" // , then they will still pass. -using GemmOpTypedTestsTypes = ::testing::Types; +using GemmOpTypedTestsTypes = ::testing::Types; TYPED_TEST_SUITE(GemmOpTypedTests, GemmOpTypedTestsTypes); TYPED_TEST(GemmOpTypedTests, TestGemmScalarBroadcast) { From f06164ef8b8de42dd67ca2137f6996cdc87a3f72 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Fri, 1 Mar 2024 14:50:06 -0800 Subject: [PATCH 182/207] [js/web] transfer input buffer back to caller thread (#19677) ### Description When using proxy worker, input buffers should be transferred back to the caller thread after `run()` call is done. Fixes #19488 --- js/web/lib/wasm/proxy-worker/main.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index 6cbd38c76ccc8..3ce37a2d6b652 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -103,7 +103,7 @@ self.onmessage = (ev: MessageEvent): void => { } else { postMessage( {type, out: outputs} as OrtWasmMessage, - extractTransferableBuffers(outputs as SerializableTensorMetadata[])); + extractTransferableBuffers([...inputs, ...outputs] as SerializableTensorMetadata[])); } }, err => { From a0521f899e9d495d57ae044bd4a1fe4d17155782 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 1 Mar 2024 16:23:20 -0800 Subject: [PATCH 183/207] Enable CPUINFO for all Windows build (#19655) ### Description It was disabled in PR #9065. And the reason was: " api-ms-win-core-kernel32-legacy-*.dll wasn't available in Windows 8 and was added in Windows 10, so cpuinfo breaks our Windows 8 support. I'm disabling it again." We no longer support Windows 8. Therefore we can add CPUINFO back. ### Motivation and Context To make the code simpler. If in any case the library doesn't work as expected, we can submit a PR to their code base and fix it. --- .../external/onnxruntime_external_deps.cmake | 9 +- cmake/onnxruntime_common.cmake | 5 -- onnxruntime/core/common/cpuid_info.cc | 82 ++++++++----------- onnxruntime/core/common/cpuid_info.h | 19 ++--- 4 files changed, 42 insertions(+), 73 deletions(-) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 09d57164b4ee1..cb75b0b8751bb 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -256,14 +256,7 @@ if (onnxruntime_ENABLE_CPUINFO) set(CPUINFO_SUPPORTED TRUE) endif() if (WIN32) - # Exclude Windows ARM build and Windows Store - if (${onnxruntime_target_platform} MATCHES "^(ARM.*|arm.*)$" ) - message(WARNING "Cpuinfo not included for compilation problems with Windows ARM.") - set(CPUINFO_SUPPORTED FALSE) - elseif (WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) - message(WARNING "Cpuinfo not included non-Desktop builds") - set(CPUINFO_SUPPORTED FALSE) - endif() + set(CPUINFO_SUPPORTED TRUE) elseif (NOT ${onnxruntime_target_platform} MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$") message(WARNING "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo. " diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index 6b8c2560b1714..fb56e3f3445d4 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -201,10 +201,6 @@ endif() if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) - if((WIN32 AND NOT CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib) OR ((ARM64 OR ARM) AND MSVC)) - # msvc compiler report syntax error with cpuinfo arm source files - # and cpuinfo does not have code for getting arm uarch info under windows - else() # Link cpuinfo if supported # Using it mainly in ARM with Android. # Its functionality in detecting x86 cpu features are lacking, so is support for Windows. @@ -212,7 +208,6 @@ if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo ${ONNXRUNTIME_CLOG_TARGET_NAME}) endif() - endif() endif() if (NOT onnxruntime_BUILD_SHARED_LIB) diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 711fd595e90fd..be881f6bc4bc2 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -52,6 +52,13 @@ #if defined(CPUINFO_SUPPORTED) #include +#if defined(CPUIDINFO_ARCH_ARM) +namespace onnxruntime { +// The following function is declared in "core/common/cpuid_uarch.h" but we cannot include the whole header file because +// some of its symbols are conflict with +void decodeMIDR(uint32_t midr, uint32_t uarch[1]); +} // namespace onnxruntime +#endif #else #include "core/common/cpuid_uarch.h" #endif // CPUINFO_SUPPORTED @@ -142,11 +149,6 @@ void CPUIDInfo::ArmLinuxInit() { // Pytorch CPUINFO only works on ARM linux or android // Assuming no hyper-threading, no NUMA groups #ifdef CPUINFO_SUPPORTED - pytorch_cpuinfo_init_ = cpuinfo_initialize(); - if (!pytorch_cpuinfo_init_) { - LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features."; - return; - } is_hybrid_ = cpuinfo_get_uarchs_count() > 1; has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); @@ -239,52 +241,24 @@ void CPUIDInfo::ArmWindowsInit() { lastUarch = uarch; } } - - switch (lastUarch) { - case cpuinfo_uarch_cortex_a55: - case cpuinfo_uarch_cortex_a55r0: - case cpuinfo_uarch_cortex_a76: - case cpuinfo_uarch_neoverse_n1: - case cpuinfo_uarch_cortex_a77: - case cpuinfo_uarch_exynos_m4: - case cpuinfo_uarch_exynos_m5: - has_fp16_ = true; - break; - default: - break; - } - if (!has_fp16_) { - /* - * Detecting fp16 support. Different cores should have the same instruction set. - * So we just check the first ID_AA64PFR0_EL1 - * Op0(0b11), Op1(0b000), CRn(0b0000), CRm(0b0100), Op2(0b000), - */ - uint64_t ID_AA64PFR0_EL1; - unsigned long valsize = sizeof(uint64_t); - auto retCode = ::RegGetValueA( - HKEY_LOCAL_MACHINE, - "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", - "CP 4020", RRF_RT_REG_QWORD, nullptr, - &ID_AA64PFR0_EL1, &valsize); - if (retCode == ERROR_SUCCESS) { - // AdvSIMD, bits [23:20] - auto advSimd = ID_AA64PFR0_EL1 >> 20; - if ((advSimd & 0xfULL) == 1) { - has_fp16_ = true; - } - } - } #endif /* Application Family or OneCore Family */ has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); #else has_arm_neon_dot_ = false; #endif - has_fp16_ |= has_arm_neon_dot_; - /* TODO: implement them when hw+sw is available for testing these features */ - has_arm_neon_i8mm_ = false; - has_arm_sve_i8mm_ = false; - has_arm_neon_bf16_ = false; + + if (pytorch_cpuinfo_init_) { + has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); + has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + } else { + has_fp16_ = false; + has_arm_neon_i8mm_ = false; + has_arm_sve_i8mm_ = false; + has_arm_neon_bf16_ = false; + } } #endif /* (arm or arm64) and windows */ @@ -304,5 +278,21 @@ uint32_t CPUIDInfo::GetCurrentCoreIdx() const { return 0xFFFFFFFF; // don't know how to get core index #endif } - +CPUIDInfo::CPUIDInfo() { +#ifdef CPUIDINFO_ARCH_X86 + X86Init(); +#elif defined(CPUIDINFO_ARCH_ARM) +#if CPUINFO_SUPPORTED + pytorch_cpuinfo_init_ = cpuinfo_initialize(); + if (!pytorch_cpuinfo_init_) { + LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features."; + } +#endif +#ifdef __linux__ + ArmLinuxInit(); +#elif defined(_WIN32) + ArmWindowsInit(); +#endif /* (arm or arm64) and windows */ +#endif +} } // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 2f8041e39f680..a3936b4bd11a6 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -93,17 +93,7 @@ class CPUIDInfo { } private: - CPUIDInfo() { -#ifdef CPUIDINFO_ARCH_X86 - X86Init(); -#elif defined(CPUIDINFO_ARCH_ARM) -#ifdef __linux__ - ArmLinuxInit(); -#elif defined(_WIN32) - ArmWindowsInit(); -#endif /* (arm or arm64) and windows */ -#endif - } + CPUIDInfo(); bool has_amx_bf16_{false}; bool has_avx_{false}; bool has_avx2_{false}; @@ -131,11 +121,13 @@ class CPUIDInfo { #ifdef CPUIDINFO_ARCH_X86 void X86Init(); - #elif defined(CPUIDINFO_ARCH_ARM) + // Now the following var is only used in ARM build, but later one we may expand the usage. + bool pytorch_cpuinfo_init_{false}; +#endif + #ifdef __linux__ - bool pytorch_cpuinfo_init_{false}; void ArmLinuxInit(); #elif defined(_WIN32) @@ -143,7 +135,6 @@ class CPUIDInfo { void ArmWindowsInit(); #endif /* (arm or arm64) and windows */ -#endif }; } // namespace onnxruntime From de3158e78d09992e4b5085c15da44108d9c6fa83 Mon Sep 17 00:00:00 2001 From: zesongw Date: Sat, 2 Mar 2024 08:55:50 +0800 Subject: [PATCH 184/207] [WebNN EP] Add contraints for MatMul (#19713) ### Description Add constraints to MatMul: - The input must be at least 2D. - CPU backend: The input rank must be the same. - CPU backend: The input shape except for the last two axis must be the same. ### Motivation and Context Prevent regression for some models. --- .../webnn/builders/impl/gemm_op_builder.cc | 73 +++++++++++-------- 1 file changed, 43 insertions(+), 30 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index d5f84f853f7de..455e0e5f16a42 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -91,44 +91,33 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType /* device_type */, + const WebnnDeviceType device_type, const logging::Logger& logger) const { (void)initializers; const auto& op_type = node.OpType(); const auto& input_defs(node.InputDefs()); const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C - if (op_type == "Gemm") { - std::vector a_shape; - { - if (!GetShape(*input_defs[a_idx], a_shape, logger)) - return false; - - if (a_shape.size() != 2) { - LOGS(logger, VERBOSE) << "A must be 2D"; - return false; - } - - if (Product(a_shape) == 0) { - LOGS(logger, VERBOSE) << "A must be non-empty"; - return false; - } - } - - std::vector b_shape; - { - if (!GetShape(*input_defs[b_idx], b_shape, logger)) - return false; + std::vector a_shape; + if (!GetShape(*input_defs[a_idx], a_shape, logger)) + return false; + if (Product(a_shape) == 0) { + LOGS(logger, VERBOSE) << "A must be non-empty"; + return false; + } - if (b_shape.size() != 2) { - LOGS(logger, VERBOSE) << "B must be 2D"; - return false; - } + std::vector b_shape; + if (!GetShape(*input_defs[b_idx], b_shape, logger)) + return false; + if (Product(b_shape) == 0) { + LOGS(logger, VERBOSE) << "B must be non-empty"; + return false; + } - if (Product(b_shape) == 0) { - LOGS(logger, VERBOSE) << "B must be non-empty"; - return false; - } + if (op_type == "Gemm") { + if (a_shape.size() != 2 || b_shape.size() != 2) { + LOGS(logger, VERBOSE) << "A and B must be 2D for Gemm"; + return false; } // C of Gemm. @@ -162,6 +151,30 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } + if (op_type == "MatMul") { + if (a_shape.size() < 2 || b_shape.size() < 2) { + LOGS(logger, VERBOSE) << "Inputs of MatMul must be at least 2D"; + return false; + } + + // WebNN CPU backend has two more constraints. + // https://source.chromium.org/chromium/chromium/src/+/main:third_party/blink/renderer/modules/ml/webnn/ml_graph_xnnpack.cc;l=1177 + // TODO: Remove this workaround when Chromium enables broadcast for MatMul on WebNN CPU backend. + if (device_type == WebnnDeviceType::CPU) { + if (a_shape.size() != b_shape.size()) { + LOGS(logger, VERBOSE) << "The rank of two inputs for WebNN CPU backend MatMul must be the same."; + return false; + } + + for (size_t i = 0; i < a_shape.size() - 2; i++) { + if (a_shape[i] != b_shape[i]) { + LOGS(logger, VERBOSE) << "WebNN CPU backend can't support broadcasting for MatMul."; + return false; + } + } + } + } + return true; } From 2d79052ec38b831f3254b20e0f6a42b3f98eabc7 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Fri, 1 Mar 2024 18:39:51 -0800 Subject: [PATCH 185/207] [QNN Quant] Add preprocessing option to transpose graph inputs/outputs to channel-last (#19731) ### Description Adds the optional parameters `inputs_to_make_channel_last` and `outputs_to_make_channel_last` to the `qnn_preprocess_model()` function. ```python """ inputs_to_make_channel_last: List of graph input names to transpose to be "channel-last". For example, if "input0" originally has the shape (N, C, D1, D2, ..., Dn), the resulting model will change input0's shape to (N, D1, D2, ..., Dn, C) and add a transpose node after it. Original: input0 (N, C, D1, D2, ..., Dn) --> Updated: input0 (N, D1, D2, ..., Dn, C) --> Transpose --> input0_chanfirst (N, C, D1, D2, ..., Dn) --> This can potentially improve inference latency for QDQ models running on QNN EP because the additional transpose node may allow other transpose nodes inserted during ORT layout transformation to cancel out. outputs_to_make_channel_last: List of graph output names to transpose to be "channel-last". For example, if "output0" originally has the shape (N, C, D1, D2, ..., Dn), the resulting model will change output0's shape to (N, D1, D2, ..., Dn, C) and add a transpose node before it. Original: --> output0 (N, C, D1, D2, ..., Dn) Updated: --> output0_chanfirst (N, C, D1, D2, ..., Dn) --> Transpose --> output0 (N, D1, D2, ..., Dn, C) This can potentially improve inference latency for QDQ models running on QNN EP because the additional transpose node may allow other transpose nodes inserted during ORT layout transformation to cancel out. """ ``` **NOTE: If you use these options with the quantization scripts, you'll have to make sure your data_reader feeds in transposed input data. It won't happen automatically.** ### Motivation and Context Native QNN operators use the channel-last data layout, but ONNX uses channel-first. To bridge the gap, ORT's layout transformer inserts transposes around layout-sensitive nodes and updates their domain to indicate that they now operate on channel-last data. The transpose optimizer is able to remove most of these inserted transposes, but not all transposes can always be removed (i.e., some could remain at the graph's inputs and outputs). We've found that these extra transpose nodes can significantly degrade inference latency on QNN EP. One workaround (provided by this PR) is to add _additional_ transpose nodes at the graph inputs or outputs. These additional nodes can often help the ORT transpose optimizer cancel out any remaining transpose nodes, which significantly improves latency. Additionally, it may make more sense for some kinds of inputs to just be in channel-last form (e.g., images), avoiding the need to pre-transpose of the input data before inference. Example at the input: ``` Original: input0 (N, C, D1, D2, ..., Dn) --> Updated: input0 (N, D1, D2, ..., Dn, C) --> Transpose --> input0_chanfirst (N, C, D1, D2, ..., Dn) --> ``` Example at the output: ``` Original: --> output0 (N, C, D1, D2, ..., Dn) Updated: --> output0_chanfirst (N, C, D1, D2, ..., Dn) --> Transpose --> output0 (N, D1, D2, ..., Dn, C) ``` --- .../execution_providers/qnn/preprocess.py | 198 ++++++++++++++++++ .../quantization/test_qnn_preprocess_model.py | 93 ++++++++ 2 files changed, 291 insertions(+) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index b0dab81830c8b..e584a65574520 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -24,6 +24,8 @@ def qnn_preprocess_model( external_data_location: str | None = None, external_data_size_threshold: int = 1024, external_data_convert_attribute: bool = False, + inputs_to_make_channel_last: list[str] | None = None, + outputs_to_make_channel_last: list[str] | None = None, ) -> bool: """ If necessary, this method creates a new "pre-processed" model in preparation for @@ -52,6 +54,32 @@ def qnn_preprocess_model( external_data_convert_attribute: Effective only if save_as_external_data is true. Defaults to false. If true, convert all tensors to external data. If false, convert only non-attribute tensors to external data. + inputs_to_make_channel_last: List of graph input names to transpose to be "channel-last". For example, + if "input0" originally has the shape (N, C, D1, D2, ..., Dn), the resulting model will change input0's + shape to (N, D1, D2, ..., Dn, C) and add a transpose node after it. + + Original: + input0 (N, C, D1, D2, ..., Dn) --> + + Updated: + input0 (N, D1, D2, ..., Dn, C) --> Transpose --> input0_chanfirst (N, C, D1, D2, ..., Dn) --> + + This can potentially improve inference latency for QDQ models running on QNN EP because the + additional transpose node may allow other transpose nodes inserted during ORT layout transformation + to cancel out. + outputs_to_make_channel_last: List of graph output names to transpose to be "channel-last". For example, + if "output0" originally has the shape (N, C, D1, D2, ..., Dn), the resulting model will change output0's + shape to (N, D1, D2, ..., Dn, C) and add a transpose node before it. + + Original: + --> output0 (N, C, D1, D2, ..., Dn) + + Updated: + --> output0_chanfirst (N, C, D1, D2, ..., Dn) --> Transpose --> output0 (N, D1, D2, ..., Dn, C) + + This can potentially improve inference latency for QDQ models running on QNN EP because the + additional transpose node may allow other transpose nodes inserted during ORT layout transformation + to cancel out. """ modified = False model = onnx.load_model(model_input) @@ -83,6 +111,19 @@ def qnn_preprocess_model( if fusion_layernorm.apply(): modified = True + # Optionally, transpose inputs and/or outputs to make them "channel-last". + if inputs_to_make_channel_last or outputs_to_make_channel_last: + transpose_node_prefix = "Transpose_channel_" + transpose_node_suffix: int = onnx_model.get_largest_node_name_suffix(transpose_node_prefix) + 1 + update_io_to_channel_last( + onnx_model.model, + inputs_to_make_channel_last, + outputs_to_make_channel_last, + transpose_node_name_prefix=transpose_node_prefix, + transpose_node_name_start_suffix=transpose_node_suffix, + ) + modified = True + # Make sure all nodes have a name. unnamed_node_prefix = "qnn_preproc_node_" available_suffix = onnx_model.get_largest_node_name_suffix(unnamed_node_prefix) + 1 @@ -107,3 +148,160 @@ def qnn_preprocess_model( ) return modified + + +class InputOutputNameMap: + def __init__( + self, + orig_tensor_names: set[str], + orig_graph_inputs: dict[str, onnx.ValueInfoProto], + orig_graph_outputs: dict[str, onnx.ValueInfoProto], + ): + self.orig_tensor_names = orig_tensor_names + self.orig_graph_inputs = orig_graph_inputs + self.orig_graph_outputs = orig_graph_outputs + self.updated_io_names = {} + self.new_value_infos = [] + + def get_new_name(self, orig_name: str): + if orig_name in self.updated_io_names: + return self.updated_io_names[orig_name] + + # Make a new tensor name that is unique among all tensors in the graph. + prefix: str = f"{orig_name}_channel_first_" + suffix: int = -1 + for tensor_name in self.orig_tensor_names: + if tensor_name.startswith(prefix) and tensor_name[len(prefix) :].isdigit(): + index = int(tensor_name[len(prefix) :]) + suffix = max(suffix, index) + + suffix += 1 # This is the first available suffix. + new_name = f"{prefix}{suffix!s}" + + # Add new value_info objects for these new tensors. + orig_value_info = self.orig_graph_inputs.get(orig_name) or self.orig_graph_outputs[orig_name] + value_info_proto = onnx.ValueInfoProto() + value_info_proto.CopyFrom(orig_value_info) + value_info_proto.name = new_name + self.new_value_infos.append(value_info_proto) + + self.updated_io_names[orig_name] = new_name + return self.updated_io_names[orig_name] + + +def update_io_to_channel_last( + model: onnx.ModelProto, + inputs_to_update: list[str] | None, + outputs_to_update: list[str] | None, + transpose_node_name_prefix: str = "Transpose_channel_", + transpose_node_name_start_suffix: int = 0, +): + inputs_to_update = set(inputs_to_update or []) + outputs_to_update = set(outputs_to_update or []) + + if not inputs_to_update and not outputs_to_update: + return + + graph = model.graph + orig_graph_inputs = {ginput.name: ginput for ginput in graph.input} + orig_graph_outputs = {goutput.name: goutput for goutput in graph.output} + + # Check that the user passed in actual input and output names. + for input_name in inputs_to_update: + if input_name not in orig_graph_inputs: + raise ValueError(f"{input_name} is not a graph input") + + for output_name in outputs_to_update: + if output_name not in orig_graph_outputs: + raise ValueError(f"{output_name} is not a graph output") + + orig_tensor_names = set() + orig_tensor_names.update(set(orig_graph_inputs)) + orig_tensor_names.update(set(orig_graph_outputs)) + orig_tensor_names.update(input_name for node in graph.node for input_name in node.input if input_name) + + # Maps original input (or output) name to its updated name used within the graph. + io_map = InputOutputNameMap(orig_tensor_names, orig_graph_inputs, orig_graph_outputs) + + # Update each node's inputs/outputs to use the transposed versions. + for node in graph.node: + for i in range(len(node.input)): + if node.input[i] and node.input[i] in inputs_to_update: + node.input[i] = io_map.get_new_name(node.input[i]) + elif node.input[i] and node.input[i] in outputs_to_update: + node.input[i] = io_map.get_new_name(node.input[i]) + + for i in range(len(node.output)): + if node.output[i] in outputs_to_update: + node.output[i] = io_map.get_new_name(node.output[i]) + + # Update graph inputs to channel-last and a Transpose (to channel-first) after each. + for g_input_name in inputs_to_update: + g_input = orig_graph_inputs[g_input_name] + + if not g_input.type.HasField("tensor_type") or not g_input.type.tensor_type.HasField("shape"): + raise ValueError(f"Expected input {g_input.name} to have a tensor_type with a shape") + + input_shape = g_input.type.tensor_type.shape + input_rank = len(input_shape.dim) + + if input_rank < 3: + raise ValueError(f"Expected input {g_input.name} to be of rank >= 3") + + channel_dim = onnx.TensorShapeProto.Dimension() + channel_dim.CopyFrom(input_shape.dim[1]) + for i in range(1, input_rank - 1): + input_shape.dim[i].CopyFrom(input_shape.dim[i + 1]) + input_shape.dim[input_rank - 1].CopyFrom(channel_dim) + + transpose_perm = list(range(input_rank)) + for i in range(input_rank): + transpose_perm[i] = i if i < 1 else i - 1 + transpose_perm[1] = input_rank - 1 + + transpose_node = onnx.helper.make_node( + "Transpose", + name=f"{transpose_node_name_prefix}{transpose_node_name_start_suffix!s}", + inputs=[g_input.name], + outputs=[io_map.get_new_name(g_input.name)], + perm=transpose_perm, + ) + transpose_node_name_start_suffix += 1 + + graph.node.extend([transpose_node]) + + # Update graph outputs to channel-last and a Transpose (from channel-first) before each. + for g_output_name in outputs_to_update: + g_output = orig_graph_outputs[g_output_name] + if not g_output.type.HasField("tensor_type") or not g_output.type.tensor_type.HasField("shape"): + raise ValueError(f"Expected output {g_output.name} to have a tensor_type with a shape") + + output_shape = g_output.type.tensor_type.shape + output_rank = len(output_shape.dim) + + if output_rank < 3: + raise ValueError(f"Expected output {g_output.name} to be of rank >= 3") + + channel_dim = onnx.TensorShapeProto.Dimension() + channel_dim.CopyFrom(output_shape.dim[1]) + for i in range(1, output_rank - 1): + output_shape.dim[i].CopyFrom(output_shape.dim[i + 1]) + output_shape.dim[output_rank - 1].CopyFrom(channel_dim) + + transpose_perm = list(range(output_rank)) + for i in range(output_rank): + transpose_perm[i] = i if i == 0 else i + 1 + transpose_perm[output_rank - 1] = 1 + + transpose_node = onnx.helper.make_node( + "Transpose", + name=f"{transpose_node_name_prefix}{transpose_node_name_start_suffix!s}", + inputs=[io_map.get_new_name(g_output.name)], + outputs=[g_output.name], + perm=transpose_perm, + ) + transpose_node_name_start_suffix += 1 + + graph.node.extend([transpose_node]) + + graph.value_info.extend(io_map.new_value_infos) diff --git a/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py index 9b67fd41caac3..6503b3223b828 100644 --- a/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py +++ b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py @@ -12,6 +12,7 @@ import numpy as np import onnx +import onnxruntime from onnxruntime.quantization.execution_providers.qnn import qnn_preprocess_model from onnxruntime.quantization.quant_utils import model_has_external_data, ms_domain @@ -165,6 +166,98 @@ def test_external_data(self): for node in fused_model.graph.node: self.assertIn(node.op_type, expected_op_types) + def build_multi_input_output_model(self, shape): + """ + Returns the following model. + +----------> [X] + | + [A] ---> Add ---> Abs -+-> Mul ---> [Y] + ^ ^ + | | + [B] ------+-----------------+ + """ + input_a = onnx.helper.make_tensor_value_info("A", onnx.TensorProto.FLOAT, shape) + input_b = onnx.helper.make_tensor_value_info("B", onnx.TensorProto.FLOAT, shape) + output_x = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, shape) + output_y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, shape) + + add_node = onnx.helper.make_node("Add", ["A", "B"], ["add_out"], name="add_node") + abs_node = onnx.helper.make_node("Abs", ["add_out"], ["X"], name="abs_node") + mul_node = onnx.helper.make_node("Mul", ["X", "B"], ["Y"], name="mul_node") + + graph = onnx.helper.make_graph( + [add_node, abs_node, mul_node], + "multi_io_graph", + [input_a, input_b], + [output_x, output_y], + ) + opset_imports = [ + onnx.helper.make_opsetid("", 18), + ] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + return onnx.shape_inference.infer_shapes(model) + + def test_make_io_channel_last(self): + """ + Test making a model's inputs and outputs channel-last. + """ + model = self.build_multi_input_output_model((1, 2, 3, 4)) + onnx.save_model(model, "model.onnx") + modified = qnn_preprocess_model( + "model.onnx", + "model.qnn_pp.onnx", + inputs_to_make_channel_last=["A", "B"], + outputs_to_make_channel_last=["X", "Y"], + ) + + self.assertTrue(modified) + + preproc_model = onnx.load_model("model.qnn_pp.onnx") + self.assertEqual(len(preproc_model.graph.node), 7) + + num_transposes = sum(1 for node in preproc_model.graph.node if node.op_type == "Transpose") + self.assertEqual(num_transposes, 4) + + # Check that the outputs of the new model are the same, but transposed. + input_a = np.arange(0.0, 24.0, 1.0, dtype=np.float32).reshape((1, 2, 3, 4)) + input_a_t = input_a.transpose(0, 2, 3, 1) + input_b = np.arange(1.0, 25.0, 1.0, dtype=np.float32).reshape((1, 2, 3, 4)) + input_b_t = input_b.transpose(0, 2, 3, 1) + + orig_session = onnxruntime.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"]) + orig_results = orig_session.run(None, {"A": input_a, "B": input_b}) + + new_session = onnxruntime.InferenceSession( + preproc_model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + new_results = new_session.run(None, {"A": input_a_t, "B": input_b_t}) + + self.assertEqual(len(orig_results), len(new_results)) + for idx, orig_output in enumerate(orig_results): + transposed_output = new_results[idx] + np.testing.assert_allclose( + orig_output, + transposed_output.transpose(0, 3, 1, 2), + err_msg=f"Channel-last model output {idx} differs", + ) + + def test_make_io_channel_last_rank_error(self): + """ + Test making a model's inputs and outputs channel-last with a rank < 3 (error). + """ + model = self.build_multi_input_output_model((1, 2)) + onnx.save_model(model, "model.onnx") + + with self.assertRaises(ValueError) as context: + qnn_preprocess_model( + "model.onnx", + "model.qnn_pp.onnx", + inputs_to_make_channel_last=["A", "B"], + outputs_to_make_channel_last=["X", "Y"], + ) + + self.assertIn("to be of rank >= 3", str(context.exception)) + if __name__ == "__main__": unittest.main() From 9460597b2103d8d07e88272b9f4e19700d71d632 Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Sat, 2 Mar 2024 11:33:47 +0800 Subject: [PATCH 186/207] Update copying API header files (#19736) ### Description Make Linux logic consistent as Windows ### Motivation and Context onnxruntime_lite_custom_op.h in Windows zip package but not in Linux zip package https://github.com/microsoft/onnxruntime/blob/acbfc29f272b5578145e7600bc42342e116ffbc2/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml#L67 Co-authored-by: Your Name --- tools/ci_build/github/linux/copy_strip_binary.sh | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tools/ci_build/github/linux/copy_strip_binary.sh b/tools/ci_build/github/linux/copy_strip_binary.sh index 42973a8fcb5b8..65d6d97ebf0a8 100755 --- a/tools/ci_build/github/linux/copy_strip_binary.sh +++ b/tools/ci_build/github/linux/copy_strip_binary.sh @@ -44,17 +44,10 @@ elif [[ $LIB_NAME == *.so.* ]] then ln -s $LIB_NAME $BINARY_DIR/$ARTIFACT_NAME/lib/libonnxruntime.so fi -cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_c_api.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_api.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_inline.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_float16.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include +cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_*.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/framework/provider_options.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h $BINARY_DIR/$ARTIFACT_NAME/include -cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h $BINARY_DIR/$ARTIFACT_NAME/include +cp $SOURCE_DIR/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h $BINARY_DIR/$ARTIFACT_NAME/include +cp $SOURCE_DIR/orttraining/orttraining/training_api/include/onnxruntime_training_*.h $BINARY_DIR/$ARTIFACT_NAME/include if [[ -f "$BINARY_DIR/$BUILD_CONFIG/libonnxruntime_providers_cuda.so" ]]; then # copy headers for context context used in custom ops From 9acaf534a62050705d9b892a57ef0e8409fa62ec Mon Sep 17 00:00:00 2001 From: ironman Date: Mon, 4 Mar 2024 23:29:58 +0800 Subject: [PATCH 187/207] Benchmark - Updating llama-2 requirement files (#19716) ### Description ### Motivation and Context --- .../tools/transformers/models/llama/requirements-cuda.txt | 1 + .../python/tools/transformers/models/llama/requirements.txt | 3 ++- .../python/tools/transformers/models/whisper/requirements.txt | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt index acd9c23aa42d0..307afbc122901 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt @@ -2,3 +2,4 @@ # Please manually install torch>=2.2.0 with CUDA enabled for the CUDA version installed in your system. # Instructions can be found here: https://pytorch.org/get-started/locally/ onnxruntime-gpu>=1.16.2 +py3nvml \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index 8b57279295e35..e991c2f27a1a3 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,6 +1,7 @@ optimum>=1.14.1 -transformers>=4.33.2 +transformers>=4.33.2,<= 4.37.2 torch>=2.2.0 onnx>=1.14.0 datasets>=2.8.0 protobuf==3.20.2 +psutil \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index 956922dc83d51..9bbe0d7380406 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -7,8 +7,8 @@ soundfile librosa optimum onnxruntime-extensions>=0.9.0 +onnx>=1.15.0 protobuf==3.20.2 numpy==1.23.3 -onnx>=1.15.0 psutil py3nvml From 2e13d5f0ab54c726ee2400d38983000de7f61b8e Mon Sep 17 00:00:00 2001 From: inisis <46103969+inisis@users.noreply.github.com> Date: Tue, 5 Mar 2024 01:41:36 +0800 Subject: [PATCH 188/207] fix split shape inference error for opset >= 13 (#19756) ### Description get split operator split section by opset ### Motivation and Context for opset higher than 13, split section is treated as an input. --- onnxruntime/python/tools/symbolic_shape_infer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 4b56bc1e8d828..4b029f9b172b0 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -1940,8 +1940,17 @@ def _infer_SoftmaxCrossEntropyLoss(self, node): # noqa: N802 def _infer_Split_Common(self, node, make_value_info_func): # noqa: N802 input_sympy_shape = self._get_sympy_shape(node, 0) axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape)) - split = get_attribute(node, "split") - if not split: + op_set = get_opset(self.out_mp_) + + # Depending on op-version 'split' are provided as attribute or via 2nd input + if op_set < 13: + split = get_attribute(node, "split") + assert self._try_get_value(node, 1) is None + else: + split = self._try_get_value(node, 1) + assert get_attribute(node, "split") is None + + if split is None: num_outputs = len(node.output) split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs self._update_computed_dims(split) From 27b1dc91abb71b71fe6a26e1b4ebd30e13524baf Mon Sep 17 00:00:00 2001 From: raoanag <127366241+raoanag@users.noreply.github.com> Date: Mon, 4 Mar 2024 11:55:35 -0800 Subject: [PATCH 189/207] [DML] MatrixMultiplyIntegerToFloat (#19608) ### Description DML Implementation for [com.microsoft.MatMulIntegerToFloat](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.MatMulIntegerToFloat) ``` .\onnxruntime_test_all.exe --gtest_filter="*MatMulIntegerToFloat.*" Note: Google Test filter = *MatMulIntegerToFloat.* [==========] Running 22 tests from 1 test suite. [----------] Global test environment set-up. [----------] 22 tests from MatMulIntegerToFloat [ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8S8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8S8 (620 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8S8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8S8 (497 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8S8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8S8 (488 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8S8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8S8 (503 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8U8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8U8 (495 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8U8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8U8 (488 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8U8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8U8 (492 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8X8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8X8 (502 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8U8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_S8U8 (452 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8U8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_S8U8 (454 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8U8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_S8U8 (446 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8U8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_S8U8 (508 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8S8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_NoBias_test_U8S8 (456 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8S8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_HasBias_test_U8S8 (455 ms) [ RUN ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8S8 [ OK ] MatMulIntegerToFloat.NoZeroPoint_NoBias_test_U8S8 (447 ms) [ RUN ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8S8 [ OK ] MatMulIntegerToFloat.HasZeroPoint_HasBias_test_U8S8 (465 ms) [ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8U8 [ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8U8 (111 ms) [ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8S8 [ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_U8S8 (115 ms) [ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8S8 [ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8S8 (114 ms) [ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8U8 [ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16_S8U8 (110 ms) [ RUN ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16 [ OK ] MatMulIntegerToFloat.MatMulIntegerToFloat_FP16 (112 ms) [ RUN ] MatMulIntegerToFloat.MatMulInteger_With_ZeroPoint [ OK ] MatMulIntegerToFloat.MatMulInteger_With_ZeroPoint (337 ms) [----------] 22 tests from MatMulIntegerToFloat (8679 ms total) [----------] Global test environment tear-down [==========] 22 tests from 1 test suite ran. (8680 ms total) [ PASSED ] 22 tests. memleakdbg: ----- No memory leaks detected ----- ``` ### Motivation and Context * `CalculateMatMulIntegerToFloat` to replace CPU EP run reference * Added more FP32 testcases to isolate all input datatype combinations * Added fixed input to `MatMulIntegerToFloat_FP16*` test cases as for FP16 test cases. * onnxruntime/test/testdata/matmul_integer_to_float.py` is capable of generating FP16 models, but we do not produce any for now --- docs/ContribOperators.md | 2 +- docs/OperatorKernels.md | 1 + .../graph/contrib_ops/quantization_defs.cc | 2 +- .../core/optimizer/graph_transformer_utils.cc | 5 +- .../core/optimizer/matmul_integer_to_float.cc | 23 +- .../src/External/DirectMLHelpers/ApiTraits.h | 12 +- .../External/DirectMLHelpers/DirectMLSchema.h | 37 +- .../DirectMLHelpers/GeneratedSchemaHelpers.h | 36 +- .../DmlOperatorMatMulIntegerToFloat.cpp | 111 +++++ .../src/Operators/OperatorRegistration.cpp | 9 + .../dml/OperatorAuthorHelper/OperatorHelper.h | 2 +- .../OperatorAuthorHelper/OperatorVersions.h | 1 + .../matmul_integer_to_float_test.cc | 414 +++++++++++++++--- .../test/optimizer/graph_transform_test.cc | 18 + .../test/testdata/matmul_integer_to_float.py | 60 ++- .../matmul_integer_to_float_int8.onnx | 4 +- .../matmul_integer_to_float_int8_bias.onnx | 4 +- .../matmul_integer_to_float_int8_int8.onnx | 4 +- ...atmul_integer_to_float_int8_int8_bias.onnx | 4 +- .../matmul_integer_to_float_uint8.onnx | 4 +- .../matmul_integer_to_float_uint8_bias.onnx | 4 +- .../fusion/matmul_integer_to_float.onnx | Bin 1520 -> 1520 bytes .../matmul_integer_to_float16_int8.onnx | 51 +++ 23 files changed, 664 insertions(+), 144 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp create mode 100644 onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f523e97293427..e295dfa203ae5 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2795,7 +2795,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Constrain input A data type to 8-bit integer tensor.
T2 : tensor(int8), tensor(uint8)
Constrain input B data type to 8-bit integer tensor.
-
T3 : tensor(float)
+
T3 : tensor(float), tensor(float16)
Constrain input a_scale, b_scale and output Y data type as float tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1eaf0fb6dad76..0e60b4622f2fb 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1268,6 +1268,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| +|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)| diff --git a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc index 4313fae767fe5..22a79ef652515 100644 --- a/onnxruntime/core/graph/contrib_ops/quantization_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/quantization_defs.cc @@ -434,7 +434,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Output(0, "Y", "Matrix multiply results from A * B", "T3") .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input A data type to 8-bit integer tensor.") .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input B data type to 8-bit integer tensor.") - .TypeConstraint("T3", {"tensor(float)"}, + .TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input a_scale, b_scale and output Y data type as float tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 2, 0); diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 8376b87aee6b2..f319e7254568d 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -278,7 +278,8 @@ InlinedVector> GenerateTransformers( onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider, onnxruntime::kJsExecutionProvider}; - + const InlinedHashSet cpu_dml_eps = {onnxruntime::kCpuExecutionProvider, + onnxruntime::kDmlExecutionProvider}; #ifdef MLAS_TARGET_AMD64_IX86 const bool avx2_precision_mode = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsAvx2PrecisionMode, "0") == "1" && MlasPlatformU8S8Overflow(); @@ -296,7 +297,7 @@ InlinedVector> GenerateTransformers( } transformers.emplace_back(std::make_unique(cpu_ep)); - transformers.emplace_back(std::make_unique(cpu_ep)); + transformers.emplace_back(std::make_unique(cpu_dml_eps)); transformers.emplace_back(std::make_unique(cpu_ep)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_acl_armnn_js_eps)); diff --git a/onnxruntime/core/optimizer/matmul_integer_to_float.cc b/onnxruntime/core/optimizer/matmul_integer_to_float.cc index 56e51cb787931..4fee1a6ce224e 100644 --- a/onnxruntime/core/optimizer/matmul_integer_to_float.cc +++ b/onnxruntime/core/optimizer/matmul_integer_to_float.cc @@ -31,6 +31,24 @@ static bool CheckBiasShape(const TensorShapeProto* bias_shape) { return bias_last_dim > 1; } +bool HasElementDataType(const NodeArg& node_arg, int32_t data_type) { + if (!node_arg.Exists()) { + return false; + } + + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto) { + return false; + } + + int32_t actual_data_type; + if (!utils::TryGetElementDataType(*type_proto, actual_data_type)) { + return false; + } + + return data_type == actual_data_type; +} + /** MatMulIntegerToFloatFusion will fuse subgraph like below into MatMulIntegerToFloat: @@ -63,9 +81,10 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g auto& mul_node = *node_ptr; ORT_RETURN_IF_ERROR(Recurse(mul_node, modified, graph_level, logger)); - + const bool is_dml_ep = node_ptr->GetExecutionProviderType() == kDmlExecutionProvider; if (!graph_utils::IsSupportedOptypeVersionAndDomain(mul_node, "Mul", {7, 13, 14}) || - !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders())) { + !graph_utils::IsSupportedProvider(mul_node, GetCompatibleExecutionProviders()) || + (!is_dml_ep && HasElementDataType(*mul_node.InputDefs()[0], ONNX_NAMESPACE::TensorProto_DataType_FLOAT16))) { continue; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index e1e7eacfbd85d..7c25755a7d09e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -879,6 +879,12 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT; +}; + template <> struct OperatorDescTraits { @@ -1041,12 +1047,6 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; }; -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT; -}; - template <> struct OperatorDescTraits { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 5fe6603c2a0bf..da57c2aa235fd 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -1885,6 +1885,25 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHE DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA { + "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", + static_cast(DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT), + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA_FIELDS[11] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, @@ -2395,24 +2414,6 @@ constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHE DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, }; -constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA { - "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", - DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 8, - DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS, -}; constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 4be41ad3924a2..86c66d8cca26c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -1139,6 +1139,19 @@ inline std::vector GetFields(const DML_QUANTIZED_LINEAR_MATRIX_MU OperatorField(&DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} inline std::vector GetFields(const DML_CONVOLUTION_INTEGER_OPERATOR_DESC& desc) { return { @@ -1487,19 +1500,6 @@ inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_P OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), }; } -inline std::vector GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), - OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputTensor))), - }; -} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1829,6 +1829,7 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_RESAMPLE1: return DML_RESAMPLE1_OPERATOR_SCHEMA; case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER: return DML_MATRIX_MULTIPLY_INTEGER_OPERATOR_SCHEMA; case DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY: return DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA; + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA; case DML_OPERATOR_CONVOLUTION_INTEGER: return DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA; case DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION: return DML_QUANTIZED_LINEAR_CONVOLUTION_OPERATOR_SCHEMA; case DML_OPERATOR_ELEMENT_WISE_BIT_AND: return DML_ELEMENT_WISE_BIT_AND_OPERATOR_SCHEMA; @@ -1856,7 +1857,6 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_DIAGONAL_MATRIX1: return DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA; case DML_OPERATOR_MULTIHEAD_ATTENTION: return DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA; case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA; - case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; @@ -2360,6 +2360,10 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_QUANTIZED_LINEAR_MATRIX_MULTIPLY_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: + return AbstractOperatorDesc( + &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_CONVOLUTION_INTEGER: return AbstractOperatorDesc( &DML_CONVOLUTION_INTEGER_OPERATOR_SCHEMA, @@ -2468,10 +2472,6 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); - case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: - return AbstractOperatorDesc( - &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp new file mode 100644 index 0000000000000..b5a3dd0960b86 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMatMulIntegerToFloat.cpp @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ + +class DmlOperatorMatMulIntegerToFloat : public DmlOperator +{ + enum OrtInputTensors : uint32_t + { + ortA, + ortB, + ortAScale, + ortBScale, + ortAZeroPoint, + ortBZeroPoint, + ortBias, + ortInputCount + }; + + enum DmlInputIndex : uint32_t + { + dmlA, + dmlAScale, + dmlAZeroPoint, + dmlB, + dmlBScale, + dmlBZeroPoint, + dmlBias, + dmlInputCount, + }; + +public: + DmlOperatorMatMulIntegerToFloat(const MLOperatorKernelCreationContext& kernelInfo) + : DmlOperator(kernelInfo) + { + std::vector> inputIndices = { OrtInputTensors::ortA, OrtInputTensors::ortAScale, OrtInputTensors::ortAZeroPoint, OrtInputTensors::ortB, OrtInputTensors::ortBScale, OrtInputTensors::ortBZeroPoint, OrtInputTensors::ortBias }; + DmlOperator::Initialize(kernelInfo, inputIndices); + + std::vector inputShape0 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortA); + std::vector inputShape1 = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortB); + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); + + OperatorHelper::MatMulShapeMapping(inputShape0, inputShape1, outputShape); + + // Initialize the input descriptions with broadcasting + m_inputTensorDescs[DmlInputIndex::dmlA] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortA, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape0); + m_inputTensorDescs[DmlInputIndex::dmlB] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortB, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, inputShape1); + + // Broadcast Bias tensor to the shape of the output tensor. + if(kernelInfo.IsInputValid(OrtInputTensors::ortBias)) { + m_inputTensorDescs[DmlInputIndex::dmlBias] = CreateTensorDescFromInput(kernelInfo, OrtInputTensors::ortBias, TensorAxis::DoNotCoerce, + TensorAxis::W, TensorAxis::RightAligned, outputShape); + } + + uint32_t dmlDimSize = m_inputTensorDescs[DmlInputIndex::dmlA].GetDimensionCount(); + // Resize the A Scale to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + m_inputTensorDescs[DmlInputIndex::dmlAScale] = CreateTensorDescFromInput( + kernelInfo, + OrtInputTensors::ortAScale, + TensorAxis::DoNotCoerce, + TensorAxis::H, + TensorAxis::LeftAligned, + std::nullopt, + dmlDimSize + ); + + // Resize the A ZeroPoint to be the same dimension as the input tensor. + // The 1D tensor needs to be moved to the H channel. + if (kernelInfo.IsInputValid(OrtInputTensors::ortAZeroPoint)) + { + m_inputTensorDescs[DmlInputIndex::dmlAZeroPoint] = CreateTensorDescFromInput( + kernelInfo, + OrtInputTensors::ortAZeroPoint, + TensorAxis::DoNotCoerce, + TensorAxis::H, + TensorAxis::LeftAligned, + std::nullopt, + dmlDimSize + ); + } + + // B Zeropoint and BScale are already aligned in the W dimension so no need to align them + + // Initialize the output description while overriding the shape + m_outputTensorDescs[0] = CreateTensorDescFromOutput(kernelInfo, 0, TensorAxis::DoNotCoerce, TensorAxis::W, TensorAxis::RightAligned, outputShape); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC matMulDesc = {}; + matMulDesc.ATensor = &inputDescs[DmlInputIndex::dmlA]; + matMulDesc.AScaleTensor = &inputDescs[DmlInputIndex::dmlAScale]; + matMulDesc.AZeroPointTensor = inputDescs[DmlInputIndex::dmlAZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlAZeroPoint] : nullptr; + matMulDesc.BTensor = &inputDescs[DmlInputIndex::dmlB]; + matMulDesc.BScaleTensor = &inputDescs[DmlInputIndex::dmlBScale]; + matMulDesc.BZeroPointTensor = inputDescs[DmlInputIndex::dmlBZeroPoint].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBZeroPoint] : nullptr; + matMulDesc.BiasTensor = inputDescs[DmlInputIndex::dmlBias].Desc != nullptr ? &inputDescs[DmlInputIndex::dmlBias] : nullptr; + matMulDesc.OutputTensor = &outputDescs[0]; + + DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, &matMulDesc }; + SetDmlOperatorDesc(opDesc, kernelInfo); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(MatMulIntegerToFloat, DmlOperatorMatMulIntegerToFloat); + +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 9c136ed8c9484..f08151b61197a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -503,6 +503,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul); DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat); DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear); DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger); +DML_OP_EXTERN_CREATION_FUNCTION(MatMulIntegerToFloat); DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger); DML_OP_EXTERN_CREATION_FUNCTION(Trilu); @@ -622,6 +623,13 @@ constexpr static std::array supportedTypeListQLinea SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 }; + +constexpr static std::array supportedTypeListMatMulIntegerToFloat = { + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Ints8Bit, + SupportedTensorDataTypes::Float16to32 +}; + constexpr static std::array supportedTypeListQLinearConv = { SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, @@ -1083,6 +1091,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)}, {REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)}, {REG_INFO( 10, MatMulInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, + {REG_INFO_MS( 1, MatMulIntegerToFloat, typeNameListThree, supportedTypeListMatMulIntegerToFloat, DmlGraphSupport::Supported)}, {REG_INFO( 10, ConvInteger, typeNameListThree, supportedTypeListInteger, DmlGraphSupport::Supported)}, {REG_INFO( 11, DynamicQuantizeLinear, typeNameListTwo, supportedTypeListDynamicQuantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 7, LayerNormalization, typeNameListLayerNormContrib, supportedTypeListLayerNormalizationContrib, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryLayerNormalization)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 1b2521a86613f..06bacc1b28c99 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -870,7 +870,6 @@ class QLinearMatMulHelper : public MatMulHelperBase QLinearMatMulHelper(const Info_t& info, const Shape_t& shape) : MatMulHelperBase(info, shape, 0, 3) {} }; - class TopKHelper { void Initialize( @@ -1776,6 +1775,7 @@ using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MatMul = MatMulHelper; using ShapeInferenceHelper_MatMulInteger = MatMulHelper; +using ShapeInferenceHelper_MatMulIntegerToFloat = MatMulHelper; using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper; using ShapeInferenceHelper_QLinearAdd = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_DynamicQuantizeLinear = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index e725ba085113d..d081aa2e29148 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -449,6 +449,7 @@ namespace OperatorHelper static const int sc_sinceVer_FusedMatMulActivation = 1; static const int sc_sinceVer_QLinearSigmoid = 1; static const int sc_sinceVer_Attention = 1; + static const int sc_sinceVer_MatMulIntegerToFloat = 1; static const int sc_sinceVer_MultiHeadAttention = 1; static const int sc_sinceVer_SkipLayerNormalization = 1; static const int sc_sinceVer_EmbedLayerNormalization = 1; diff --git a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc index 26ce5272d25ee..6f3ca7e239671 100644 --- a/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_integer_to_float_test.cc @@ -23,135 +23,407 @@ using namespace std; namespace onnxruntime { namespace test { -template -void TestMatMulIntegerToFloat(const std::vector& A_dims, - std::vector B_dims, - const std::string& reference_model, - bool is_matrix_b_constant, +template +static void CalculateMatMulIntegerToFloat(const int64_t M, const int64_t N, const int64_t K, + const std::vector& A_data, const std::vector& A_scale, + const std::vector& A_zero_point, const std::vector& B_data, + std::vector& B_scale, std::vector& B_zero_point, + const std::vector& Bias, std::vector& Y_data, + bool per_column, bool has_zp, bool has_bias) { + if (!per_column) { + B_zero_point.resize(N, B_zero_point[0]); + B_scale.resize(N, B_scale[0]); + } + + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + float A_dequantized = has_zp ? (static_cast(A_data[m * K + k]) - static_cast(A_zero_point[0])) * A_scale[0] : A_data[m * K + k] * A_scale[0]; + float B_dequantized = has_zp ? (static_cast(B_data[k * N + n]) - static_cast(B_zero_point[n])) * B_scale[n] : B_data[k * N + n] * B_scale[n]; + + sum += A_dequantized * B_dequantized; + } + if (has_bias) { + sum += Bias[n]; + } + Y_data[m * N + n] = static_cast(sum); + } + } +} + +template +void TestMatMulIntegerToFloat(bool is_matrix_b_constant, bool per_column = false, bool has_zp = true, bool has_bias = false) { // create rand inputs RandomValueGenerator random{}; - + int64_t M = 4; + int64_t N = 128; + int64_t K = 128; + std::vector A_dims{M, K}; + std::vector B_dims{K, N}; + std::vector Y_dims{M, K}; std::vector A_data; - std::vector tmp_A_data = random.Uniform(A_dims, - std::numeric_limits::lowest(), - std::numeric_limits::max()); - std::transform(tmp_A_data.begin(), tmp_A_data.end(), std::back_inserter(A_data), [](int32_t v) -> WType { + std::vector tmp_A_data = random.Uniform(A_dims, + std::numeric_limits::lowest(), + std::numeric_limits::max()); + std::transform(tmp_A_data.begin(), tmp_A_data.end(), std::back_inserter(A_data), [](int32_t v) -> IType { return static_cast(v); }); std::vector B_data; - std::vector tmp_B_data = random.Uniform(B_dims, - std::numeric_limits::lowest(), - std::numeric_limits::max()); + + std::vector tmp_B_data; + tmp_B_data = random.Uniform(B_dims, + std::is_signed::value ? std::numeric_limits::lowest() / 2 : std::numeric_limits::lowest(), + std::numeric_limits::max() / 2); std::transform(tmp_B_data.begin(), tmp_B_data.end(), std::back_inserter(B_data), [](int32_t v) -> WType { return static_cast(v); }); - std::vector A_scale = random.Uniform(AsSpan({1}), -0.1f, 0.1f); + std::vector A_scale = random.Uniform(AsSpan({1}), -0.1f, 0.1f); std::vector A_zero_point{(std::numeric_limits::lowest() + std::numeric_limits::max() + IType(2)) / 2}; int64_t b_scale_zp_size = per_column ? B_dims.back() : 1; - std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f); + std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f); std::vector B_zero_point(b_scale_zp_size); std::for_each(B_zero_point.begin(), B_zero_point.end(), [&random](WType& zp) { - zp = static_cast(random.Uniform(std::array{1}, - std::numeric_limits::lowest(), - std::numeric_limits::max())[0]); + zp = static_cast(random.Uniform(std::array{1}, + std::numeric_limits::lowest(), + std::numeric_limits::max())[0]); }); - std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f); + std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f); OpTester test("MatMulIntegerToFloat", 1, onnxruntime::kMSDomain); test.AddInput("A", A_dims, A_data); test.AddInput("B", B_dims, B_data, is_matrix_b_constant); - test.AddInput("a_scale", {1}, A_scale); - test.AddInput("b_scale", {b_scale_zp_size}, B_scale); + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {b_scale_zp_size}, B_scale); if (has_zp) { test.AddInput("a_zero_point", {1}, A_zero_point); test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point); } else { - test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); test.AddOptionalInputEdge(); } if (has_bias) { - test.AddInput("bias", {B_dims.back()}, Bias); + test.AddInput("bias", {B_dims.back()}, Bias); } else { - test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); } - test.AddReferenceOutputs(reference_model); - test.SetOutputRelErr("Y", 1e-4f); - test.Run(); -} + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, Bias, Y_data, + per_column, has_zp, has_bias); -template -void RunMatMulIntegerToFloatTest(const string& model_path) { - std::vector A_dims{4, 128}; - std::vector B_dims{128, 128}; - std::vector Y_dims{4, 128}; + if (std::is_same_v) { + test.AddOutput("Y", {M, N}, Y_data); + test.SetOutputRelErr("Y", 0.02f); + } else { + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + test.SetOutputAbsErr("Y", 0.5f); + } - TestMatMulIntegerToFloat(A_dims, - B_dims, - model_path, - false, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ + // Only DML EP supports these data type combinations for now + if (std::is_same_v || + (std::is_same_v && + std::is_same_v && + std::is_same_v)) { + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } else { + test.Run(); + } +} + +template +void RunMatMulIntegerToFloatTest() { + TestMatMulIntegerToFloat( + false, /*is_matrix_b_constant*/ + false, /*per_column*/ + HasZeroPoint, /*has_zp*/ + HasBias /*has_bias*/ ); - TestMatMulIntegerToFloat(A_dims, - B_dims, - model_path, - true, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ + TestMatMulIntegerToFloat( + true, /*is_matrix_b_constant*/ + false, /*per_column*/ + HasZeroPoint, /*has_zp*/ + HasBias /*has_bias*/ ); - TestMatMulIntegerToFloat(A_dims, - B_dims, - model_path, - false, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ + TestMatMulIntegerToFloat( + false, /*is_matrix_b_constant*/ + true, /*per_column*/ + HasZeroPoint, /*has_zp*/ + HasBias /*has_bias*/ ); - TestMatMulIntegerToFloat(A_dims, - B_dims, - model_path, - true, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ + TestMatMulIntegerToFloat( + true, /*is_matrix_b_constant*/ + true, /*per_column*/ + HasZeroPoint, /*has_zp*/ + HasBias /*has_bias*/ ); } -TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8X8) { - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8.onnx"); - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_uint8.onnx"); +TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8S8) { + RunMatMulIntegerToFloatTest(); } -TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8X8) { - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8_bias.onnx"); - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_uint8_bias.onnx"); +TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8S8) { + RunMatMulIntegerToFloatTest(); } -TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8S8) { - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8_int8.onnx"); +TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_S8S8) { + RunMatMulIntegerToFloatTest(); } -TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8S8) { - RunMatMulIntegerToFloatTest("testdata/matmul_integer_to_float_int8_int8_bias.onnx"); +TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_S8S8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_U8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_U8X8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_U8S8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_U8S8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_U8S8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_U8S8) { + RunMatMulIntegerToFloatTest(); +} + +// DML EP supports Float16 output type and Signed A Matrix and Unsigned B Matric for Float32 output +#if defined(USE_DML) + +TEST(MatMulIntegerToFloat, HasZeroPoint_NoBias_test_S8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_HasBias_test_S8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, NoZeroPoint_NoBias_test_S8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, HasZeroPoint_HasBias_test_S8U8) { + RunMatMulIntegerToFloatTest(); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_U8U8) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 5; + int64_t N = 5; + int64_t K = 2; + + std::vector A_data = {1, 5, 2, 1, 9, + 1, 1, 3, 7, 2}; + std::vector B_data = {3, 7, 2, 1, 1, + 2, 1, 9, 1, 1}; + std::vector A_scale = ToFloat16({3.0f}); + std::vector B_scale = ToFloat16({2.0f}); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + std::vector A_zero_point = {1}; + std::vector B_zero_point = {1}; + + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, {}, Y_data, + false, true, false); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_U8S8) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 5; + int64_t N = 5; + int64_t K = 2; + + std::vector A_data = {3, 7, 2, 1, 1, + 2, 1, 9, 1, 1}; + std::vector B_data = {2, -1, -9, 1, 1, + -1, 0, -3, 1, -4}; + std::vector A_scale = ToFloat16({-4.0f}); + std::vector B_scale = ToFloat16({2.0f}); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + std::vector A_zero_point = {1}; + std::vector B_zero_point = {3}; + std::vector Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f}); + + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, {}, Y_data, + false, true, false); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_S8S8) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 5; + int64_t N = 5; + int64_t K = 2; + + std::vector A_data = {3, 7, -2, 1, 1, + 2, -1, -9, 1, 1}; + std::vector B_data = {2, -1, -9, 1, 1, + -1, 0, -3, 1, -4}; + std::vector A_scale = ToFloat16({-4.0f}); + std::vector B_scale = ToFloat16({2.0f}); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + std::vector A_zero_point = {-1}; + std::vector B_zero_point = {3}; + std::vector Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f}); + + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + test.AddInput("bias", {N}, Bias); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, Bias, Y_data, + false, true, true); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16_S8U8) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 5; + int64_t N = 5; + int64_t K = 2; + + std::vector A_data = {3, 7, -2, 1, 1, + 2, -1, -9, 1, 1}; + std::vector B_data = {3, 7, 2, 1, 1, + 2, 1, 9, 1, 1}; + std::vector A_scale = ToFloat16({-4.0f}); + std::vector B_scale = ToFloat16({2.0f}); + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + std::vector A_zero_point = {-1}; + std::vector B_zero_point = {1}; + std::vector Bias = ToFloat16({11.0f, -17.0f, 1.0f, -3.0f, 12.0f}); + + test.AddInput("a_scale", {1}, A_scale); + test.AddInput("b_scale", {1}, B_scale); + test.AddInput("a_zero_point", {1}, A_zero_point); + test.AddInput("b_zero_point", {1}, B_zero_point); + test.AddInput("bias", {N}, Bias); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, Bias, Y_data, + false, true, true); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +TEST(MatMulIntegerToFloat, MatMulIntegerToFloat_FP16) { + OpTester test("MatMulIntegerToFloat", 1, kMSDomain); + int64_t M = 2; + int64_t N = 2; + int64_t K = 3; + + std::vector A_data = {11, -2, 5, + -1, 3, 10}; + std::vector B_data = {-13, -2, + 9, 55, + -1, 23}; + std::vector A_scale = ToFloat16({0.910f}); + std::vector B_scale = ToFloat16({1.10f, 1.123f}); + + std::vector A_zero_point = {113}; + std::vector B_zero_point = {98, 71}; + + std::vector Bias = ToFloat16({0.10f, 1.123f}); + + test.AddInput("A", {M, K}, A_data); + test.AddInput("B", {K, N}, B_data); + + test.AddInput("a_scale", {}, {A_scale}); + test.AddInput("b_scale", {N}, B_scale); + test.AddInput("a_zero_point", {}, {A_zero_point}); + test.AddInput("b_zero_point", {N}, B_zero_point); + test.AddInput("bias", {N}, Bias); + + std::vector Y_data(M * N); + CalculateMatMulIntegerToFloat(M, N, K, A_data, A_scale, A_zero_point, + B_data, B_scale, B_zero_point, Bias, Y_data, + true, true, true); + + test.AddOutput("Y", {M, N}, ToFloat16(Y_data)); + test.SetOutputRelErr("Y", 2e-2f); + std::vector> execution_providers; + execution_providers.push_back(DefaultDmlExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +#endif TEST(MatMulIntegerToFloat, MatMulInteger_With_ZeroPoint) { auto test_case = [&](const std::vector& input_shape, diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 16f38bac62713..1535e2b60a3bd 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -5679,6 +5679,24 @@ TEST_F(GraphTransformationTests, MatMulIntegerToFloatTest) { EXPECT_EQ(op_to_count["Add"], 1); } +#ifdef USE_DML +TEST_F(GraphTransformationTests, MatMulIntegerToFloat16Test) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/matmul_integer_to_float16_int8.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + for (auto& node : graph.Nodes()) { + node.SetExecutionProviderType(kDmlExecutionProvider); + } + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); +} +#endif // USE_DML + #endif #ifndef DISABLE_CONTRIB_OPS diff --git a/onnxruntime/test/testdata/matmul_integer_to_float.py b/onnxruntime/test/testdata/matmul_integer_to_float.py index b898390044cf4..e6c51009018f9 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float.py +++ b/onnxruntime/test/testdata/matmul_integer_to_float.py @@ -4,7 +4,7 @@ from onnx import TensorProto, helper -def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: N802 +def GenerateModel(model_name, sign_i, sign_w, output_type_fp16, has_zp=True, bias=False): # noqa: N802 nodes = [ # subgraph helper.make_node( "MatMulInteger", @@ -13,7 +13,13 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: "MatMulInteger", ), helper.make_node("Mul", ["a_scale", "b_scale"], ["multiplier"], "mul_right"), - helper.make_node("Cast", ["matmul_output_int32"], ["matmul_output_float"], "cast", to=1), + helper.make_node( + "Cast", + ["matmul_output_int32"], + ["matmul_output_float"], + "cast", + to=TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, + ), helper.make_node( "Mul", ["matmul_output_float", "multiplier"], @@ -25,8 +31,8 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: inputs = [ # inputs helper.make_tensor_value_info("A", TensorProto.INT8 if sign_i else TensorProto.UINT8, ["M", "K"]), helper.make_tensor_value_info("B", TensorProto.INT8 if sign_w else TensorProto.UINT8, ["K", "N"]), - helper.make_tensor_value_info("a_scale", TensorProto.FLOAT, [1]), - helper.make_tensor_value_info("b_scale", TensorProto.FLOAT, ["C"]), + helper.make_tensor_value_info("a_scale", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("b_scale", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["C"]), ] if has_zp: @@ -48,14 +54,22 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: if bias: nodes.extend([helper.make_node("Add", ["mul_bottom_output", "bias"], ["Y"], "add")]) - inputs.extend([helper.make_tensor_value_info("bias", TensorProto.FLOAT, ["N"])]) + inputs.extend( + [ + helper.make_tensor_value_info( + "bias", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["N"] + ) + ] + ) graph = helper.make_graph( nodes, "DynamicQuantizeMatMul_fusion", # name inputs, [ # outputs - helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["M", "N"]), + helper.make_tensor_value_info( + "Y", TensorProto.FLOAT16 if output_type_fp16 else TensorProto.FLOAT, ["M", "N"] + ), ], ) @@ -64,10 +78,32 @@ def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): # noqa: if __name__ == "__main__": - GenerateModel("matmul_integer_to_float_int8.onnx", False, True) - GenerateModel("matmul_integer_to_float_uint8.onnx", False, False) - GenerateModel("matmul_integer_to_float_int8_bias.onnx", False, True, False, True) - GenerateModel("matmul_integer_to_float_uint8_bias.onnx", False, False, False, True) + GenerateModel("matmul_integer_to_float16_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=True) + GenerateModel("matmul_integer_to_float_int8.onnx", sign_i=False, sign_w=True, output_type_fp16=False) + GenerateModel("matmul_integer_to_float_uint8.onnx", sign_i=False, sign_w=False, output_type_fp16=False) + GenerateModel( + "matmul_integer_to_float_int8_bias.onnx", + sign_i=False, + sign_w=True, + output_type_fp16=False, + has_zp=False, + bias=True, + ) + GenerateModel( + "matmul_integer_to_float_uint8_bias.onnx", + sign_i=False, + sign_w=False, + output_type_fp16=False, + has_zp=False, + bias=True, + ) - GenerateModel("matmul_integer_to_float_int8_int8.onnx", True, True) - GenerateModel("matmul_integer_to_float_int8_int8_bias.onnx", True, True, False, True) + GenerateModel("matmul_integer_to_float_int8_int8.onnx", sign_i=True, sign_w=True, output_type_fp16=False) + GenerateModel( + "matmul_integer_to_float_int8_int8_bias.onnx", + sign_i=True, + sign_w=True, + output_type_fp16=False, + has_zp=False, + bias=True, + ) diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx index 9f4465a914963..906dec542a4fa 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8.onnx @@ -1,4 +1,4 @@ -:Ì + :Ì U A B @@ -44,4 +44,4 @@ mul_bottom"MulDynamicQuantizeMatMul_fusionZ  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx index 01b7e15aa4a1f..16cdf03c7ae59 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8_bias.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä 9 A Bmatmul_output_int32 MatMulInteger" MatMulInteger @@ -41,4 +41,4 @@ mul_bottom"Mul  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx index 9d38828e25d6a..55102757a0b57 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8.onnx @@ -1,4 +1,4 @@ -:Ì + :Ì U A B @@ -44,4 +44,4 @@ mul_bottom"MulDynamicQuantizeMatMul_fusionZ  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx index 4d9a55af50a87..d9d7222a1acaa 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_int8_int8_bias.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä 9 A Bmatmul_output_int32 MatMulInteger" MatMulInteger @@ -41,4 +41,4 @@ mul_bottom"Mul  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx index a4c6d20d59be8..5373ce145688e 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_uint8.onnx @@ -1,4 +1,4 @@ -:Ì + :Ì U A B @@ -44,4 +44,4 @@ mul_bottom"MulDynamicQuantizeMatMul_fusionZ  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx b/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx index a5be0c63f4dcb..e407414b23b24 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx +++ b/onnxruntime/test/testdata/matmul_integer_to_float_uint8_bias.onnx @@ -1,4 +1,4 @@ -:Ä + :Ä 9 A Bmatmul_output_int32 MatMulInteger" MatMulInteger @@ -41,4 +41,4 @@ mul_bottom"Mul  M -NB \ No newline at end of file +NB \ No newline at end of file diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.onnx index 7ea69c580ee435be09f12b949f14fdb2efe3d403..aa8e67bcbc59e53d3418000c23ef35c75dfd76c6 100644 GIT binary patch delta 13 Ucmeys{ehc_gL5O(TUJJ403a9x!vFvP delta 13 Ucmeys{ehc_gMA~@TUJIM03ZVcx&QzG diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx new file mode 100644 index 0000000000000..22293b0d10756 --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float16_int8.onnx @@ -0,0 +1,51 @@ + :Ì +U +A +B + a_zero_point + b_zero_pointmatmul_output_int32 MatMulInteger" MatMulInteger +. +a_scale +b_scale +multiplier mul_right"Mul +A +matmul_output_int32matmul_output_floatcast"Cast* +to +  +5 +matmul_output_float + +multiplierY +mul_bottom"MulDynamicQuantizeMatMul_fusionZ +A + + +M +KZ +B + + +K +NZ +a_scale + + + +Z +b_scale +  + +CZ + a_zero_point + + +Z + b_zero_point +  +Cb +Y + + + +M +NB \ No newline at end of file From 0cdf36faeb4eafcf543bd84dd6f543a55df738c1 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 4 Mar 2024 13:46:51 -0800 Subject: [PATCH 190/207] Expose SessionOtions.DisablePerSessionThreads (#19730) ### Description ### Motivation and Context ML.NET needs to run mltiple sessions on a single threadpool. --- .../src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs | 5 +++++ .../Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs | 9 +++++++++ .../InferenceTest.cs | 5 ++++- 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 4128524b30483..8a8426a0b3054 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -362,6 +362,7 @@ static NativeMethods() OrtDisableMemPattern = (DOrtDisableMemPattern)Marshal.GetDelegateForFunctionPointer(api_.DisableMemPattern, typeof(DOrtDisableMemPattern)); OrtEnableCpuMemArena = (DOrtEnableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.EnableCpuMemArena, typeof(DOrtEnableCpuMemArena)); OrtDisableCpuMemArena = (DOrtDisableCpuMemArena)Marshal.GetDelegateForFunctionPointer(api_.DisableCpuMemArena, typeof(DOrtDisableCpuMemArena)); + OrtDisablePerSessionThreads = (DOrtDisablePerSessionThreads)Marshal.GetDelegateForFunctionPointer(api_.DisablePerSessionThreads, typeof(DOrtDisablePerSessionThreads)); OrtSetSessionLogId = (DOrtSetSessionLogId)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogId, typeof(DOrtSetSessionLogId)); OrtSetSessionLogVerbosityLevel = (DOrtSetSessionLogVerbosityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogVerbosityLevel, typeof(DOrtSetSessionLogVerbosityLevel)); OrtSetSessionLogSeverityLevel = (DOrtSetSessionLogSeverityLevel)Marshal.GetDelegateForFunctionPointer(api_.SetSessionLogSeverityLevel, typeof(DOrtSetSessionLogSeverityLevel)); @@ -992,6 +993,10 @@ IntPtr[] outputValues /* An array of output value pointers. Array must be alloca public delegate IntPtr /*(OrtStatus*)*/ DOrtDisableCpuMemArena(IntPtr /* OrtSessionOptions* */ options); public static DOrtDisableCpuMemArena OrtDisableCpuMemArena; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtDisablePerSessionThreads(IntPtr /* OrtSessionOptions* */ options); + public static DOrtDisablePerSessionThreads OrtDisablePerSessionThreads; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /*(OrtStatus*)*/ DOrtSetSessionLogId(IntPtr /* OrtSessionOptions* */ options, byte[] /* const char* */ logId); public static DOrtSetSessionLogId OrtSetSessionLogId; diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index 7a68246c9b67a..30d005b3c4236 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -696,6 +696,15 @@ public bool EnableCpuMemArena } private bool _enableCpuMemArena = true; + /// + /// Disables the per session threads. Default is true. + /// This makes all sessions in the process use a global TP. + /// + public void DisablePerSessionThreads() + { + NativeApiStatus.VerifySuccess(NativeMethods.OrtDisablePerSessionThreads(handle)); + } + /// /// Log Id to be used for the session. Default is empty string. /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs index fd8feda359f90..d6a6b9627f418 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/InferenceTest.cs @@ -55,6 +55,9 @@ public void TestSessionOptions() Assert.Equal(0, opt.InterOpNumThreads); Assert.Equal(GraphOptimizationLevel.ORT_ENABLE_ALL, opt.GraphOptimizationLevel); + // No get, so no verify + opt.DisablePerSessionThreads(); + // try setting options opt.ExecutionMode = ExecutionMode.ORT_PARALLEL; Assert.Equal(ExecutionMode.ORT_PARALLEL, opt.ExecutionMode); @@ -98,7 +101,7 @@ public void TestSessionOptions() Assert.Contains("[ErrorCode:InvalidArgument] Config key is empty", ex.Message); // SessionOptions.RegisterOrtExtensions can be manually tested by referencing the - // Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw. + // Microsoft.ML.OnnxRuntime.Extensions nuget package. After that is done, this should not throw. ex = Assert.Throws(() => { opt.RegisterOrtExtensions(); }); Assert.Contains("Microsoft.ML.OnnxRuntime.Extensions NuGet package must be referenced", ex.Message); From 2a5c9b86ebbdba8fb76f79de26524a2fdd2e5c2a Mon Sep 17 00:00:00 2001 From: zhijiang <43435212+zhijxu-MS@users.noreply.github.com> Date: Tue, 5 Mar 2024 10:11:19 +0800 Subject: [PATCH 191/207] Zhijxu/fix conv1d replacement (#19758) remove the constraint - "group number should be less than 3"; add more condition to make sure the conv1d replacement only happens on conv1d instead of conv2d/conv3d; add more tests; --- .../core/optimizer/conv1d_replacement.cc | 63 +++++++++++------- .../test/optimizer/graph_transform_test.cc | 64 ++++++++++++++++--- 2 files changed, 96 insertions(+), 31 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc index 0412000e04e1b..ff220fcb067b8 100644 --- a/orttraining/orttraining/core/optimizer/conv1d_replacement.cc +++ b/orttraining/orttraining/core/optimizer/conv1d_replacement.cc @@ -42,30 +42,45 @@ */ namespace onnxruntime { bool NodeCanBeReplacedByMatmul(const Node& node) { - // If node type is Conv, and attr "dilations" is 1, "kernel_shape" is 1, "stride" is 1, group is 1 or 2, - // then it can be replaced by MatMul - // Kernel_shape is 1 means it is conv1d + /* + If node type is Conv, and satisfy the following conditions then it can be replaced by MatMul: + - not bias as input which means only has 2 inputs: input and weight + - "dilations" should be [1] + size 1 means conv1d + - "strides" should be [1] + - "pads" should be [0,0] + - "autopad" should be "NOTSET" + - "kernel_shape" should be [1] + */ if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Conv", {1, 11})) { return false; } - const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations"); - const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape"); - const auto* stride = graph_utils::GetNodeAttribute(node, "strides"); - const auto* group = graph_utils::GetNodeAttribute(node, "group"); - if (dilations == nullptr || kernel_shape == nullptr || stride == nullptr || group == nullptr) { + + // TODO: bias input can also be supported if needed + if (node.InputDefs().size() != 2) { return false; } - if ((dilations->ints_size() && dilations->ints(0) != 1) || - (kernel_shape->ints_size() && kernel_shape->ints(0) != 1) || - (stride->ints_size() && stride->ints(0) != 1) || - group->i() >= 3) { + + const auto* dilations = graph_utils::GetNodeAttribute(node, "dilations"); + const auto* strides = graph_utils::GetNodeAttribute(node, "strides"); + const auto* pads = graph_utils::GetNodeAttribute(node, "pads"); + const auto* autopad = graph_utils::GetNodeAttribute(node, "auto_pad"); + const auto* kernel_shape = graph_utils::GetNodeAttribute(node, "kernel_shape"); + if (dilations == nullptr || strides == nullptr || pads == nullptr || autopad == nullptr || kernel_shape == nullptr) { return false; } - return true; + if ((dilations->ints_size() == 1 && dilations->ints(0) == 1) && + (strides->ints_size() == 1 && strides->ints(0) == 1) && + (autopad->s() == "NOTSET") && + (pads->ints_size() == 2 && pads->ints(0) == 0 && pads->ints(1) == 0) && + (kernel_shape->ints_size() == 1 && kernel_shape->ints(0) == 1)) { + return true; + } + return false; } -void Conv1dToMatmul(Graph& graph, Node& conv) { +void Conv1dToMatmul(Graph& graph, Node& conv, const std::string transformer_name) { // Shape of conv1d input: [batch_size, in_channels, in_length] // Shape of conv1d weight:[output_channels, input_channels/group, kernel_shape], kernel_shape is 1 // We need to split the input into "group", and squeeze&split the weight, and then do MatMul @@ -83,7 +98,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { conv1d_input_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( graph.GenerateNodeArgName("input_split_output"), nullptr)); } - auto& input_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, {conv1d_input}, + auto& input_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description, {conv1d_input}, {conv1d_input_splitted_outputs}); input_split.SetExecutionProviderType(execution_provider_type); input_split.AddAttribute("axis", int64_t(1)); @@ -93,23 +108,25 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { } // 2. Squeeze conv weight auto conv1d_weight = conv.MutableInputDefs()[1]; + // auto con1d_bias = xx; auto weight_squeeze_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("weight_squeeze_output"), nullptr); - auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName("WeightSqueeze"), "Squeeze", + auto& weight_squeeze = graph.AddNode(graph.GenerateNodeName(transformer_name + "WeightSqueeze"), "Squeeze", node_description, {conv1d_weight}, {weight_squeeze_output}); + int64_t weight_squeeze_axis = 2; if (onnx_opset_version > 12) { // After onnx version 12, squeeze node has axes as input instead of attribute ONNX_NAMESPACE::TensorProto initializer_proto; - initializer_proto.set_name(graph.GenerateNodeName("ConstAsInitializer")); + initializer_proto.set_name(graph.GenerateNodeName(transformer_name + "ConstAsInitializer")); initializer_proto.add_dims(static_cast(1)); initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); - InlinedVector initializer_proto_value{2}; + InlinedVector initializer_proto_value{weight_squeeze_axis}; initializer_proto.set_raw_data(initializer_proto_value.data(), initializer_proto_value.size() * sizeof(int64_t)); auto& axes_input = graph_utils::AddInitializer(graph, initializer_proto); // Squeeze node doesn't have opschema here, so we need to set input args count manually weight_squeeze.MutableInputArgsCount().resize(2); graph_utils::AddNodeInput(weight_squeeze, 1, axes_input); } else { - weight_squeeze.AddAttribute("axes", std::vector{2}); + weight_squeeze.AddAttribute("axes", std::vector{weight_squeeze_axis}); } weight_squeeze.SetExecutionProviderType(execution_provider_type); // 3. Split conv weight @@ -118,7 +135,7 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { conv1d_weight_splitted_outputs.push_back(&graph.GetOrCreateNodeArg( graph.GenerateNodeArgName("weight_split_output"), nullptr)); } - auto& weight_split = graph.AddNode(graph.GenerateNodeName("Split"), "Split", node_description, + auto& weight_split = graph.AddNode(graph.GenerateNodeName(transformer_name + "Split"), "Split", node_description, {weight_squeeze_output}, {conv1d_weight_splitted_outputs}); weight_split.AddAttribute("axis", int64_t(0)); weight_split.SetExecutionProviderType(execution_provider_type); @@ -130,13 +147,13 @@ void Conv1dToMatmul(Graph& graph, Node& conv) { for (int i = 0; i < group_num; i++) { auto matmul_output = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName("matmul_output"), nullptr); matmul_outputs.push_back(matmul_output); - auto& matmul = graph.AddNode(graph.GenerateNodeName("Matmul"), "MatMul", node_description, + auto& matmul = graph.AddNode(graph.GenerateNodeName(transformer_name + "Matmul"), "MatMul", node_description, {conv1d_weight_splitted_outputs[i], conv1d_input_splitted_outputs[i]}, {matmul_output}); matmul.SetExecutionProviderType(execution_provider_type); } // 5. Concat matmul outputs - auto& concat_node = graph.AddNode(graph.GenerateNodeName("Concat"), "Concat", node_description, + auto& concat_node = graph.AddNode(graph.GenerateNodeName(transformer_name + "Concat"), "Concat", node_description, matmul_outputs, {}); concat_node.SetExecutionProviderType(execution_provider_type); concat_node.AddAttribute("axis", int64_t(1)); @@ -155,7 +172,7 @@ Status Conv1dReplacement::ApplyImpl(Graph& graph, bool& modified, int graph_leve ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); if (NodeCanBeReplacedByMatmul(node)) { LOGS(logger, VERBOSE) << "lora conv1d replacement, node name: " + node.Name(); - Conv1dToMatmul(graph, node); + Conv1dToMatmul(graph, node, Name()); modified = true; } } diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index bab7c09839273..109937ff96d1d 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -1200,7 +1200,7 @@ TEST_P(QDQFusionTestsParameterized, CheckModelComposition) { ASSERT_EQ(op_to_count_post_fusion["com.microsoft.FakeQuant"], 1); } -TEST_F(GraphTransformationTests, Conv1dReplacement) { +TEST_F(GraphTransformationTests, Conv1dReplacement_TakeEffect) { auto pre_graph_checker = [&](Graph& graph) { auto op_count_map = CountOpsInGraph(graph); TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); @@ -1208,7 +1208,7 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) { }; for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { - for (auto group : {1, 2}) { + for (auto group : {1, 2, 4}) { auto build_test_case = [&](ModelTestBuilder& builder) { auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); auto out_channel = 64; @@ -1222,6 +1222,8 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) { conv_node.AddAttribute("kernel_shape", std::vector{1}); conv_node.AddAttribute("strides", std::vector{1}); conv_node.AddAttribute("group", static_cast(group)); + conv_node.AddAttribute("pads", std::vector{0, 0}); + conv_node.AddAttribute("auto_pad", "NOTSET"); }; auto post_graph_checker = [&](Graph& graph) { @@ -1243,28 +1245,64 @@ TEST_F(GraphTransformationTests, Conv1dReplacement) { } } -TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { +// node has bias input so conv not replaced +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect1) { auto pre_graph_checker = [&](Graph& graph) { auto op_count_map = CountOpsInGraph(graph); TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); return Status::OK(); }; - // "group" is 3 so conv not replaced for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { auto build_test_case = [&](ModelTestBuilder& builder) { auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); auto out_channel = 64; auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); - auto* weight_arg = builder.MakeInitializer({out_channel, in_channel / 3, 1}, {-1.0f, 1.0f}); + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); + auto* bias_arg = builder.MakeInitializer({out_channel}, {-1.0f, 1.0f}); + auto* conv_output = builder.MakeOutput(); + + auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg, bias_arg}, {conv_output}); + conv_node.AddAttribute("dilations", std::vector{1}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); + conv_node.AddAttribute("strides", std::vector{1}); + conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{0, 0}); + conv_node.AddAttribute("auto_pad", "NOTSET"); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, + pre_graph_checker, pre_graph_checker)); + } +} + +// "auto_pad " is not NOTSET so conv not replaced +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect2) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; + + for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); + auto out_channel = 64; + auto* data_arg = builder.MakeInput({{batch_size, in_channel, in_length}}); + + auto* weight_arg = builder.MakeInitializer({out_channel, in_channel, 1}, {-1.0f, 1.0f}); auto* conv_output = builder.MakeOutput(); auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); conv_node.AddAttribute("dilations", std::vector{1}); conv_node.AddAttribute("kernel_shape", std::vector{1}); conv_node.AddAttribute("strides", std::vector{1}); - conv_node.AddAttribute("group", static_cast(3)); + conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{0, 0}); + conv_node.AddAttribute("auto_pad", "VALID"); }; std::unique_ptr transformer = std::make_unique(); @@ -1272,8 +1310,16 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { TransformerLevel::Level1, 1, pre_graph_checker, pre_graph_checker)); } +} + +// pads is not all zero, so conv not replaced +TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect3) { + auto pre_graph_checker = [&](Graph& graph) { + auto op_count_map = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_map["Conv"] == 1); + return Status::OK(); + }; - // "kernel_shape" is not 1 so conv not replaced for (auto opset : {11, 12, 13, 14, 15, 16, 17, 18}) { auto build_test_case = [&](ModelTestBuilder& builder) { auto [batch_size, in_channel, in_length] = std::make_tuple(8, 16, 128); @@ -1285,9 +1331,11 @@ TEST_F(GraphTransformationTests, Conv1dReplacement_NoTakeEffect) { auto& conv_node = builder.AddNode("Conv", {data_arg, weight_arg}, {conv_output}); conv_node.AddAttribute("dilations", std::vector{1}); - conv_node.AddAttribute("kernel_shape", std::vector{2}); + conv_node.AddAttribute("kernel_shape", std::vector{1}); conv_node.AddAttribute("strides", std::vector{1}); conv_node.AddAttribute("group", static_cast(1)); + conv_node.AddAttribute("pads", std::vector{1, 0}); + conv_node.AddAttribute("auto_pad", "NOTSET"); }; std::unique_ptr transformer = std::make_unique(); From 7e613ee821405b1192d0b71b9434a4f94643f1e4 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Tue, 5 Mar 2024 11:45:45 +0800 Subject: [PATCH 192/207] [quant] supports act_order inputs in Matmulnbits and new quantization algorithm "hqq" (#19106) ### Description 1. Support quantized GPTQ weight in huggingface like [TheBloke/Llama-2-7B-Chat-GPTQ](https://huggingface.co/TheBloke/Llama-2-7B-Chat-GPTQ) 2. Support Act_order for GPTQ 3. Support [HQQ](https://mobiusml.github.io/hqq_blog/) algorithm to quantize matmul weight and add quant script ### Motivation and Context --- docs/ContribOperators.md | 43 +- docs/OperatorKernels.md | 4 +- .../cpu/quantization/matmul_nbits.cc | 105 ++++- .../cpu/quantization/matmul_nbits_impl.cc | 108 +++++ .../cpu/quantization/matmul_nbits_impl.h | 23 ++ .../cuda/quantization/dequantize_blockwise.cu | 159 ++++++-- .../quantization/dequantize_blockwise.cuh | 6 +- .../cuda/quantization/matmul_nbits.cc | 170 ++++---- .../cuda/quantization/matmul_nbits.h | 41 ++ .../core/graph/contrib_ops/contrib_defs.cc | 38 +- .../quantization/matmul_4bits_quantizer.py | 379 ++++++++++++++++-- .../test/contrib_ops/matmul_4bits_test.cc | 78 +++- .../test/python/quantization/op_test_utils.py | 3 +- .../quantization/test_op_matmul_4bits.py | 19 +- 14 files changed, 942 insertions(+), 234 deletions(-) create mode 100644 onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc create mode 100644 onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h create mode 100644 onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e295dfa203ae5..5f0100fad95a2 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2808,22 +2808,23 @@ This version of the operator has been available since version 1 of the 'com.micr And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. - Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: - - n_blocks_per_col = (K + block_size - 1) / block_size - - blob_size = block_size / 8 * bits + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] - Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 - + Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. #### Version @@ -2844,17 +2845,19 @@ This version of the operator has been available since version 1 of the 'com.micr
number of groupsize used for weight quantization,(default 128). It needs to be a power of 2 and not smaller than 16.
-#### Inputs (3 - 4) +#### Inputs (3 - 5)
A : T1
The input tensor, not quantized
B : T2
-
1-dimensional data blob
+
1 or 2 dimensional data blob
scales : T1
quantization scale
-
zero_points (optional) : T2
+
zero_points (optional) : T3
quantization zero points
+
g_idx (optional) : T4
+
group_idx
#### Outputs @@ -2869,8 +2872,12 @@ This version of the operator has been available since version 1 of the 'com.micr
T1 : tensor(float), tensor(float16)
Constrain input and output types to float/half_float tensors.
-
T2 : tensor(uint8)
-
Constrain quantized weight types to uint8.
+
T2 : tensor(uint8), tensor(int32)
+
Constrain quantized weight types to uint8/int32.
+
T3 : tensor(uint8), tensor(int32), tensor(float16), tensor(float)
+
Constrain quantized zero point types to uint8/int32/float16/float.
+
T4 : tensor(int32)
+
the index tensor.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 0e60b4622f2fb..71b0def659741 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -470,7 +470,7 @@ Do not modify directly.* |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| @@ -855,7 +855,7 @@ Do not modify directly.* |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc2_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 166f5c8f52f54..602dd98d8c0d6 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -1,6 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + +#include +#include + +#include "core/common/common.h" #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/framework/op_kernel.h" @@ -50,6 +56,17 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level } } // namespace +bool GetType(const NodeArg& node_arg, int32_t& type) { + type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; + const auto* type_proto = node_arg.TypeAsProto(); + if (!type_proto || !type_proto->has_tensor_type() || !type_proto->tensor_type().has_elem_type()) { + return false; + } + + type = type_proto->tensor_type().elem_type(); + return true; +} + class MatMulNBits final : public OpKernel { public: MatMulNBits(const OpKernelInfo& info) @@ -59,6 +76,17 @@ class MatMulNBits final : public OpKernel { block_size_{narrow(info.GetAttr("block_size"))}, nbits_{narrow(info.GetAttr("bits"))}, accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr("accuracy_level"))} { + const auto& node = info.node(); + auto input_defs = node.InputDefs(); + // g_idx + if (input_defs.size() > 4) { + act_order_ = true; + } + int32_t type; + if (input_defs.size() > 3 && GetType(*input_defs[3], type)) { + zero_point_is_not_quant_ = type != ONNX_NAMESPACE::TensorProto_DataType_UINT8; + } + ORT_ENFORCE(nbits_ == 4, "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); #ifdef ORT_NEURAL_SPEED @@ -88,6 +116,8 @@ class MatMulNBits final : public OpKernel { const size_t N_; const size_t block_size_; const size_t nbits_; + bool act_order_{false}; + bool zero_point_is_not_quant_{false}; const int64_t accuracy_level_; const bool column_wise_quant_{true}; IAllocatorUniquePtr packed_b_; @@ -105,7 +135,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; - + if (act_order_ || zero_point_is_not_quant_) { + return Status::OK(); + } #if defined(ORT_NEURAL_SPEED) if (!all_constant_) { @@ -212,7 +244,6 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep Status MatMulNBits::Compute(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); - const Tensor* a = ctx->Input(0); const auto* a_data = a->Data(); @@ -257,11 +288,14 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { #endif // defined(ORT_NEURAL_SPEED) const Tensor* scales = ctx->Input(2); - const Tensor* zero_points = ctx->Input(3); + const Tensor* zero_points = ctx->InputCount() > 3 ? ctx->Input(3) : nullptr; + const Tensor* reorder_idx = ctx->InputCount() > 4 ? ctx->Input(4) : nullptr; + const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); TensorShape b_shape({static_cast(N_), static_cast(K_)}); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); MatMulComputeHelper helper; ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true)); @@ -281,8 +315,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), - [](size_t offset) { return offset == 0; }); + const bool has_single_b_matrix = + (!act_order_) && (!zero_point_is_not_quant_) && + std::all_of(helper.RightOffsets().begin(), helper.RightOffsets().end(), [](size_t offset) { return offset == 0; }); if (has_single_b_matrix) { const auto compute_type = static_cast(accuracy_level_); @@ -328,22 +363,50 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const uint8_t* b_data = b->Data(); const size_t ldb = helper.Ldb(true); - AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - // dequantize b, only 4b quantization is supported for now - MlasDequantizeBlockwise( - tmp_b_data_ptr.get(), // dequantized output - b_data, // quantized input - scales_data, // quantization scales - zero_points_data, // quantization zero points - static_cast(block_size_), // quantization block size - column_wise_quant_, // columnwise quantization or row-wise - static_cast(K_), // number of rows in quantized input - static_cast(N_), // number of columns in quantized input - thread_pool); - + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { + // dequantize b, only 4b quantization is supported for now + MlasDequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { + ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); + // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! + if ((zero_points && zero_points->IsDataType())) { + DequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + reorder_idx_data, + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } else { + DequantizeBlockwise( + tmp_b_data_ptr.get(), // dequantized output + b_data, // quantized input + scales_data, // quantization scales + static_cast(zero_points_data), // quantization zero points + reorder_idx_data, + static_cast(block_size_), // quantization block size + column_wise_quant_, // columnwise quantization or row-wise + static_cast(K_), // number of rows in quantized input + static_cast(N_), // number of columns in quantized input + thread_pool); + } + } #if 0 // for debug auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); @@ -374,7 +437,9 @@ ONNX_OPERATOR_KERNEL_EX( kCpuExecutionProvider, KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), MatMulNBits); } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc new file mode 100644 index 0000000000000..f92e59e990ba5 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" + +#include +#include +#include +#include +#include + +#include "core/common/common.h" +#include "core/framework/float16.h" +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +template +void Dequantize4BitsKernelReOrder( + T* output, const uint8_t* quant_data, const T* scale_data, + const zeroT* zero_points, const int32_t* reorder_idx, int block_size, + int groups_per_threadblock, int total_groups, int out_rows, int out_cols, + int blockIdx_x, int threadIdx_x) { + const int group_id = blockIdx_x * groups_per_threadblock + ((threadIdx_x * 8) / block_size); + if (group_id >= total_groups) { + return; + } + const int scales_shape_x = (out_cols + block_size - 1) / block_size; + const int zero_point_shape_x = (scales_shape_x + 1) / 2; + + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx_x * 8) & (block_size - 1)); + + const int out_x = element_offset % (scales_shape_x * block_size); + const int out_y = element_offset / (scales_shape_x * block_size); + if (out_y >= out_rows || out_x >= out_cols) { + return; + } + T* output_i = output + out_y * out_cols + out_x; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + const int remain_x = std::min(8, out_cols - out_x); + for (int i = 0; i < remain_x; i++) { + int32_t rid = reorder_idx ? reorder_idx[kb_idx * block_size + i] : kb_idx; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + float zp_f = 8; + if (zero_points) { + if constexpr (std::is_same_v) { + zp_f = *(zero_points + n_idx * scales_shape_x + rid); + } else { + uint8_t zp = 8; + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * MLFloat16(zp_f); + output_i[i] = static_cast((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * zp_f; + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +void DequantizeBlockwise( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + const int32_t* reorder_idx, // reorder_idx for groupwise quantization + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* pool) { + auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; + constexpr int element_per_thread = 8; + int groups_per_threadblock = 256 * element_per_thread / block_size; + int groups_per_K = ceildiv(K, block_size); + int total_groups = N * groups_per_K; // total elemenets in quant_data + int blocks_per_grid = static_cast(ceildiv(total_groups, groups_per_threadblock)); + concurrency::ThreadPool::TrySimpleParallelFor( + pool, static_cast(blocks_per_grid), + [&](std::ptrdiff_t block_id) { + for (int j = 0; j < 256; j++) { + Dequantize4BitsKernelReOrder(output, quant_data, scales_data, zero_points, + reorder_idx, block_size, groups_per_threadblock, + total_groups, N, K, static_cast(block_id), j); + } + }); +} + +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const float* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h new file mode 100644 index 0000000000000..5061ac5c800a6 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "core/providers/common.h" +#include "core/platform/threadpool.h" + +namespace onnxruntime { +namespace contrib { + +template +void DequantizeBlockwise( + inputT* output, // dequantized output + const uint8_t* quant_data, // quantized input + const inputT* scales_data, // quantization scales + const zeroT* zero_points, // quantization zero points + const int32_t* reorder_idx, // quantization zero points + int32_t block_size, // quantization block size + bool, // columnwise quantization or row-wise + int32_t K, // number of rows in quantized input + int32_t N, // number of columns in quantized input + onnxruntime::concurrency::ThreadPool* thread_pool); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu index 6b66f1d84e221..cd6593352008b 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cu @@ -2,10 +2,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include #include #include #include #include +#include #include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" @@ -56,41 +58,94 @@ __device__ __forceinline__ void DequantizeEightElements(uint32_t values_quant, f } template -__global__ void Dequantize4BitsKernel( +__global__ void Dequantize4BitsKernelReOrder( T* output, const uint8_t* quant_data, const T* scale_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int block_size, - int blocks_per_K, - int blocks_per_threadblock, - int total_blks, - int shift) { - int block_id = blockIdx.x * blocks_per_threadblock + ((threadIdx.x * 8) >> shift); - if (block_id >= total_blks) { + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int group_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (group_id >= total_groups) { return; } - int n_idx = block_id / blocks_per_K; - int kb_idx = block_id % blocks_per_K; - int element_offset = block_id * block_size + ((threadIdx.x * 8) & ((1 << shift) - 1)); + // T __shared__ zero_points_after_reorder[];//K + // T __shared__ scales_after_reorder[]; // K + // const int num_r_per_thread = k / 256; + + const int zero_point_shape_x = (groups_per_K + 1) / 2; + const int scales_shape_x = groups_per_K; + int n_idx = group_id / scales_shape_x; + int kb_idx = group_id % scales_shape_x; + int element_offset = group_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); + T* output_i = output + element_offset; + uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); + for (int i = 0; i < 8; i++) { + int32_t rid = reorder_idx[kb_idx * block_size + i]; + T scale = *(scale_data + n_idx * scales_shape_x + rid); + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; + zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + + if constexpr (std::is_same_v) { + T zp_adjust = -scale * __short2half_rn(zp); + output_i[i] = __uint2half_rn((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } else { + T zp_adjust = -scale * T(zp); + output_i[i] = T((quant_value >> (4 * i)) & 0xF) * scale + zp_adjust; + } + } +} + +template +__global__ void Dequantize4BitsKernel( + T* output, + const uint8_t* quant_data, + const T* scale_data, + const ZeroT* zero_points, + int block_size, + int groups_per_K, + int groups_per_threadblock, + int total_groups) { + int block_id = blockIdx.x * groups_per_threadblock + ((threadIdx.x * 8) / block_size); + if (block_id >= total_groups) { + return; + } + int element_offset = block_id * block_size + ((threadIdx.x * 8) & (block_size - 1)); uint32_t quant_value = *(reinterpret_cast(quant_data + element_offset / 2)); T scale = *(scale_data + block_id); - uint8_t zp = 8; - if (zero_points) { - zp = zero_points[n_idx * ((blocks_per_K + 1)/2) + kb_idx / 2]; - zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + T zero_point_value; + if constexpr (std::is_same_v) { + const int scales_shape_x = groups_per_K; + const int zero_point_shape_x = (groups_per_K + 1) / 2; + int kb_idx = block_id % scales_shape_x; + int n_idx = block_id / scales_shape_x; + uint8_t zp = 8; + if (zero_points) { + zp = zero_points[n_idx * zero_point_shape_x + kb_idx / 2]; + zp = (kb_idx & 0x01) ? (zp >> 4) : (zp & 0x0f); + } + zero_point_value = static_cast(zp); + } else { + zero_point_value = zero_points? *(zero_points + block_id):static_cast(8); } output = output + element_offset; - DequantizeEightElements(quant_value, scale, static_cast(zp), output); + DequantizeEightElements(quant_value, scale, zero_point_value, output); } -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, // shape: [N, (block_per_K + 1)/2] + const ZeroT* zero_points, // shape: [N, (block_per_K + 1)/2] + const int32_t* reorder_idx, int k, int n, int block_size, @@ -98,47 +153,79 @@ Status Dequantize4Bits( // k is padded and equal to block_per_K * block_size ORT_ENFORCE(k % block_size == 0, "k must be a multiplier of block_size"); constexpr int element_per_thread = 8; - int blocks_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; - int blocks_per_K = k / block_size; - int total_blks = n * blocks_per_K; - int blocks_per_grid = static_cast(CeilDiv(n * blocks_per_K, blocks_per_threadblock)); - int shift = static_cast(log2f(float(block_size))); - - Dequantize4BitsKernel<<>>( - output, - quant_data, - scales_data, - zero_points, - block_size, - blocks_per_K, - blocks_per_threadblock, - total_blks, - shift); + int groups_per_threadblock = GridDim::maxThreadsPerBlock * element_per_thread / block_size; + int groups_per_K = k / block_size; + int total_groups = n * groups_per_K; // total elemenets in quant_data + int groups_per_grid = static_cast(CeilDiv(total_groups, groups_per_threadblock)); + if (!reorder_idx) { + Dequantize4BitsKernel<<>>( + output, + quant_data, + scales_data, + zero_points, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } else { + // static_assert(std::is_same_v, "ZeroT must be uint8_t"); + Dequantize4BitsKernelReOrder<<>>( + output, + quant_data, + scales_data, + (const uint8_t*)zero_points, + reorder_idx, + block_size, + groups_per_K, + groups_per_threadblock, + total_groups); + } return Status::OK(); } -template Status Dequantize4Bits( +template Status Dequantize4Bits( float* output, const uint8_t* quant_data, const float* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); -template Status Dequantize4Bits( +template Status Dequantize4Bits( half* output, const uint8_t* quant_data, const half* scales_data, const uint8_t* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); +template Status Dequantize4Bits( + float* output, + const uint8_t* quant_data, + const float* scales_data, + const float* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - +template Status Dequantize4Bits( + half* output, + const uint8_t* quant_data, + const half* scales_data, + const half* zero_points, + const int32_t* reorder_idx, + int k, + int n, + int block_size, + cudaStream_t stream); /////////////////////////////////////////////////////////////////////////////// // A more general block-wise dequantization implementation that supports // different block sizes and block orientations (row-wise/column-wise). diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh index f9c09c55fd893..580b5087f3fa3 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise.cuh @@ -7,18 +7,18 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template Status Dequantize4Bits( T* output, const uint8_t* quant_data, const T* scales_data, - const uint8_t* zero_points, + const ZeroT* zero_points, + const int32_t* reorder_idx, int k, int n, int block_size, cudaStream_t stream); - /** * @brief Dequantize a block-wise quantized matrix, and store the result in a * column major matrix for use in subsequent GEMM. This implementation supports diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc index 015df70c8ec3c..1cec6f6a12f1c 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc @@ -1,15 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// -// This module define MatMulFp32Q4 operator, it is basically -// matmul float32 with right hand side being a 2-D matrix -// pre-packed and block-compacted into int4 -// - -#include "core/common/safeint.h" -#include "core/providers/cuda/cuda_kernel.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "contrib_ops/cuda/quantization/matmul_nbits.h" + +#include + +#include "core/common/status.h" +#include "core/framework/float16.h" #include "core/providers/cpu/math/matmul_helper.h" #include "matmul_nbits.cuh" #include "dequantize_blockwise.cuh" @@ -19,40 +16,19 @@ namespace contrib { namespace cuda { using namespace onnxruntime::cuda; -template -class MatMulNBits final : public CudaKernel { - public: - MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { - ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); - ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); - ORT_ENFORCE(nbits_ == 4, - "Only 4b quantization is supported for MatMulNBits op," - " additional bits support is planned."); - } - - Status ComputeInternal(OpKernelContext* context) const override; - - private: - int64_t K_; - int64_t N_; - int64_t block_size_; - int64_t nbits_; - bool column_wise_quant_blk_{true}; -}; - template Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { const Tensor* a = ctx->Input(0); const Tensor* b = ctx->Input(1); const Tensor* scales = ctx->Input(2); const Tensor* zero_points = ctx->Input(3); + const Tensor* reorder_idx = ctx->Input(4); const auto* a_data = a->Data(); const uint8_t* blob_data = b->Data(); const auto* scales_data = scales->Data(); - const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->Data(); + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); + const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); typedef typename ToCudaType::MappedType CudaT; @@ -67,77 +43,99 @@ Status MatMulNBits::ComputeInternal(OpKernelContext* ctx) const { // Bail out early if the output is going to be empty if (Y->Shape().Size() == 0) return Status::OK(); - bool is_4bit_done = TryMatMul4Bits( - reinterpret_cast(Y->MutableData()), - reinterpret_cast(a_data), - blob_data, - reinterpret_cast(scales_data), - zero_points_data, - SafeInt(helper.M()), - SafeInt(helper.N()), - SafeInt(helper.K()), - SafeInt(block_size_), - SafeInt(GetDeviceProp().sharedMemPerBlock), - static_cast(ctx->GetComputeStream()->GetHandle())); - if (!is_4bit_done) { - int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; - IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); - auto* b_data = b_data_ptr.get(); - if (column_wise_quant_blk_) { - // column-wise block + bool is_4bit_done = (reorder_idx_data == nullptr) && + (!zero_points || !zero_points->IsDataType()) && + TryMatMul4Bits( + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + blob_data, + reinterpret_cast(scales_data), + static_cast(zero_points_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + SafeInt(GetDeviceProp().sharedMemPerBlock), + static_cast(ctx->GetComputeStream()->GetHandle())); + + if (is_4bit_done) { + return Status::OK(); + } + + int64_t K_padded = (K_ + block_size_ - 1) / block_size_ * block_size_; + IAllocatorUniquePtr b_data_ptr = GetScratchBuffer(N_ * K_padded, ctx->GetComputeStream()); + auto* b_data = b_data_ptr.get(); + if (column_wise_quant_blk_) { + if (reorder_idx) { + ORT_ENFORCE(K_padded == reorder_idx->Shape()[0], "K_padded != g_idx->Shape()[0]"); + } + // column-wise block + if ((zero_points && zero_points->IsDataType())) { ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, + (const CudaT*)zero_points_data, + reorder_idx_data, SafeInt(K_padded), SafeInt(N_), SafeInt(block_size_), static_cast(ctx->GetComputeStream()->GetHandle()))); } else { - // row-wise block - K_padded = K_; - - ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + ORT_RETURN_IF_ERROR(Dequantize4Bits( reinterpret_cast(b_data), blob_data, reinterpret_cast(scales_data), - zero_points_data, - SafeInt(block_size_), - column_wise_quant_blk_, - SafeInt(K_), + (const uint8_t*)zero_points_data, + reorder_idx_data, + SafeInt(K_padded), SafeInt(N_), + SafeInt(block_size_), static_cast(ctx->GetComputeStream()->GetHandle()))); } + } else { + // row-wise block + K_padded = K_; + + ORT_RETURN_IF_ERROR(DequantizeBlockwise4b( + reinterpret_cast(b_data), + blob_data, + reinterpret_cast(scales_data), + (const uint8_t*)zero_points_data, + SafeInt(block_size_), + column_wise_quant_blk_, + SafeInt(K_), + SafeInt(N_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + } #if 0 - cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); - T* b_data_cpu = new T[K_ * N_]; - cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); - delete[] b_data_cpu; +cudaStreamSynchronize(static_cast(ctx->GetComputeStream()->GetHandle())); +T* b_data_cpu = new T[K_ * N_]; +cudaMemcpy(b_data_cpu, b_data, K_ * N_ * sizeof(T), cudaMemcpyDeviceToHost); +delete[] b_data_cpu; #endif - const CudaT alpha = ToCudaType::FromFloat(1.f); - const CudaT zero = ToCudaType::FromFloat(0.f); - - if (helper.OutputOffsets().size() == 1) { - CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( - GetCublasHandle(ctx), - CUBLAS_OP_T, - CUBLAS_OP_N, - SafeInt(helper.N()), - SafeInt(helper.M()), - SafeInt(helper.K()), - &alpha, - reinterpret_cast(b_data), - SafeInt(K_padded), - reinterpret_cast(a_data), - helper.Lda(transa), - &zero, - reinterpret_cast(Y->MutableData()), - helper.Ldc(), - GetDeviceProp(), - UseTF32())); - } + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + if (helper.OutputOffsets().size() == 1) { + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_data), + SafeInt(K_padded), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp(), + UseTF32())); } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..f5c2c6c4e4fdf --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +// This module define MatMulNBits operator, it is basically +// matmul float with right hand side being a 2-D matrix +// pre-packed and block-compacted into int4 +// +#pragma once +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulNBits final : public CudaKernel { + public: + MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("bits", &nbits_)); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t nbits_; + bool column_wise_quant_blk_{true}; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index e33ce20737f80..f06a3785f362d 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3343,22 +3343,23 @@ MatMulNBits is a MatMul with weight quantized with N bits(e.g., 2, 3, 4, 5, 6, 7 And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. 3. Input B's scale and zero point are specified by input scales and zero_points. -Input B is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: -- n_blocks_per_col = (K + block_size - 1) / block_size -- blob_size = block_size / 8 * bits - - For a block blob. It is stored in format: - struct Blob { - uint8 one_bits[(bits & 0x1) * 1 * block_size / 8]; // highest 1 bit for 3, 5, 7 bits quantization - uint8 two_bits[(bits & 0x2) * 2 * block_size / 8]; // high 2 bits for 2, 6, 7 bits quantization - uint8 four_bits[(bits & 0x4) * 4 * block_size / 8]; // low 4 bits for 4, 5, 6 bits quantization - } + Input is stored as uint8_t with shape: [N][n_blocks_per_col][blob_size] in which: + - n_blocks_per_col = (K + block_size - 1) / block_size + - blob_size = CeilDiv(block_size * bits, bitsof(uint8_t)<8>) + For all bits from 2-8, a row of data is stored squeezely and represented by uint8_t. + - for 2,4,8 bits, 4x2bit,2x4bit,1x8bit are stored in one uint8_t. + 4bit example: + |.|.|.|.| .|.|.|.| =uint8_t (2x4bit) + - for 3,5,6,7 bits, 32x3bit,32x5bit,16x6bit,32x7bit are stored in 12xuint8_t,20xuint8_t,12xuint8_t,28xuint8_t separately. no bits are wasted. + 3bit example: + |.|.|. |.|.|. |.|.|. = 9bit, which across 2 uint8_t, the highest bit for the second uint8_t is used. + The last uint_8 may have some bits unused. -Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] -Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored as one unit8_t. If bits > 4, one zero point is stored with one unit8_t. Thus, its shape is: - - [(N * n_blocks_per_col + 1) / 2] if bits <=4 - - [N * n_blocks_per_col] if bits > 4 +Input scales is stored in same type as original type of B(float32, float16) with shape like: [N * n_blocks_per_col] +Input zero_points is stored as uint8_t or same as type(A). It has the same packing method as input B. + - [CeilDiv((N * n_blocks_per_col + 1) *bits, 8)] + If zero_points has same type as A, it's not packed and has the same shape as Scales. )DOC"; ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulNBits) @@ -3377,12 +3378,15 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored "type T1.", AttributeProto::INT, static_cast(0)) .Input(0, "A", "The input tensor, not quantized", "T1") - .Input(1, "B", "1-dimensional data blob", "T2") + .Input(1, "B", "1 or 2 dimensional data blob", "T2") .Input(2, "scales", "quantization scale", "T1") - .Input(3, "zero_points", "quantization zero points", "T2", OpSchema::Optional) + .Input(3, "zero_points", "quantization zero points", "T3", OpSchema::Optional) + .Input(4, "g_idx", "group_idx", "T4", OpSchema::Optional) .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") - .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeConstraint("T2", {"tensor(uint8)", "tensor(int32)"}, "Constrain quantized weight types to uint8/int32.") + .TypeConstraint("T3", {"tensor(uint8)", "tensor(int32)", "tensor(float16)", "tensor(float)"}, "Constrain quantized zero point types to uint8/int32/float16/float.") + .TypeConstraint("T4", {"tensor(int32)"}, "the index tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { // Type inference propagateElemTypeFromInputToOutput(ctx, 0, 0); diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index eb7bbec997d59..a1916e806c5c0 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -65,7 +65,7 @@ def __init__( self, calibration_data_reader: CalibrationDataReader, percdamp=0.01, - blocksize=128, + block_size=128, actorder=False, mse=False, perchannel=True, @@ -79,7 +79,7 @@ def __init__( a calibration data reader. It enumerates calibration data and generates inputs for the original model. percdamp: percent of the average Hessian diagonal to use for dampening. - blocksize (int, optional): + block_size (int, optional): channel number in one block to execute a GPTQ quantization iteration. actorder (bool, optional): whether rearrange Hessian matrix considering the diag's value. @@ -93,42 +93,285 @@ def __init__( ) self.calibration_data_reader = calibration_data_reader self.percdamp = percdamp - self.blocksize = blocksize + self.block_size = block_size self.actorder = actorder self.mse = mse self.perchannel = perchannel -class MatMul4BitsQuantizer: - """Perform 4b quantization of constant MatMul weights""" +class HQQWeightOnlyQuantConfig(WeightOnlyQuantConfig): + def __init__( + self, + block_size=128, + bits=4, + axis=1, + ): + """ + This is a class for HQQ algorithm Weight Only Quant Configuration. + HQQ algorithm quant weight without needing calibrate data. + + Args: + block_size (int, optional): + channel number in one block to execute a GPTQ quantization iteration. + bits (int, optional): + how many bits to represent weight. + axis (int, optional): + 0 or 1. which axis to quantize. https://arxiv.org/pdf/2309.15531.pdf + """ + super().__init__( + algorithm="HQQ", + ) + self.block_size = block_size + self.bits = bits + self.axis = axis + +class DefaultWeightOnlyQuantConfig(WeightOnlyQuantConfig): def __init__( self, - model: ModelProto | str, - block_size: int, - is_symmetric: bool, + block_size: int = 128, + is_symmetric: bool = False, accuracy_level: int | None = None, - nodes_to_exclude=None, - algo_config: WeightOnlyQuantConfig = None, ): - if nodes_to_exclude is None: - nodes_to_exclude = [] - self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) - self.model_path = model if isinstance(model, str) else None + super().__init__(algorithm="DEFAULT") self.block_size = block_size self.is_symmetric = is_symmetric + self.bits = 4 self.accuracy_level = accuracy_level - self.nodes_to_exclude = set(nodes_to_exclude) - self.algo_config = algo_config + + +def is_divisible(val1, val2): + return int(val2 * np.ceil(val1 / val2)) == val1 + + +class HQQWeightOnlyQuantizer: + def __init__( + self, + config: HQQWeightOnlyQuantConfig, + ): + self.config = config + + # Proximal solver || weight - dequantize(quantize(weight))||_p^p + @staticmethod + def optimize_weights( + tensor, + scale, + zero, + min_max: list[int], + axis: int = 0, + opt_params: dict = None, # noqa: RUF013 + verbose=False, + ): + import torch + + opt_params = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} if opt_params is None else opt_params + lp_norm, beta, kappa, iters = ( + opt_params["lp_norm"], + opt_params["beta"], + opt_params["kappa"], + opt_params["iters"], + ) + + dtype = torch.float16 if tensor.is_cuda else torch.float32 + w_f = tensor.to(dtype) + scale = scale.to(dtype) + zero = zero.to(dtype) + + if lp_norm == 1: + + def shrink_op(x, beta): + return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) + + else: + + def shrink_op(x, beta, p=lp_norm): + return torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x) + 1e-8, p - 1) + ) + + best_error = 1e4 + for i in range(iters): + w_q = torch.round(w_f * scale + zero).clamp(min_max[0], min_max[1]) + w_r = (w_q - zero) / scale + w_e = shrink_op(w_f - w_r, beta) + zero = torch.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdim=True) + beta *= kappa + + current_error = float(torch.abs(w_f - w_r).mean()) + if verbose: + print(i, np.round(current_error, 6)) + if current_error < best_error: + best_error = current_error + else: + break + + del w_f, w_q, w_r, w_e + + return scale, zero @staticmethod - def __get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: - for gid in range(len(graph_path) - 1, -1, -1): - graph = graph_path[gid] - for tensor in graph.initializer: - if tensor.name == name: - return tensor, graph - return None, None + def pack_on_row_fast_248bit(pack_tensor, ori_int_tensor, bits): + if pack_tensor.shape[0] == ori_int_tensor.shape[0]: + ori_int_tensor = ori_int_tensor.T + pack_tensor = pack_tensor.T + if bits in [2, 4, 8]: + compress_ratio = pack_tensor.element_size() * 8 // bits + for j in range(0, compress_ratio): + pack_tensor[0:] |= ori_int_tensor[j::compress_ratio] << (bits * (j)) + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + # from Official implementation of Half-Quadratic Quantization (HQQ) + def quantize_internal( + self, tensor, bits=4, channel_wise=True, group_size=64, optimize=True, round_zero=True, axis=1 + ): + import torch + + weight = tensor.float() + ori_shape = weight.shape + + pad_len = (group_size - ori_shape[axis] % group_size) % group_size + if axis == 1: + weight = torch.nn.functional.pad(weight, (0, pad_len), "constant", 0) + else: + weight = torch.nn.functional.pad(weight, (0, 0, 0, pad_len), "constant", 0) + shape = weight.shape + + # Reshape for grouping + if (group_size is not None) and channel_wise: + weight = weight.reshape([-1, group_size]) if (axis == 1) else weight.reshape([group_size, -1]) + + # Get min/max values + if channel_wise is False: + _min, _max = weight.min(), weight.max() + optimize = False + else: + _min = weight.min(axis=axis, keepdim=True)[0] + _max = weight.max(axis=axis, keepdim=True)[0] + + max_v = 2**bits - 1 + min_v = 0 + min_max = [min_v, max_v] + + # Note: here we work with the inverse of the scale to avoid division and quantize instead via weight*scale + zero, the scale is inverted later on. + # clamp to avoid half-precision problems + scale = (max_v / (_max - _min)).clamp(max=2e4) + #!!!!!!!!!!!!!!! + min_max_axis = _max - _min + if (min_max_axis == 0).sum().item() > 0: + min_max_axis[min_max_axis == 0] = max_v + scale = (max_v / min_max_axis).clamp(max=2e4) + zero = -_min * scale + + if round_zero: + zero = torch.round(zero) + + # Fine-tune weights + if optimize: + scale, zero = self.optimize_weights(tensor=weight, scale=scale, zero=zero, min_max=min_max, axis=axis) + + # Quantize + # Necessary for fake quantization backprop + w_q = torch.round(weight * scale + zero).clamp(min_max[0], min_max[1]) + w_q = w_q.reshape(shape).int() + + scale = 1.0 / scale + if axis == 1: + scale = scale.reshape(shape[0], -1) + zero = zero.reshape(shape[0], -1) + else: + scale = scale.reshape(-1, shape[-1]) + zero = zero.reshape(-1, shape[-1]) + # cleanup + del weight, _min, _max + + return w_q, scale.to(tensor.dtype), zero.to(tensor.dtype) + + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]): + """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + if node.op_type != "MatMul": + return node # only care about MatMul for now + import torch + + logger.info(f"start to quantize {node.name} ...") + inputB = node.input[1] # noqa: N806 + b_pb, bs_graph = get_initializer(inputB, graph_stack) + if b_pb is None: + logger.info("MatMul doesn't have const weight. Skip to quantize") + return node # only care about constant weight + + b_array = onnx.numpy_helper.to_array(b_pb) + if len(b_array.shape) != 2: + logger.info("MatMul weight is not 2D. Skip to quantize") + return node # can only process 2-D matrix + b_array_torch = torch.from_numpy(b_array) + if torch.cuda.is_available(): + b_array_torch = b_array_torch.cuda() + quant_weight_torch, scales_torch, zero_points_torch = self.quantize_internal( + b_array_torch.T, bits=self.config.bits, group_size=self.config.block_size + ) + quant_weight_torch = quant_weight_torch.contiguous() + scales_torch = scales_torch.contiguous() + zero_points_torch = zero_points_torch.contiguous() + + packed_torch = torch.zeros( + (quant_weight_torch.shape[0], quant_weight_torch.shape[1] // 2), + dtype=torch.uint8, + device=quant_weight_torch.device, + ) + self.pack_on_row_fast_248bit(packed_torch, quant_weight_torch, self.config.bits) + scales = scales_torch.cpu().numpy() + zero_points = zero_points_torch.cpu().numpy() + b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy()) + b_quant.name = b_pb.name + "_Q4" + for input in bs_graph.input: + if input.name == inputB: + bs_graph.input.remove(input) + break + + scales_tensor = onnx.numpy_helper.from_array(scales) + scales_tensor.name = b_pb.name + "_scales" + bs_graph.initializer.extend([b_quant, scales_tensor]) + + input_names = [node.input[0], b_quant.name, scales_tensor.name] + zp_tensor = onnx.numpy_helper.from_array(zero_points) + zp_tensor.name = b_pb.name + "_zero_points" + bs_graph.initializer.extend([zp_tensor]) + input_names.append(zp_tensor.name) + + kwargs = {} + rows, cols = b_array.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["bits"] = self.config.bits + kwargs["block_size"] = self.config.block_size + + matmul_q4_node = onnx.helper.make_node( + "MatMulNBits", + inputs=input_names, + outputs=[node.output[0]], + name=node.name + "_Q4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) + + logger.info(f"complete quantization of {node.name} ...") + + return matmul_q4_node + + +def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]: + for gid in range(len(graph_path) - 1, -1, -1): + graph = graph_path[gid] + for tensor in graph.initializer: + if tensor.name == name: + return tensor, graph + return None, None + + +class DefaultWeightOnlyQuantizer: + def __init__(self, config: DefaultWeightOnlyQuantConfig): + self.config = config def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: """4b quantize fp32 weight to a blob""" @@ -137,7 +380,7 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: raise ValueError("Current int4 block quantization only supports 2D tensors!") rows, cols = fp32weight.shape - block_size = self.block_size + block_size = self.config.block_size blob_size = block_size // 2 k_blocks = (rows + block_size - 1) // block_size padded_rows = k_blocks * block_size @@ -149,23 +392,19 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray: packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype) zero_point = np.zeros(cols * ((k_blocks + 1) // 2), dtype="uint8") - quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.is_symmetric) + quantize_matmul_4bits(packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric) return (packed, scales, zero_point) - def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: + def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> NodeProto: """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" if node.op_type != "MatMul": return node # only care about MatMul for now logger.info(f"start to quantize {node.name} ...") - if node.name in self.nodes_to_exclude: - logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") - return node - inputB = node.input[1] # noqa: N806 - B, Bs_graph = MatMul4BitsQuantizer.__get_initializer(inputB, graph_stack) # noqa: N806 + B, Bs_graph = get_initializer(inputB, graph_stack) # noqa: N806 if B is None: logger.info("MatMul doesn't have const weight. Skip to quantize") return node # only care about constant weight @@ -188,7 +427,7 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) Bs_graph.initializer.extend([B_quant, scales_tensor]) input_names = [node.input[0], B_quant.name, scales_tensor.name] - if not self.is_symmetric: + if not self.config.is_symmetric: zp_tensor = onnx.numpy_helper.from_array(zero_points) zp_tensor.name = B.name + "_zero_points" Bs_graph.initializer.extend([zp_tensor]) @@ -199,8 +438,8 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) kwargs["K"] = rows kwargs["N"] = cols kwargs["bits"] = 4 - kwargs["block_size"] = self.block_size - if self.accuracy_level is not None: + kwargs["block_size"] = self.config.block_size + if self.config.accuracy_level is not None: kwargs["accuracy_level"] = self.accuracy_level matmul_q4_node = onnx.helper.make_node( @@ -216,6 +455,38 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: list[GraphProto]) return matmul_q4_node + +class MatMul4BitsQuantizer: + """Perform 4b quantization of constant MatMul weights""" + + def __init__( + self, + model: ModelProto | str, + block_size: int = 128, + is_symmetric: bool = False, + accuracy_level: int | None = None, + nodes_to_exclude=None, + algo_config: WeightOnlyQuantConfig = None, + ): + if nodes_to_exclude is None: + nodes_to_exclude = [] + self.model = ONNXModel(onnx.load(model)) if isinstance(model, str) else ONNXModel(model) + self.model_path = model if isinstance(model, str) else None + self.block_size = block_size + self.is_symmetric = is_symmetric + self.accuracy_level = accuracy_level + self.nodes_to_exclude = set(nodes_to_exclude) + self.node_quantizer = None + if algo_config is None: + algo_config = DefaultWeightOnlyQuantConfig( + block_size=block_size, is_symmetric=is_symmetric, accuracy_level=accuracy_level + ) + self.algo_config = algo_config + if algo_config.algorithm == "HQQ": + self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config) + elif algo_config.algorithm == "DEFAULT": + self.node_quantizer = DefaultWeightOnlyQuantizer(self.algo_config) + def _process_subgraph(self, graph_stack: list[GraphProto]): new_nodes = [] graph = graph_stack[-1] @@ -246,8 +517,15 @@ def _process_subgraph(self, graph_stack: list[GraphProto]): node = onnx.helper.make_node( # noqa: PLW2901 node.op_type, node.input, node.output, name=node.name, **kwargs ) - - new_nodes.append(self._q4_matmul_node_weight(node, graph_stack)) + out_node = None + if node.name in self.nodes_to_exclude: + logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") + out_node = node + elif self.algo_config is not None and self.algo_config.algorithm == "HQQ": + out_node = self.node_quantizer.quantize(node, graph_stack) + else: + out_node = self.node_quantizer.quantize(node, graph_stack) + new_nodes.append(out_node) graph.ClearField("node") graph.node.extend(new_nodes) @@ -300,7 +578,7 @@ def inc_dataloader(): from neural_compressor.adaptor.ox_utils.weight_only import gptq_quantize kwargs["percdamp"] = self.algo_config.percdamp - kwargs["blocksize"] = self.algo_config.blocksize + kwargs["blocksize"] = self.algo_config.block_size kwargs["actorder"] = self.algo_config.actorder kwargs["mse"] = self.algo_config.mse kwargs["perchannel"] = self.algo_config.perchannel @@ -316,7 +594,7 @@ def inc_dataloader(): logger.info(f"complete quantization of model with {algorithm} algorithm.") def process(self): - if self.algo_config is None: + if self.algo_config.algorithm in ["HQQ", "DEFAULT"]: # use a stack to keep track of sub-graphs graph_stack = [self.model.graph()] opset_import = self.model.opset_import() @@ -327,7 +605,6 @@ def process(self): has_ms_domain = True if not has_ms_domain: opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) - self._process_subgraph(graph_stack) self.model.clean_initializers() else: @@ -366,6 +643,14 @@ def parse_args(): parser.add_argument("--input_model", required=True, help="Path to the input model file") parser.add_argument("--output_model", required=True, help="Path to the output model file") parser.add_argument("--block_size", required=False, default=32, type=int, help="Block size for quantization") + parser.add_argument( + "--quant_method", + default="default", + type=str, + choices=["default", "hqq"], + help="the algorithm used to quantize weight", + ) + parser.add_argument("--bits", default=4, type=int, help="the target bits to represent weight") parser.add_argument( "--symmetric", required=False, @@ -411,12 +696,24 @@ def parse_args(): raise Exception(f"file {output_model_path} already exists") model = onnx.load(input_model_path) + if args.quant_method == "hqq": + quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits) + elif args.quant_method == "default": + quant_config = DefaultWeightOnlyQuantConfig( + block_size=args.block_size, is_symmetric=args.symmetric, accuracy_level=args.accuracy_level + ) + elif args.quant_method == "rtn": + quant_config = RTNWeightOnlyQuantConfig() + elif args.quant_method == "gptq": + quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size) + else: + raise ValueError(f"Unsupported quantization method: {args.quant_method}") + quant = MatMul4BitsQuantizer( model=model, - block_size=args.block_size, - is_symmetric=args.symmetric, accuracy_level=args.accuracy_level, nodes_to_exclude=args.nodes_to_exclude, + algo_config=quant_config, ) quant.process() quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 2ad20eafc2ef1..d294fd4e2b0e0 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #ifndef ORT_MINIMAL_BUILD +#include #include "core/common/span_utils.h" #include "core/framework/tensor.h" @@ -66,7 +67,9 @@ void QuantizeDequantize(std::vector& raw_vals, } void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accuracy_level, - bool has_zeropoint, bool use_float16, float fp16_abs_error = 0.02f) { + bool has_zeropoint, bool use_float16, bool has_g_idx = false, + bool zp_is_4bit = true, float fp16_abs_error = 0.02f) { + zp_is_4bit = zp_is_4bit | has_g_idx; RandomValueGenerator random{1234}; std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); std::vector input1_f_vals(random.Gaussian(std::vector({K, N}), 0.0f, 0.25f)); @@ -113,12 +116,40 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura test.AddAttribute("block_size", block_size); test.AddAttribute("bits", QBits); test.AddAttribute("accuracy_level", accuracy_level); + auto ceildiv = [](int64_t a, int64_t b) { return (a + b - 1) / b; }; + if (use_float16) { test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); test.AddInput("B", {q_cols, q_rows}, input1_vals, true); test.AddInput("scales", {static_cast(q_scale_size)}, ToFloat16(scales), true); if (has_zeropoint) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + if (zp_is_4bit) { + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + } else { + std::vector zp_f; + zp_f.reserve(q_zp_size_in_bytes * 2); + for (size_t i = 0; i < zp.size(); i++) { + zp_f.push_back(static_cast(zp[i] & 0xf)); + zp_f.push_back(static_cast((zp[i] >> 4) & 0xf)); + } + size_t ind = zp_f.size() - 1; + while (zp_f.size() != q_scale_size) { + zp_f.erase(zp_f.begin() + ind); + ind -= q_scale_size / N + 1; + } + + test.AddInput("zero_points", {static_cast(q_scale_size)}, ToFloat16(zp_f), true); + } + } else { + test.AddInput("", {0}, {}); + } + if (has_g_idx) { + int K_pad = gsl::narrow(ceildiv(K, block_size) * block_size); + std::vector g_idx(K_pad); + for (int64_t i = 0; i < K_pad; i++) { + g_idx[i] = gsl::narrow(i / block_size); + } + test.AddInput("g_idx", {static_cast(K_pad)}, g_idx, true); } test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); @@ -132,9 +163,34 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura test.AddInput("B", {q_cols, q_rows}, input1_vals, true); test.AddInput("scales", {static_cast(q_scale_size)}, scales, true); if (has_zeropoint) { - test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); - } + if (zp_is_4bit) { + test.AddInput("zero_points", {static_cast(q_zp_size_in_bytes)}, zp, true); + } else { + std::vector zp_f; + zp_f.reserve(q_zp_size_in_bytes * 2); + for (size_t i = 0; i < zp.size(); i++) { + zp_f.push_back(static_cast(zp[i] & 0xf)); + zp_f.push_back(static_cast((zp[i] >> 4) & 0xf)); + } + size_t ind = zp_f.size() - 1; + while (zp_f.size() != q_scale_size) { + zp_f.erase(zp_f.begin() + ind); + ind -= q_scale_size / N + 1; + } + test.AddInput("zero_points", {static_cast(q_scale_size)}, zp_f, true); + } + } else { + test.AddInput("", {0}, {}); + } + if (has_g_idx) { + int K_pad = gsl::narrow(ceildiv(K, block_size) * block_size); + std::vector g_idx(K_pad); + for (int64_t i = 0; i < K_pad; i++) { + g_idx[i] = gsl::narrow(i / block_size); + } + test.AddInput("g_idx", {static_cast(K_pad)}, g_idx, true); + } test.AddOutput("Y", {M, N}, expected_vals); if (accuracy_level == 4) { test.SetOutputAbsErr("Y", 0.1f); @@ -158,6 +214,8 @@ TEST(MatMulNBits, Float32) { for (auto accuracy_level : {0}) { RunTest(M, N, K, block_size, accuracy_level, false, false); RunTest(M, N, K, block_size, accuracy_level, true, false); + RunTest(M, N, K, block_size, accuracy_level, false, false, true); + RunTest(M, N, K, block_size, accuracy_level, true, false, false, false); } #endif } @@ -172,8 +230,10 @@ TEST(MatMulNBits, Float16) { for (auto N : {1, 2, 32, 288}) { for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { for (auto block_size : {16, 32, 64, 128}) { - RunTest(M, N, K, block_size, 0, false, true); - RunTest(M, N, K, block_size, 0, true, true); + for (auto has_gidx : {true, false}) { + RunTest(M, N, K, block_size, 0, false, true, has_gidx); + RunTest(M, N, K, block_size, 0, true, true, has_gidx, false); + } } } } @@ -183,9 +243,9 @@ TEST(MatMulNBits, Float16) { TEST(MatMulNBits, Float16Large) { for (auto block_size : {16, 32, 64, 128}) { for (auto symmetric : {false, true}) { - RunTest(1, 4096, 4096, block_size, 0, symmetric, true, 0.05f); - RunTest(1, 4096, 11008, block_size, 0, symmetric, true, 0.05f); - RunTest(1, 11008, 4096, block_size, 0, symmetric, true, 0.05f); + RunTest(1, 4096, 4096, block_size, 0, symmetric, true, false, true, 0.05f); + RunTest(1, 4096, 11008, block_size, 0, symmetric, true, false, true, 0.05f); + RunTest(1, 11008, 4096, block_size, 0, symmetric, true, false, true, 0.05f); } } } diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index c1bbb49f10c7e..b30282f2ab41f 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -358,6 +358,7 @@ def check_model_correctness( model_onnx = onnx.load(f) ops_set = set(node.op_type for node in model_onnx.graph.node) check_reference_evaluator = not (ops_set & {"EmbedLayerNormalization", "Conv", "Attention", "Transpose"}) + check_target_evaluator = False with open(model_path_to_check, "rb") as f: model_check = onnx.load(f) @@ -413,7 +414,7 @@ def check_model_correctness( check_sign_f8_quantization(model_path_origin, model_path_to_check) # Verifies the expected outputs. - if check_reference_evaluator and onnx_recent_enough: + if check_target_evaluator and onnx_recent_enough: if op_matmul: reference_new_ops = [QLinearMatMul] else: diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py index 73dae08af8ece..88e5052db4e2e 100644 --- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py +++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py @@ -125,7 +125,10 @@ def quant_test( from onnxruntime.quantization import matmul_4bits_quantizer model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) - quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric) + quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig( + block_size=block_size, is_symmetric=is_symmetric + ) + quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config) quant.process() quant.model.save_model_to_file(model_int4_path, False) @@ -165,6 +168,9 @@ def quant_test_with_algo( elif algorithm == "GPTQ": # test GPTQ algorithm algo_config = matmul_4bits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader) + elif algorithm == "HQQ": + # test HQQ algorithm + algo_config = matmul_4bits_quantizer.HQQWeightOnlyQuantConfig(block_size=block_size) model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config) @@ -227,6 +233,17 @@ def test_quantize_matmul_int4_using_gptq_algo(self): data_reader = self.input_feeds(1, {"input": [100, 52]}) self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits" + ) + def test_quantize_matmul_int4_using_hqq_algo(self): + if not find_spec("torch"): + self.skipTest("skip test_hqq_quant since torch is not installed") + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, symmetric=False) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + self.quant_test_with_algo("HQQ", model_fp32_path, data_reader, 32, False) + if __name__ == "__main__": unittest.main() From cd56ea4a74ee41c040899d702667d2c86bee4ef0 Mon Sep 17 00:00:00 2001 From: guyang3532 <62738430+guyang3532@users.noreply.github.com> Date: Tue, 5 Mar 2024 13:15:30 +0800 Subject: [PATCH 193/207] enable embedding sparse optimization by default (#19714) --- docs/ORTModule_Training_Guidelines.md | 2 +- .../training/ortmodule/_graph_execution_manager.py | 14 +++++++++----- .../python/training/ortmodule/options.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index f50b18b736936..84631bd1f6555 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -246,7 +246,7 @@ to standard outputs. #### ORTMODULE_ENABLE_EMBEDDING_SPARSE_OPTIMIZER - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, this is disabled. This env var can be used for enabling or disabling the embedding input +- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the embedding input data sparsity based performance optimizations. ```bash diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index fda6e345da235..e189ffff9cc7f 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -681,11 +681,15 @@ def _enable_conditional_optimizations( ) if self._runtime_options.enable_embedding_sparse_optimizer and len(embed_sparsity_results) > 0: - graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys()) - self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results) - self._runtime_options.embed_sparsity_ratio = ",".join( - [f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()] - ) + if detected_device.type == "cuda": + # Embedding sparsity optimization is only supported on CUDA devices. + graph_transformer_config.sparse_embedding_input_names = list(embed_sparsity_results.keys()) + self._logger.info("Embedding sparsity-based optimization is ON for %s", embed_sparsity_results) + self._runtime_options.embed_sparsity_ratio = ",".join( + [f"{k}:{v:.0f}%" for k, v in embed_sparsity_results.items()] + ) + else: + self._logger.info("Embedding sparsity-based optimization is not supported on non-CUDA devices.") # If users don't want to print input density, disable the input density observer to avoid overhead # when looping through inputs during training. diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 539859a0d58a6..93d24a34df6bd 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -271,7 +271,7 @@ def __init__(self, logger: Logger): self.enable_sparse_optimizer = True self.label_sparsity_ratio = "" self.embed_sparsity_ratio = "" - self.enable_embedding_sparse_optimizer = False # TODO(pengwa): remove once validation on more models are done. + self.enable_embedding_sparse_optimizer = True # Configuration for memory optimization. self.memory_optimization_level = ( From bdf678df93cb257e311de3fa82fe6409be2854ff Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Tue, 5 Mar 2024 17:09:42 +0100 Subject: [PATCH 194/207] Fix CUDA BatchNorm bugs and add support for NHWC (#19742) ### Description - Fix incorrect running_mean / running_var in training mode due to incorrect momentum and missing input mean/var. runnig_var could be correct, but has a too high epsilon. - Fix incorrect checks when using NHWC - Pass NHWC flag to NormalizeDims to get correct new dimensions from x_shape - Register missing double operations to get parity between NHWC/NCHW --- .../core/providers/cpu/nn/batch_norm_helper.h | 41 +++++++++++++------ .../providers/cuda/cuda_execution_provider.cc | 18 +++++--- .../core/providers/cuda/cuda_nhwc_kernels.cc | 16 ++++++++ .../core/providers/cuda/nn/batch_norm.cc | 11 ++++- .../providers/cpu/nn/batch_norm_op_test.cc | 1 + 5 files changed, 66 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h index a5d46aff83b50..ccecbabfa3db3 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm_helper.h @@ -25,6 +25,8 @@ class BatchNormHelper { const Tensor* var, bool is_spatial = true, bool is_nhwc = false) { + // NHWC dependent shape: X + // All other shapes are assumed to be in NCHW layout? const auto& x_dims = X->Shape().GetDims(); // If x_dims size < 2, num_channels defaults to 1. @@ -48,16 +50,22 @@ class BatchNormHelper { // validate 'scales' shape const auto& scale_dims = scale->Shape().GetDims(); if (static_cast(scale_dims.size()) != kNumInputScaleDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input scale: NumDimensions() != ", kNumInputScaleDimensions); } if (scale_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: 0th dimension != ", num_channels); } + // N & C do not belong to features + // skip the first element for NHWC and the first two elements for NCHW. + int feature_offset = is_nhwc ? 1 : 2; + // in non-spatial cases - the other dims of 'scale' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (scale_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (scale_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input scale: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -65,7 +73,8 @@ class BatchNormHelper { // validate 'B' shape const auto& B_dims = B->Shape().GetDims(); if (static_cast(B_dims.size()) != kNumInputBiasDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input B: NumDimensions() != ", kNumInputBiasDimensions); } if (B_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: 0th dimension != ", num_channels); @@ -73,8 +82,9 @@ class BatchNormHelper { // in non-spatial cases - the other dims of 'B' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (B_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (B_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input B: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -82,16 +92,19 @@ class BatchNormHelper { // validate 'mean' shape const auto& mean_dims = mean->Shape().GetDims(); if (static_cast(mean_dims.size()) != kNumInputMeanDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input mean: NumDimensions() != ", kNumInputMeanDimensions); } if (mean_dims[0] != num_channels) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: 0th dimension != ", num_channels); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input mean: 0th dimension != ", num_channels); } // in non-spatial cases - the other dims of 'mean' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (mean_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (mean_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input mean: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } @@ -99,7 +112,8 @@ class BatchNormHelper { // validate 'var' shape const auto& var_dims = var->Shape().GetDims(); if (static_cast(var_dims.size()) != kNumInputVarianceDimensions) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Invalid input var: NumDimensions() != ", kNumInputVarianceDimensions); } if (var_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: 0th dimension != ", num_channels); @@ -107,8 +121,9 @@ class BatchNormHelper { // in non-spatial cases - the other dims of 'var' must be validated if (!is_spatial) { for (int feature = 0; feature < num_feature_dims; ++feature) { - if (var_dims[1 + feature] != x_dims[2 + feature]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), " dimension != ", x_dims[2 + feature]); + if (var_dims[1 + feature] != x_dims[feature_offset + feature]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid input var: ", (1 + feature), + " dimension != ", x_dims[feature_offset + feature]); } } } diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 1ce089fd93044..8ba282031a5d4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1202,9 +1202,12 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, LSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, LSTM); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, double, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME( + kCudaExecutionProvider, kOnnxDomain, 14, 14, MLFloat16, BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, ReduceMin); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, ReduceMin); @@ -2107,9 +2110,12 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index f416caecd115f..64edc319e15ac 100644 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -18,10 +18,14 @@ namespace onnxruntime::cuda { class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 9, 13, MLFloat16, BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, @@ -72,10 +76,14 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalN class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 12, MLFloat16, MaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, float, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, double, + BatchNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 14, 14, MLFloat16, BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, float, BatchNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, double, + BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16, BatchNormalization); @@ -86,18 +94,26 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { kCudaExecutionProvider, kMSInternalNHWCDomain, 7, 8, MLFloat16, BatchNormalization)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo::ComputeInternal(OpKernelContext* p_op_kernel_context) CudnnTensor data_desc; vector new_dims; - BatchNormHelper::NormalizeDims(x_shape, new_dims); + BatchNormHelper::NormalizeDims(x_shape, new_dims, NHWC); ORT_RETURN_IF_ERROR(data_desc.Set(new_dims, CudnnTensor::GetDataType(), NHWC)); // For half data type, the alpha, beta, scale, B, mean, var need to be float type @@ -137,6 +137,12 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) auto saved_mean_data = reinterpret_cast(saved_mean->MutableData()); auto saved_inv_var_data = reinterpret_cast(saved_var->MutableData()); + auto stream = static_cast(p_op_kernel_context->GetComputeStream()->GetHandle()); + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(running_mean_data, mean_data, mean->SizeInBytes(), cudaMemcpyDeviceToDevice, stream)); + CUDA_RETURN_IF_ERROR( + cudaMemcpyAsync(running_var_data, var_data, var->SizeInBytes(), cudaMemcpyDeviceToDevice, stream)); + CUDNN_RETURN_IF_ERROR(BatchNormalizationForwardTrainingHelper( GetCudnnHandle(p_op_kernel_context), cudnn_batch_norm_mode_, @@ -149,7 +155,7 @@ Status BatchNorm::ComputeInternal(OpKernelContext* p_op_kernel_context) bn_tensor_desc, scale_data, b_data, - momentum_, + 1.0 - momentum_, running_mean_data, running_var_data, epsilon_, @@ -186,6 +192,7 @@ SPECIALIZED_COMPUTE(MLFloat16, kOnnxDomain, false) #ifdef ENABLE_CUDA_NHWC_OPS SPECIALIZED_COMPUTE(float, kMSInternalNHWCDomain, true) +SPECIALIZED_COMPUTE(double, kMSInternalNHWCDomain, true) SPECIALIZED_COMPUTE(MLFloat16, kMSInternalNHWCDomain, true) #endif } // namespace cuda diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index cbb4531a50b7c..54e5c71bd753a 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -916,6 +916,7 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) { // exclude CUDA Execution Provider due to flakiness // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() test.Run(OpTester::ExpectResult::kExpectSuccess, "", + // TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1 {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } From 06e684c9f2f8495de5259967cc12bab24da3d522 Mon Sep 17 00:00:00 2001 From: Chen Fu <1316708+chenfucn@users.noreply.github.com> Date: Tue, 5 Mar 2024 09:37:45 -0800 Subject: [PATCH 195/207] Adding cuda kernel (optimized for sm80) for block-wise 4b quantized float 16 GEMM. (#18619) ### Description Adding CUDA kernel for block-wise 4b quantized float 16 GEMM, this is specially optimized for Nvidia Ampere GPUs. ### Motivation and Context Trying to improve quantized LLM inference performance on Nvidia Ampere GPUs ### Note: This is implemented by extending CUTLASS, so it has a hard dependency on CUTLASS. However, in current build system, loading of CUTLASS dependency is guarded with: (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) If both of these options are turned off, then compilation will fail. Why CUTLASS dependency is guarded at all? It's a header file only library that does not introduce any binary if not instantiated. What's the downside of removing all the guards and just include CUTLASS unconditionally? --- .lintrunner.toml | 1 + cmake/CMakeLists.txt | 5 +- cmake/onnxruntime_providers_cuda.cmake | 2 +- cmake/onnxruntime_unittests.cmake | 1 + onnxruntime/core/mickey/README.md | 4 + .../core/mickey/blk_q4/f16_gemm_sm80.h | 208 +++ .../{prepack_sm80.h => f16_prepack_sm80.h} | 2 +- .../cutlass_ext/q4gemm/device/quantb_gemm.h | 481 ++++++ .../q4gemm/kernel/default_quantb_gemm.h | 255 ++++ .../cutlass_ext/q4gemm/kernel/quantb_gemm.h | 462 ++++++ .../q4gemm/threadblock/default_quantb_mma.h | 248 ++++ .../threadblock/default_quantb_mma_core.h | 340 +++++ .../optional_predicated_tile_access_iter.h | 314 ++++ .../optional_regular_tile_access_iter.h | 224 +++ .../threadblock/quantb_mma_multistage.h | 1290 +++++++++++++++++ .../warp/default_quantb_mma_tensor_op.h | 112 ++ .../quantb_meta_mma_tensor_op_tile_iterator.h | 883 +++++++++++ .../q4gemm/warp/quantb_mma_tensor_op.h | 361 +++++ onnxruntime/core/util/matrix_layout.h | 1 - .../test/cuda_host/blkq4_fp16_quant_sm80.h | 203 +++ .../cuda/test_cases/blkq4_fp16_gemm_sm80.h | 188 +++ .../test_cases/blkq4_fp16_gemm_sm80_test.cc | 330 +++++ .../test_cases/blkq4_fp16_gemm_sm80_testcu.cu | 344 +++++ .../blkq4_fp16_sm80_prepack_test.cc | 507 ------- .../cuda_execution_provider_test.cc | 4 +- 25 files changed, 6257 insertions(+), 513 deletions(-) create mode 100644 onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h rename onnxruntime/core/mickey/blk_q4/{prepack_sm80.h => f16_prepack_sm80.h} (99%) create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h create mode 100644 onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h create mode 100644 onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h create mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h create mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc create mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu delete mode 100644 onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc diff --git a/.lintrunner.toml b/.lintrunner.toml index 4e5d077b08ff4..be95e03479cf9 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -132,6 +132,7 @@ exclude_patterns = [ 'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code 'onnxruntime/core/graph/contrib_ops/quantization_defs.cc', 'onnxruntime/core/mlas/**', # Contains assembly code + 'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS lib recommends NO automatic code formatting 'winml/lib/Api.Image/shaders/**', # Contains data chunks ] command = [ diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 8453da19ce3a6..0d55d4cab9826 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -727,6 +727,9 @@ if (onnxruntime_USE_CUDA) set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() + if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4) + message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4") + endif() else() set(onnxruntime_USE_FLASH_ATTENTION OFF) set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) @@ -747,8 +750,8 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1) endif() - endif() + if (onnxruntime_USE_VITISAI) list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_VITISAI=1) diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 0f6d48bdb6ec8..7f295a59a0931 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -201,7 +201,7 @@ endif() include(cutlass) - target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) + target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include) target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CUDAToolkit_INCLUDE_DIRS}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 88f662075e177..b004054c616a5 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -774,6 +774,7 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) onnxruntime_add_shared_library_module(onnxruntime_providers_cuda_ut ${onnxruntime_test_providers_cuda_ut_src} $) config_cuda_provider_shared_module(onnxruntime_providers_cuda_ut) onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock) + target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey) target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut) endif() diff --git a/onnxruntime/core/mickey/README.md b/onnxruntime/core/mickey/README.md index 7e8d30cd1805b..735ec4b80daf3 100644 --- a/onnxruntime/core/mickey/README.md +++ b/onnxruntime/core/mickey/README.md @@ -4,3 +4,7 @@ Playful name for a template library of high performance cuda code that are often shared by various AI operators. The intention is to make this header files only, with no binary impact unless it is instantiated where it is needed. + +Currently cuda code are scattered in multiple locations in the repo. +Hopefully this can be the starting point of consolidating all cuda +code. diff --git a/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h new file mode 100644 index 0000000000000..52bff7e40dbe3 --- /dev/null +++ b/onnxruntime/core/mickey/blk_q4/f16_gemm_sm80.h @@ -0,0 +1,208 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blk_q4/f16_gemm_sm80.h + * + * Abstract: + * Entry point for Q4F16 GEMM kernel for SM80 devices. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass_ext/q4gemm/device/quantb_gemm.h" + +namespace onnxruntime { +namespace cuda { + +// +// This is the implementation of the quantized GEMM kernel for 16b float x blocked quantized 4b data type +// +template < + typename ElementDequant_, // <- data type of dequantized elements for gemm, fp16 or bf16 + typename QuantBlocking_, // <- weights block per scale, cutlass::MatrixShape + bool SmallM, // <- true if M <= 16 + bool kHasQuantOffset> +struct BlkQ4F16GemmImpl { + // + // Type definitions + // + + using ElementDequant = ElementDequant_; + using QuantBlocking = QuantBlocking_; + + static_assert(sizeof(ElementDequant) == 2, "q4f16gemm kerenl only support 16b operands!"); + + // Data types that are fixed for this kernel + using ElementAccumulator = float; + using ElementComputeEpilogue = ElementAccumulator; + using ElementInputA = ElementDequant; + using ElementOutput = ElementDequant; + + using ElementW = uint8_t; // <- Weight is int4, uint8 for two of them + + // We pack 4 weights into one 16b element, so as to leverage cutlass tile iterators + // for async shared memory loading and minimize bank conflict + using ElementWPack = ElementDequant; + + using ElementQScale = ElementDequant; // <- data type of quantization scale + using ElementQOffset = uint8_t; + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputWPack = cutlass::layout::ColumnMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + // Layout of quantization scale and offset, oriented to be loaded using less instructions + // in a warp tile + using LayoutInputQScale = + typename std::conditional::type; // <- layout of quantization scale + + using ShapeMMAThreadBlock = + typename std::conditional, + cutlass::gemm::GemmShape<128, 256, 64>>::type; + + static constexpr int MinN = QuantBlocking::kColumn > 32 ? QuantBlocking::kColumn : 32; + using ShapeMMAWarp = + typename std::conditional, + cutlass::gemm::GemmShape<64, 64, 64>>::type; + + using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + + // This code section describes the epilogue part of the kernel + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + + // Number of pipelines you want to use + static constexpr int NumStages = 3; + + using Gemm = cutlass::gemm::device::QuantBGemm< + ElementInputA, + LayoutInputA, + ElementWPack, + LayoutInputWPack, + ElementQScale, + typename std::conditional::type, + LayoutInputQScale, + QuantBlocking, + ElementOutput, + LayoutOutput, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ShapeMMAThreadBlock, + ShapeMMAWarp, + ShapeMMAOp, + EpilogueOp, + SwizzleThreadBlock, + NumStages>; + + using Arguments = typename Gemm::Arguments; + + // Invoke gemm kernel (the version with quantization offset) + static cutlass::Status run( + cudaStream_t stream, + const cutlass::gemm::GemmCoord& problem_size_, + cutlass::TensorRef ref_A_, + cutlass::TensorRef ref_B_, + cutlass::TensorRef ref_Qscale_, + cutlass::TensorRef ref_Qoffset_, + cutlass::TensorRef ref_C_, + cutlass::TensorRef ref_D_, + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { + if constexpr (!kHasQuantOffset) { + return cutlass::Status::kErrorNotSupported; + } else { + if constexpr (ShapeMMAThreadBlock::kM == 16) { + if (problem_size_.m() > 16) { + // For M > 16, the caller should have picked the + // kernel with bigger M + return cutlass::Status::kErrorNotSupported; + } + } + + // Construct Gemm arguments + Arguments args{ + problem_size_, + ref_A_, + ref_B_, + ref_Qscale_, + ref_Qoffset_, + ref_C_, + ref_D_, + epilogue_}; + + Gemm gemm_op; + + // Check if this GEMM can be run or not + cutlass::Status status = gemm_op.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Launch the CUTLASS GEMM kernel. + return gemm_op(args, nullptr, stream); + } + } + + // Invoke gemm kernel (the version without quantization offset) + static cutlass::Status run( + cudaStream_t stream, + const cutlass::gemm::GemmCoord& problem_size_, + cutlass::TensorRef ref_A_, + cutlass::TensorRef ref_B_, + cutlass::TensorRef ref_Qscale_, + cutlass::TensorRef ref_C_, + cutlass::TensorRef ref_D_, + typename EpilogueOp::Params epilogue_ = typename EpilogueOp::Params()) { + if constexpr (kHasQuantOffset) { + return cutlass::Status::kErrorNotSupported; + } else { + if constexpr (ShapeMMAThreadBlock::kM == 16) { + if (problem_size_.m() > 16) { + // For M > 16, the caller should have picked the + // kernel with bigger M + return cutlass::Status::kErrorNotSupported; + } + } + + // Construct Gemm arguments + Arguments args{ + problem_size_, + ref_A_, + ref_B_, + ref_Qscale_, + ref_C_, + ref_D_, + epilogue_}; + + Gemm gemm_op; + + // Check if this GEMM can be run or not + cutlass::Status status = gemm_op.can_implement(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Launch the CUTLASS GEMM kernel. + return gemm_op(args, nullptr, stream); + } + } +}; + +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h similarity index 99% rename from onnxruntime/core/mickey/blk_q4/prepack_sm80.h rename to onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h index e291ab39e8aa3..a08cfb97eed4a 100644 --- a/onnxruntime/core/mickey/blk_q4/prepack_sm80.h +++ b/onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h @@ -3,7 +3,7 @@ * Licensed under the MIT License. * * Module Name: - * prepack_sm80.h + * blk_q4/f16_prepack_sm80.h * * Abstract: * Prepack weights and quantization parameters (scales and offsets) for diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h new file mode 100644 index 0000000000000..38795291b0328 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/device/quantb_gemm.h @@ -0,0 +1,481 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_gemm.h + * @brief Modified from cutlass/gemm/device/gemm.h, boilerplate code passing input pointers to the kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/gemm/kernel/gemm.h" + +#include "cutlass_ext/q4gemm/kernel/default_quantb_gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" + +#include "cutlass/layout/permute.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/*! A specialized GEMM operator for quantized B GEMM. + + It is modified from cutlass::gemm::device::Gemm. Both this class and the original Gemm class + are pretty much boilerplate code that construct the Gemm kernel class, and pass parameters + and controls to it. The only difference is that this class has a few more template parameters + to support quantization. + + This implementation pretty much follows the design of cutlass. But this class seems to be + just a wrapper of the Gemm kernel class. Consider combining them in future iterations. + +*/ +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout type for quant scales and offsets + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm80, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = + typename threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute> +class QuantBGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp = EpilogueOutputOp_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp::kCount; + static bool const kSplitKSerial = SplitKSerial; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + // Quantization Parameters + static_assert(std::is_same::value, + "LayoutB, i.e. packed weights must appear ColumnMajor."); + static_assert(InstructionShape::kK == 16, + "InstructionShape::kK must be a multiple of 16 (2 tiles), required by 4b weight packing layout."); + using ElementQScale = ElementQScale_; + using ElementQOffset = ElementQOffset_; + using LayoutQMeta = LayoutQMeta_; + using QuantBlocking = QuantBlocking_; + static constexpr bool kHasQOffset = !(std::is_same::value); + + // TODO(chenfucn): consider moving to uint4_t or smaller for QOffset + static_assert(!kHasQOffset || std::is_same::value, "QOffset must be uint8_t"); + + /// Define the kernel + using GemmKernel = typename kernel::DefaultQuantBGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementQScale, + ElementQOffset, + LayoutQMeta, + QuantBlocking, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kSplitKSerial, + Operator, + GatherA, + GatherB, + ScatterD, + PermuteDLayout + >::GemmKernel; + + /// Argument structure + struct Arguments { + // + // Data members + // + + GemmCoord problem_size; + TensorRef ref_A; + TensorRef ref_B; + TensorRef ref_C; + TensorRef ref_D; + TensorRef ref_Qscale; + TensorRef ref_Qoffset; + + typename EpilogueOutputOp::Params epilogue; + + // split-K parallelism (etc.) are not yet supported, keeping this for future extension + int split_k_slices{1}; + // For gather+scatter operations + int const *gather_A_indices{nullptr}; + int const *gather_B_indices{nullptr}; + int const *scatter_D_indices{nullptr}; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0) {} + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(!kHasQOffset); + } + + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord problem_size_, + TensorRef ref_A_, + TensorRef ref_B_, + TensorRef ref_Qscale_, + TensorRef ref_Qoffset_, + TensorRef ref_C_, + TensorRef ref_D_, + typename EpilogueOutputOp::Params epilogue_ = + typename EpilogueOutputOp::Params()): + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_Qscale(ref_Qscale_), + ref_Qoffset(ref_Qoffset_), + ref_C(ref_C_), + ref_D(ref_D_), + epilogue(epilogue_) { + assert(kHasQOffset); + } + }; + + private: + /// Kernel parameters object + typename GemmKernel::Params params_; + + public: + /// Constructs the GEMM. + QuantBGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + + Status status = GemmKernel::can_implement( + args.problem_size, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial && args.split_k_slices > 1) { + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename GemmKernel::Params{ + args.problem_size, + grid_shape, + args.ref_A.non_const_ref(), + args.ref_B.non_const_ref(), + args.ref_Qscale.non_const_ref(), + args.ref_Qoffset.non_const_ref(), + args.ref_C.non_const_ref(), + args.ref_D, + args.epilogue, + static_cast(workspace), + args.gather_A_indices, + args.gather_B_indices, + args.scatter_D_indices + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A.reset(args.ref_A.non_const_ref().data()); + params_.ref_B.reset(args.ref_B.non_const_ref().data()); + params_.ref_Qscale.reset(args.ref_Qscale.non_const_ref().data()); + params_.ref_Qoffset.reset(args.ref_Qoffset.non_const_ref().data()); + params_.ref_C.reset(args.ref_C.non_const_ref().data()); + params_.ref_D.reset(args.ref_D.data()); + params_.output_op = args.epilogue; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(GemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + std::cerr << "Failed to obtain maximum shared memory size " << smem_size << " for kernel: " + << cudaGetErrorString(result) << "\n"; + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h new file mode 100644 index 0000000000000..2f4460bb59e9f --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/default_quantb_gemm.h @@ -0,0 +1,255 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_gemm.h + * @brief Modified from cutlass/gemm/kernel/default_gemm.h. templates for combining + * threadblock-scoped matrix multiply-add with the appropriate + * threadblock-scoped epilogue. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass_ext/q4gemm/kernel/quantb_gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#include "cutlass/layout/permute.h" + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) +#include "cutlass/epilogue/threadblock/default_epilogue_wmma_tensor_op.h" +#endif //CUTLASS_ARCH_WMMA_ENABLED + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout type for quant scales and offsets + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Access granularity of quant scales in units of elements + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Scatter result D by using an index array + bool ScatterD = false, + /// Permute result D + typename PermuteDLayout = layout::NoPermute, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute, + /// + typename Enable = void +> +struct DefaultQuantBGemm; + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale, + /// Element type for quant offsets + typename ElementQOffset, + /// Layout type for quant scales + typename LayoutQMeta, + /// Blocking dimensions for quantization + typename QuantBlocking, + /// Access granularity of quant scales in units of elements + typename ElementC, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// If true, kernel is configured to support serial reduction in the + /// epilogue + bool SplitKSerial, + /// Operation performed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Scatter result D by using an index array + bool ScatterD, + /// Permute result D + typename PermuteDLayout, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout +> +struct DefaultQuantBGemm { + + static_assert((platform::is_same::value + || platform::is_same>::value), + "Epilogue in the kernel level must be row major"); + + /// Define the threadblock-scoped matrix multiply-accumulate + using Mma = typename cutlass::gemm::threadblock::DefaultQuantBMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape, WarpShape, InstructionShape, Stages, + Operator, false, GatherA, GatherB, + PermuteALayout, PermuteBLayout>::ThreadblockMma; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using RegularEpilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount, ScatterD, PermuteDLayout>::Epilogue; + + using Affine2Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOpAffineRankN< + 2, ThreadblockShape, typename Mma::Operator, kPartitionsK, EpilogueOutputOp, + EpilogueOutputOp::kCount>::Epilogue; + + using Epilogue = typename platform::conditional::value, + RegularEpilogue, + Affine2Epilogue>::type; + + /// Define the kernel-level GEMM operator. + using GemmKernel = kernel::QuantBGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h new file mode 100644 index 0000000000000..6e5ad8f406147 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/kernel/quantb_gemm.h @@ -0,0 +1,462 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_gemm.h + * @brief Modified from cutlass/gemm/kernel/gemm.h. + * Template for a pipelined GEMM kernel. Does not compute batching or support split-K. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. +> +struct QuantBGemm { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using OutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static bool const kSplitKSerial = SplitKSerial; + + static constexpr bool kHasQOffset = Mma::kHasQOffset; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorA::TensorRef ref_A; + typename Mma::IteratorB::Params params_B; + typename Mma::IteratorB::TensorRef ref_B; + typename Mma::IteratorQScale::Params params_QScale; + typename Mma::IteratorQScale::TensorRef ref_QScale; + typename Mma::IteratorQOffset::Params params_QOffset; + typename Mma::IteratorQOffset::TensorRef ref_QOffset; + typename Epilogue::OutputTileIterator::Params params_C; + typename Epilogue::OutputTileIterator::TensorRef ref_C; + typename Epilogue::OutputTileIterator::Params params_D; + typename Epilogue::OutputTileIterator::TensorRef ref_D; + typename OutputOp::Params output_op; + int *semaphore; + int gemm_k_size; // how many k vectors are processed by this threadblock + // For gather+scatter operations + int const *gather_A_indices; + int const *gather_B_indices; + int const *scatter_D_indices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D, + typename OutputOp::Params output_op = typename OutputOp::Params(), + int *workspace = nullptr, + int const *gather_A_indices = nullptr, + int const *gather_B_indices = nullptr, + int const *scatter_D_indices = nullptr + ): + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + params_A(ref_A.layout()), + ref_A(ref_A), + params_B(ref_B.layout()), + ref_B(ref_B), + params_QScale(ref_QScale.layout()), + ref_QScale(ref_QScale), + params_QOffset(ref_QOffset.layout()), + ref_QOffset(ref_QOffset), + params_C(ref_C.layout()), + ref_C(ref_C), + params_D(ref_D.layout()), + ref_D(ref_D), + output_op(output_op), + gather_A_indices(gather_A_indices), + gather_B_indices(gather_B_indices), + scatter_D_indices(scatter_D_indices) { + int total_gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + + gemm_k_size = gemm_k_iterations * Mma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + QuantBGemm() { } + + /// Determines whether kernel satisfies alignment + CUTLASS_HOST_DEVICE + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::IteratorQScale::TensorRef ref_QScale, + typename Mma::IteratorQOffset::TensorRef ref_QOffset, + typename Epilogue::OutputTileIterator::TensorRef ref_C, + typename Epilogue::OutputTileIterator::TensorRef ref_D) { + + // TODO check problem_size K, N must be multiple of QuantBlocking + + static int const kAlignmentA = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = (platform::is_same>::value) + ? 32 + : (platform::is_same>::value) + ? 64 + : Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (problem_size.k() % Mma::Shape::kK != 0) { + // Currently we don't support this case due to the way + // predicate iterator works, it loads the partial tile + // in the first iteration and then the full tile in the + // remaining iterations. This will cause the blockwise + // quantization parameters to go out of step with the + // weights. We can fix this by adding a predicate iterator + // that loads the full tile in the first iterations and + // then the partial tile in the last iteration. + return Status::kErrorInvalidProblem; + } + + int qscale_k = problem_size.k() / Mma::QuantBlocking::kRow; + int qscale_n = problem_size.n() / Mma::QuantBlocking::kColumn; + if ((qscale_k == 0) || (qscale_k * Mma::QuantBlocking::kRow != problem_size.k())) { + // partial block not supported + return Status::kErrorInvalidProblem; + } + if ((qscale_n == 0) || (qscale_n * Mma::QuantBlocking::kColumn != problem_size.n())) { + // partial block not supported + return Status::kErrorInvalidProblem; + } + + if (!TensorRef_aligned(ref_QScale, Mma::IteratorQScale::AccessType::kElements)) { + return Status::kErrorMisalignedOperand; + } + + if constexpr(kHasQOffset) { + if (!TensorRef_aligned(ref_QOffset, Mma::IteratorQOffset::AccessType::kElements)) { + return Status::kErrorMisalignedOperand; + } + } + + if (!TensorRef_aligned(ref_C, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.k() * params.gemm_k_size, + }; + + cutlass::MatrixCoord tb_offset_B{ + (threadblock_tile_offset.k() * params.gemm_k_size) / 2, + (threadblock_tile_offset.n() * Mma::Shape::kN) / 2 + }; + + // Problem size is a function of threadblock index in the K dimension + int problem_size_k = min( + params.problem_size.k(), + (threadblock_tile_offset.k() + 1) * params.gemm_k_size); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + params.ref_A.data(), + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A, + params.gather_A_indices); + + typename Mma::IteratorB iterator_B( + params.params_B, + params.ref_B.data(), + {problem_size_k/2, params.problem_size.n()/2}, + thread_idx, + tb_offset_B, + params.gather_B_indices); + + const int qscale_k = problem_size_k / Mma::QuantBlocking::kRow; + const int qscale_n = params.problem_size.n() / Mma::QuantBlocking::kColumn; + + // should have been verified by can_implement() + assert((qscale_k > 0) && (qscale_k * Mma::QuantBlocking::kRow == problem_size_k)); + assert((qscale_n > 0) && (qscale_n * Mma::QuantBlocking::kColumn == params.problem_size.n())); + + cutlass::MatrixCoord tb_offset_QScale{ + threadblock_tile_offset.k() * (params.gemm_k_size/Mma::QuantBlocking::kRow), + threadblock_tile_offset.n() * (Mma::Shape::kN/Mma::QuantBlocking::kColumn) + }; + + typename Mma::IteratorQScale iterator_QScale( + params.params_QScale, + params.ref_QScale.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale, + nullptr); + + typename Mma::IteratorQOffset iterator_QOffset( + params.params_QOffset, + params.ref_QOffset.data(), + {qscale_k, qscale_n}, + thread_idx, + tb_offset_QScale); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + const int warp_idx = canonical_warp_idx(); + const int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_QScale, iterator_QOffset, accumulators); + } + + // + // Epilogue + // + + OutputOp output_op(params.output_op); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C( + params.params_C, + params.ref_C.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D( + params.params_D, + params.ref_D.data(), + params.problem_size.mn(), + thread_idx, + threadblock_offset, + params.scatter_D_indices + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_offset.k()); + + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, iterator_D, accumulators, iterator_C); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h new file mode 100644 index 0000000000000..0af604f090e1f --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma.h @@ -0,0 +1,248 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma.h + * @brief Modified from cutlass/gemm/threadblock/default_mma.h. + * Defining global memory data layout and iterators, combinging with mma core and + * pipelined GEMM kernel. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/wmma.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h" +#include "cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale_, + /// Element type for quant offsets + typename ElementQOffset_, + /// Layout for quant scales and offsets + typename LayoutQMeta_, + /// Blocking size for quantization + typename QuantBlocking_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Gather operand A by using an index array + bool GatherA = false, + /// Gather operand B by using an index array + bool GatherB = false, + /// Permute operand A + typename PermuteALayout = layout::NoPermute, + /// Permute operand B + typename PermuteBLayout = layout::NoPermute + > +struct DefaultQuantBMma; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row-major output (OperatorClass TensorOp) +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for quant scales + typename ElementQScale, + /// Element type for quant offsets + typename ElementQOffset, + /// Layout for quant scales and offsets + typename LayoutQMeta, + /// Blocking size for quantization + typename QuantBlocking, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + /// Gather operand A by using an index array + bool GatherA, + /// Gather operand B by using an index array + bool GatherB, + /// Permute operand A + typename PermuteALayout, + /// Permute operand B + typename PermuteBLayout + > +struct DefaultQuantBMma { + + static_assert(platform::is_same::value + || platform::is_same>::value, + "simt epilogue must be row major"); + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the MmaCore components + using MmaCore = typename cutlass::gemm::threadblock::DefaultQuantBMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementQScale, ElementQOffset, LayoutQMeta, QuantBlocking, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA, GatherA, PermuteALayout>; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB, GatherB, PermuteBLayout>; + + // Define iterators over tiles from the quant scales + using ThreadMapQScale = typename MmaCore::IteratorThreadMapQScale; + using AccessTypeQScale = + cutlass::Array; + using IteratorQScale = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + typename MmaCore::ThreadblockQShape, + ElementQScale, LayoutQMeta, 0, ThreadMapQScale, AccessTypeQScale>; + + using ThreadMapQOffset = typename MmaCore::IteratorThreadMapQOffset; + using AccessTypeQOffset = + cutlass::Array; + using IteratorQOffset = + cutlass::transform::threadblock::OptionalPredicatedTileAccessIterator< + typename MmaCore::ThreadblockQShape, ElementQOffset, LayoutQMeta, + 0, ThreadMapQOffset, AccessTypeQOffset, MmaCore::kThreads>; + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::QuantBMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + MmaCore::kCacheOpA, IteratorB, typename MmaCore::SmemIteratorB, + MmaCore::kCacheOpB, IteratorQScale, typename MmaCore::SmemIteratorQScale, + cutlass::arch::CacheOperation::Global, IteratorQOffset, + typename MmaCore::SmemIteratorQOffset, cutlass::arch::CacheOperation::Global, + ElementAccumulator, LayoutC, + typename MmaCore::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h new file mode 100644 index 0000000000000..ad322f6505200 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/default_quantb_mma_core.h @@ -0,0 +1,340 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma_core.h + * @brief Modified from cutlass/gemm/threadblock/default_mma_core.h. + * Defining data layout in shared memory, and its iterators. + */ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" + +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" +#include "cutlass/layout/tensor_op_multiplicand_sm80.h" + +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core.h" +#include "cutlass/gemm/threadblock/default_multistage_mma_complex_core_sm80.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_tensor_op_sm80.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear.h" +#include "cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template defininng default matrix multiply operators inferred from threadblock tile size, +/// global memory data layout, and target math instruction. +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Element data type of quant scale + typename ElementQScale, + /// Element data type of quant offset + typename ElementQOffset, + /// Layout of quant scale + typename LayoutQMeta, + /// Blocking dimensions for quantization + typename QuantBlocking, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DefaultQuantBMmaCore; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: tensor op class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Element data type of quant scale + typename ElementQScale_, + /// Element data type of quant offset + typename ElementQOffset_, + /// Layout of quant scale + typename LayoutQMeta_, + /// Blocking dimensions for quantization + typename QuantBlocking_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator_, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB> +struct DefaultQuantBMmaCore { + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + + using ElementQScale = ElementQScale_; + using ElementQOffset = ElementQOffset_; + using LayoutQMeta = LayoutQMeta_; + using QuantBlocking = QuantBlocking_; + + using ElementC = ElementC_; + using LayoutC = LayoutC_; + static int const kStages = Stages; + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + /// Number of warps present + using WarpCount = GemmShape; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."); + + /// Number of threads per warp + static int const kWarpSize = warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + /// Size of a threadblock-scoped access + static int const kAccessSizeInBits = 128; + + /// Default Operator + using Operator = Operator_; + + // Warp thread arrangement + static int const kWarpThreadArrangementContiguousA = + Shape::kK / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedA = + kWarpSize / kWarpThreadArrangementContiguousA; + + static int const kWarpThreadArrangementContiguousB = + (Shape::kK / 2) / (kAccessSizeInBits / sizeof_bits::value); + + static int const kWarpThreadArrangementStridedB = + kWarpSize / kWarpThreadArrangementContiguousB; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK>; + + using SmemLayoutB = layout::ColumnMajorTensorOpMultiplicandCrosswise< + sizeof_bits::value, Shape::kK/2>; + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementA, SmemLayoutA, 0, + IteratorThreadMapA>; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearWarpRakedThreadMap< + layout::PitchLinearShape, kThreads, + layout::PitchLinearShape, + kAccessSizeInBits / sizeof_bits::value>; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIterator< + MatrixShape, ElementB, SmemLayoutB, 1, + IteratorThreadMapB>; + + using SmemLayoutQScale = LayoutQMeta; + using SmemLayoutQOffset = LayoutQMeta; + + /// Threadblock-level quantization meta data shape + using ThreadblockQShape = MatrixShape; + static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow"); + static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn"); + static_assert(ThreadblockQShape::kCount > 0, "QuantBlocking too big to fit in a thread block!"); + static_assert(QuantBlocking::kRow == 1 || QuantBlocking::kColumn == 1, + "Only support single column or row quantize blocking!"); + static_assert(QuantBlocking::kColumn != 1 || std::is_same::value, + "Quant scale matrix's major dimension must have more elements, to facilitate fast loading!"); + + /// Threadblock-level quantization meta data shape in pitch-linear layout + using TBQPitchLinearShape = typename std::conditional< + std::is_same::value, + layout::PitchLinearShape, + layout::PitchLinearShape>::type; + + /// By default we would like to use 128b load. However, we can't load more than + /// a column at a time in a column major layout. + static int const kElementsPerAccessQScale = + (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous + ? TBQPitchLinearShape::kContiguous + : (kAccessSizeInBits / sizeof_bits::value); + + /// quant scale is tiny. Not all threads are needed. + static int const kAccessCntQScale = ThreadblockQShape::kCount / kElementsPerAccessQScale; + static int const kThreadsQScale = (kAccessCntQScale > kThreads) ? kThreads : kAccessCntQScale; + + using IteratorThreadMapQScale = transform::PitchLinearStripminedThreadMap< + TBQPitchLinearShape, kThreadsQScale, kElementsPerAccessQScale>; + + using SmemIteratorQScale = transform::threadblock::RegularTileAccessIterator< + ThreadblockQShape, ElementQScale, SmemLayoutQScale, 1, IteratorThreadMapQScale>; + + static int const kElementsPerAccessQOffset = + (kAccessSizeInBits / sizeof_bits::value) > TBQPitchLinearShape::kContiguous + ? TBQPitchLinearShape::kContiguous + : (kAccessSizeInBits / sizeof_bits::value); + static int const kAccessCntQOffset = ThreadblockQShape::kCount / kElementsPerAccessQOffset; + static int const kThreadsQOffset = (kAccessCntQOffset > kThreads) ? kThreads : kAccessCntQOffset; + + using IteratorThreadMapQOffset = transform::PitchLinearStripminedThreadMap< + TBQPitchLinearShape, kThreadsQOffset, kElementsPerAccessQOffset>; + + using SmemIteratorQOffset = transform::threadblock::OptionalRegularTileAccessIterator< + ThreadblockQShape, ElementQOffset, SmemLayoutQOffset, 1, IteratorThreadMapQOffset, kThreads>; + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level tensor op + using MmaTensorOp = typename cutlass::gemm::warp::DefaultQuantBMmaTensorOp< + WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB, + ElementQScale, SmemLayoutQScale, ElementQOffset, SmemLayoutQScale, QuantBlocking, + ElementC, LayoutC, Operator, WarpCount::kK>::Type; + + /// Policy used to define MmaPipelined + using MmaPolicy = MmaPolicy, + MatrixShape<0, 0>, WarpCount::kK>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h new file mode 100644 index 0000000000000..6f27a692a3a2e --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_predicated_tile_access_iter.h @@ -0,0 +1,314 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file optional_predicated_tile_access_iter.h + * @brief Templates for loading and storing optional tiles of matrix data. + * This iterator is just a wrapper of PredicatedTileAccessIterator, with + * the option to turn it off at compile time and minimize its runtime + * footprint. Also, it utilize the higher numbered threads in the + * threadblock when the iterator can not utilize all the threads. + */ + +#pragma once + +#include + +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + + +//////////////////////////////////////////////////////////////////////////////// + +/// Optional 2-D matrix data loader, when element is std::monostate, the +/// iterator becomes no-op with minimal runtime footprint. Also, it utilize the +/// higher numbered threads in the threadblock when the iterator can not utilize +/// all the threads. +/// +template < + /// Tile shape of the iterator + typename Shape_, + /// Element data type of the iterator, no-op when it is std::monostate + typename Element_, + /// Layout of the source matrix + typename Layout_, + int AdvanceRank_, + typename ThreadMap_, + typename AccessType_, + /// Number of threads in the threadblock, when provided, the iterator + /// will utilize the higher numbered threads + int kThreadBlockSize_ = -1> +class OptionalPredicatedTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + static constexpr int kAdvanceRank = AdvanceRank_; + static constexpr int kThreadblockSize = kThreadBlockSize_; + + static_assert(!std::is_same::value, + "Disabled Iterator failed to match the specialized version below."); + static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + + using Base = PredicatedTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using Mask = typename Base::Mask; + using TensorCoord = typename Base::TensorCoord; + using TensorRef = typename Base::TensorRef; + using Params = typename Base::Params; + using Pointer = typename Base::Pointer; + + static constexpr int kAccessesPerVector = Base::kAccessesPerVector; + + CUTLASS_HOST_DEVICE + static int flip_thread_id(int thread_id){ + if constexpr (kThreadblockSize > 0) { + return kThreadblockSize - 1 - thread_id; + } + return thread_id; + } + + public: + Base base_; + + /// Default constructor + OptionalPredicatedTileAccessIterator(): base_() {}; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : base_(params, pointer, extent, flip_thread_id(thread_id), threadblock_offset) {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : OptionalPredicatedTileAccessIterator(params, pointer, extent, thread_id, make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + base_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + base_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) { + base_.add_tile_offset(tile_offset); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return base_.get(); + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator &operator++() { + ++base_; + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator operator++(int) { + OptionalPredicatedTileAccessIterator self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + base_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + base_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) { + base_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) { + base_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return base_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for the disabled version +/// Reduce runtime overhead +/// +template < + /// Tile shape of the iterator + typename Shape_, + typename Layout_, + int AdvanceRank_, + typename ThreadMap_, + typename AccessType_, + int kThreadBlockSize_> +class OptionalPredicatedTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = std::monostate; + using Layout = Layout_; + static int const kAdvanceRank = AdvanceRank_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + static constexpr int kThreadblockSize = kThreadBlockSize_; + + using Base = PredicatedTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using Mask = typename Base::Mask; + using TensorCoord = typename Base::TensorCoord; + using TensorRef = typename Base::TensorRef; + using Params = typename Base::Params; + using Pointer = typename Base::Pointer; + + static constexpr int kAccessesPerVector = Base::kAccessesPerVector; + + public: + std::monostate base_; + + /// Default constructor + OptionalPredicatedTileAccessIterator(): base_() {}; + + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : base_() {} + + /// Construct a PredicatedTileAccessIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : base_() {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) {} + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_DEVICE + void add_tile_offset( + TensorCoord const &tile_offset) {} + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType *get() const { + return nullptr; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator &operator++() { + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + OptionalPredicatedTileAccessIterator operator++(int) { + return *this; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) {} + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() {} + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const &mask) {} + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask &mask) {} + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { return false; } +}; + +//////////////////////////////////////////////////////////////////////////////// +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h new file mode 100644 index 0000000000000..4b0ae5317f8bb --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/optional_regular_tile_access_iter.h @@ -0,0 +1,224 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file optional_regular_tile_access_iter.h + * @brief Templates implementing the address computation of storing of tiles + * from pitch-linear rank=2 tensors. + * + * This iterator is just a wrapper of RegularTileAccessIterator, with the + * option to turn it off at compile time and minimize its runtime footprint. + * Also, it utilize the higher numbered threads in the threadblock when the + * iterator can not utilize all the threads. + * + * Must be used in conjunction with OptionalPredicatedTileAccessIterator, + * with the same template parameters. + */ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/regular_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Optional 2-D tile iterator, when element is std::monostate, the iterator +/// becomes no-op with minimal runtime footprint. Also, it utilize the higher +/// numbered threads in the threadblock when the iterator can not utilize all +/// the threads. +/// +template < + /// Tile shape of the iterator + typename Shape_, + typename Element_, + typename Layout_, + int AdvanceRank, + typename ThreadMap_, + /// Number of threads in the threadblock, when not -1, the iterator + /// will utilize the higher numbered threads + int ThreadblockSize_ = -1, + int Alignment = + sizeof_bits::value * ThreadMap_::kElementsPerAccess / 8> +class OptionalRegularTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + static constexpr int kAlignment = Alignment; + static constexpr int kThreadblockSize = ThreadblockSize_; + + static_assert(!std::is_same::value, + "Disabled Iterator failed to match the specialized template"); + static_assert(kThreadblockSize == -1 || kThreadblockSize >= ThreadMap::kThreads, + "kThreadblockSize must be no smaller than ThreadMap::kThreads"); + + using Base = RegularTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using TensorRef = typename Base::TensorRef; + using TensorCoord = typename Base::TensorCoord; + using AccessType = typename Base::AccessType; + + CUTLASS_HOST_DEVICE + static int flip_thread_id(int thread_id){ + if constexpr (kThreadblockSize > 0) { + return kThreadblockSize - 1 - thread_id; + } + return thread_id; + } + + private: + + Base base_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : base_(ref, flip_thread_id(thread_id)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + base_.set_iteration_index(index); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + base_.add_pointer_offset(pointer_offset); + } + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + return base_.get(); + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator &operator++() { + ++base_; + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator operator++(int) { + RegularTileAccessIterator prev(*this); + this->operator++(); + + return prev; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) { + base_.add_tile_offset(coord); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization when Element is std::monostate, the iterator becomes no-op +/// +template < + typename Shape_, + typename Layout_, + int AdvanceRank, + typename ThreadMap_, + int ThreadblockSize_, + int Alignment> +class OptionalRegularTileAccessIterator{ + public: + + using Shape = Shape_; + using Element = std::monostate; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + static constexpr int kAlignment = Alignment; + static constexpr int kThreadblockSize = ThreadblockSize_; + + using Base = RegularTileAccessIterator; + + using LongIndex = typename Base::LongIndex; + using TensorRef = typename Base::TensorRef; + using TensorCoord = typename Base::TensorCoord; + using AccessType = typename Base::AccessType; + + private: + + std::monostate base_; + + public: + /// Construct a TileIterator with zero threadblock offset + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator(TensorRef ref, ///< Pointer to start of tensor + int thread_id ///< ID of each participating thread + ) + : base_() {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) {} + + /// Returns a pointer + CUTLASS_DEVICE + AccessType *get() const { + return nullptr; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator &operator++() { + return *this; + } + + /// Advances to the next tile in memory. + CUTLASS_HOST_DEVICE + OptionalRegularTileAccessIterator operator++(int) { + return *this; + } + + /// Adds a tile offset in the unit of tile. + /// In GEMM/Conv implementation, this is used to move in the k dimension in the shared memory. + /// Below layouts are the shared memory layouts. Current SM50 SIMT kernels only use col major A and row major B. + /// For row major A operand, k dimension is contiguous dimension; + /// For col major A operand, k dimension is strided dimension; + /// For row major B operand, k dimension is strided dimension; + /// For col major B operand, k dimension is contiguous dimension. + /// Below two classes map col/row major to the pitch linear coordinates used + /// in this base class. + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const &coord) {} +}; + +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h new file mode 100644 index 0000000000000..8b6bac8c5099a --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/threadblock/quantb_mma_multistage.h @@ -0,0 +1,1290 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_mma_multistage.h + * @brief Modified from cutlass/gemm/threadblock/mma_multistage.h. + * Added the quantized data memory pipeline, dequantization, and feeding + * to tensor cores. Mainloop pipeline is heavily modified. + */ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +#include "cutlass/util/debug.h" +#include "cutlass/util/device_dump.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// +namespace{ + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Utilities for printing layout for the prepacked weights and quantization parameters +/// +template< + /// Data type of the prepacked weights + typename ElementWeight, + /// Data type of the quant scales + typename ElementQScale, + /// Data type of the quant offsets + typename ElementQOffset> +struct QuantBLayoutDebug{ + static constexpr bool debug_smem = true; + static constexpr bool debug_fragment = true; + ElementWeight* smem_b_ptr_; + ElementQScale* smem_qscale_ptr_; + ElementQOffset* smem_qoffset_ptr_; + int warp_id_; + int lane_id_; + int block_id_; + + template + CUTLASS_DEVICE + static void print_fragment(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + static_assert(Size % 4 == 0, "Size must be multiple of 4"); + if constexpr (debug_fragment){ + if (block_id == 1 && warp_id == 0){ + const Element* ptr = reinterpret_cast(&frag); + for (int i = 0; i < Size/4; i++, ptr+=4){ + if constexpr(std::is_integral::value){ + printf("T%.2d%c%d, %3d, %3d, %3d, %3d\n", + threadIdx.x, label, i, + ptr[0], ptr[1], ptr[2], ptr[3]); + } else { + printf("T%.2d%c%d, %.3f, %.3f, %.3f, %.3f\n", + threadIdx.x, label, i, + float(ptr[0]), float(ptr[1]), float(ptr[2]), float(ptr[3])); + } + } + } + } + } + + template + CUTLASS_DEVICE + static void print_as_int4(cutlass::Array const& frag, char label, int block_id, int warp_id, int lane_id){ + constexpr int I8Size = Size * cutlass::sizeof_bits::value / 8; + static_assert(I8Size % 2 == 0, "Size must be multiple of 4"); + if constexpr (debug_fragment){ + if (block_id == 1 && warp_id == 0){ + const uint8_t* ptr = reinterpret_cast(&frag); + for (int i = 0; i < I8Size/2; i++, ptr+=2){ + printf("T%.2dW%d, %d, %d, %d, %d\n", threadIdx.x, i, ptr[0] & 0x0f, ptr[0] >> 4, ptr[1] & 0x0f, ptr[1] >> 4); + } + } + } + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dummy type when quant offset is not used, to avoid compilation error, +/// and reduce runtime footprint +/// +struct DummyType{ + std::monostate dummy_; + public: + DummyType() = default; + + CUTLASS_HOST_DEVICE + void* data() const { + return nullptr; + } + + CUTLASS_HOST_DEVICE + std::monostate& operator[](int idx) { + return dummy_; + } +}; + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class QuantBMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + static constexpr bool kHasQOffset = !std::is_same::value; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the prepacked weights + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // Tensor reference to the quantization scales + using TensorRefQScale = TensorRef; + using TensorRefQOffset = TensorRef; + + // Block size of the quantization (one set of quantization parameters per block of weights) + using QuantBlocking = typename Operator::QuantBlocking; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the prepacked weights in shared memory + using ShapeB = + MatrixShape; + + /// Shape of the quantization parameter matrix in shared memory + /// Validation done in mma core class ThreadblockQShape + using ShapeQScale = + MatrixShape<(Shape::kK / QuantBlocking::kRow) * kStages, + Shape::kN / QuantBlocking::kColumn>; + + using BufTypeQOffset = std::conditional_t, + DummyType>; + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for prepacked weights + AlignedBuffer operand_B; + + /// Buffer for quantization scales + AlignedBuffer operand_QScale; + + /// Buffer for quantization offsets + BufTypeQOffset operand_QOffset; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + CUTLASS_HOST_DEVICE + static typename Operator::SmemLayoutQScale LayoutQMeta() { + return Operator::SmemLayoutQScale::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); + } + + CUTLASS_HOST_DEVICE + static typename Operator::SmemLayoutQOffset LayoutQOffset() { + return Operator::SmemLayoutQOffset::packed({ShapeQScale::kRow, ShapeQScale::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the prepacked weights + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + + /// Returns a TensorRef to the quantization scales + CUTLASS_HOST_DEVICE + TensorRefQScale operand_QScale_ref() { + return TensorRefQScale{operand_QScale.data(), LayoutQMeta()}; + } + + CUTLASS_HOST_DEVICE + TensorRefQOffset operand_QOffset_ref() { + if constexpr (!kHasQOffset){ + return TensorRefQOffset(); + } else { + return TensorRefQOffset{operand_QOffset.data(), LayoutQOffset()}; + } + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + /// Iterator to load a warp-scoped tile of quant scales from shared memory + typename Operator::IteratorQMeta warp_tile_iterator_QScale_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + QuantBMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx), + warp_tile_iterator_QScale_(shared_storage.operand_QScale_ref(), + shared_storage.operand_QOffset_ref(), lane_idx) + {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterators over tiles of quant scales in global memory + typename IteratorQScale_, + /// Iterators over tiles of quant scales in shared memory + typename SmemIteratorQScale_, + /// Cache operation for quant scales + cutlass::arch::CacheOperation::Kind CacheOpQScale, + /// Iterators over tiles of quant scales in global memory + typename IteratorQOffset_, + /// Iterators over tiles of quant scales in shared memory + typename SmemIteratorQOffset_, + /// Cache operation for quant scales + cutlass::arch::CacheOperation::Kind CacheOpQOffset, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class QuantBMmaMultistage : + public QuantBMmaBase { +public: + ///< Base class + using Base = QuantBMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + using IteratorQScale = IteratorQScale_; + using IteratorQOffset = IteratorQOffset_; + using SmemIteratorQScale = SmemIteratorQScale_; + using SmemIteratorQOffset = SmemIteratorQOffset_; + using QuantBlocking = typename Base::QuantBlocking; + + static cutlass::arch::CacheOperation::Kind const kCacheOpQScale = CacheOpQScale; + static cutlass::arch::CacheOperation::Kind const kCacheOpQOffset = CacheOpQOffset; + static constexpr bool kHasQOffset = Base::kHasQOffset; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of packed weights + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const AsyncCopyIterationsPerStageQScale = + IteratorQScale::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of quant scale + static int const kAccessesPerGroupQScale = + (AsyncCopyIterationsPerStageQScale + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + static int const AsyncCopyIterationsPerStageQOffset = + IteratorQOffset::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of quant offset + static int const kAccessesPerGroupQOffset = + (AsyncCopyIterationsPerStageQOffset + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical + // accuracy, where each mainloop iteration first accumulates into a temporary + // set of freshly-cleared accumulators, which are subsequently added to the + // final accumulator set. + static bool const kStagedAccumulation = arch::UseStagedAccumulation::value; + }; + + private: + + + // Structure encapsulating pipeline state live from one iteration to the next + struct PipeState { + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + /// Temporary accumulator to facilitate staged-accumulation + FragmentC tmp_accum_; + + /// Pair of A fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentA warp_loaded_frag_A_[2]; + + /// Pair of B fragments used to overlap shared memory loads and math instructions + WarpLoadedFragmentB warp_loaded_frag_B_; + WarpTransformedFragmentB warp_transformed_frag_B_[2]; + + using WarpLoadedFragmentQScale = typename Operator::FragmentQScale; + WarpLoadedFragmentQScale warp_loaded_frag_QScale_; + + using WarpLoadedFragmentQOffset = typename std::conditional::type; + WarpLoadedFragmentQOffset warp_loaded_frag_QOffset_; + }; + + + private: + + // + // Data members + // + + /// Warp-level MMA operator + Operator warp_mma_; + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to write threadblock-scoped tile of quant meta data to shared memory + SmemIteratorQScale smem_iterator_QScale_; + SmemIteratorQOffset smem_iterator_QOffset_; + + /// Shared memory write stage index + int smem_write_stage_idx_; + + /// Shared memory read stage index + int smem_read_stage_idx_; + + /// very small meta data tensor require less threads to load + bool const should_load_qscale_; + bool const should_load_qoffset_; + + /// Shared memory pointers for debug dumping + static constexpr bool debug_layout = false; + using LayoutDebugType = typename std::conditional, + std::monostate>::type; + LayoutDebugType layout_debug_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + QuantBMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx), + smem_iterator_QScale_(shared_storage.operand_QScale_ref(), thread_idx), + smem_iterator_QOffset_(shared_storage.operand_QOffset_ref(), thread_idx), + should_load_qscale_(thread_idx < IteratorQScale::ThreadMap::kThreads), + should_load_qoffset_(thread_idx >= IteratorQOffset::kThreadblockSize - IteratorQOffset::ThreadMap::kThreads), + smem_write_stage_idx_(0), + smem_read_stage_idx_(0) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + if constexpr(debug_layout){ + layout_debug_.smem_b_ptr_ = shared_storage.operand_B_ref().data(); + layout_debug_.smem_qscale_ptr_ = shared_storage.operand_QScale_ref().data(); + if constexpr(kHasQOffset){ + layout_debug_.smem_qoffset_ptr_ = shared_storage.operand_QOffset_ref().data(); + } else { + layout_debug_.smem_qoffset_ptr_ = nullptr; + } + layout_debug_.warp_id_ = warp_idx; + layout_debug_.lane_id_ = lane_idx; + layout_debug_.block_id_ = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + } + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_QScale_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Advance shared memory read-iterators to the next stage + CUTLASS_DEVICE + void advance_smem_read_stage() + { + ++smem_read_stage_idx_; + + if (smem_read_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + this->warp_tile_iterator_A_.add_tile_offset({0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + this->warp_tile_iterator_QScale_.add_tile_offset({-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, 0}); + + smem_read_stage_idx_ = 0; + } + } + + /// Advance global memory read-iterators and shared memory write-iterators to the stage + CUTLASS_DEVICE + void advance_smem_write_stage( + IteratorA &iterator_A, + IteratorB &iterator_B, + IteratorQScale &iterator_QScale, + IteratorQOffset &iterator_QOffset) + { + // Advance global iterators + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + iterator_QScale.add_tile_offset({1, 0}); + + // Advance shared iterators + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + smem_iterator_QScale_.add_tile_offset({1, 0}); + + if constexpr (kHasQOffset) { + iterator_QOffset.add_tile_offset({1, 0}); + smem_iterator_QOffset_.add_tile_offset({1, 0}); + } + + // Increment shared memory write stage index + ++smem_write_stage_idx_; + + if (smem_write_stage_idx_ == Base::kStages) { + // Wrap back around to the 'start' of the circular buffer in shared memory + smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_iterator_QScale_.add_tile_offset({-Base::kStages, 0}); + if constexpr (kHasQOffset) { + smem_iterator_QOffset_.add_tile_offset({-Base::kStages, 0}); + } + smem_write_stage_idx_ = 0; + } + } + + CUTLASS_DEVICE + void copy_qscale_tiles(IteratorQScale &iterator_QScale){ + // Quant scale matrix is 1/block_size of the B matrix, for a 64x64 warp tile, + // it's only 64x64/block_size elements. For blocking size 16 ~ 64, it only + // takes 4 ~ 16 cp.async instructions to load. One warp has 32 threads, so + // it should be loaded in less than one cp.async instruction per thread. + // Even less for quant offset matrix. + static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, + "Quant scale should be loaded in one shot!"); + static_assert(IteratorQScale::kAccessesPerVector == 1, + "Quant scale should 1 access per vector!"); + + // Async Copy for quantization scale + typename IteratorQScale::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QScale_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQScale::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QScale.get(), iterator_QScale.valid()); + } + + CUTLASS_DEVICE + void copy_qoffset_tiles(IteratorQOffset & iterator_QOffset) { + static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, + "Quant offset should be loaded in one shot!"); + static_assert(IteratorQOffset::kAccessesPerVector == 1, + "Quant offset should 1 access per vector!"); + + if constexpr(kHasQOffset) { + // Async Copy for quantization offset + typename IteratorQOffset::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QOffset_.get()); + + constexpr int kSrcBytes = sizeof_bits::value * + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB &iterator_B, + int group_start = 0) { + auto group_start_A = group_start * Detail::kAccessesPerGroupA; + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + auto group_start_B = group_start * Detail::kAccessesPerGroupB; + iterator_B.set_iteration_index(group_start_B * + IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + /// GEMM prologue. Bootstrap the global->shared memory pipeline by fetching + /// the global fragments needed by the first kStages-1 threadblock mainloop iterations + CUTLASS_DEVICE + void prologue( + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + + // Async Copy for quantization scale + static_assert(Detail::AsyncCopyIterationsPerStageQScale == 1, "Quant scale should be loaded in one shot!"); + static_assert(IteratorQScale::kAccessesPerVector == 1, "Quant scale should 1 access per vector!"); + + typename IteratorQScale::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QScale_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQScale::ThreadMap::kElementsPerAccess / 8; + + auto gmem_ptr = iterator_QScale.get(); + + cutlass::arch::cp_async( + dst_ptr, gmem_ptr, iterator_QScale.valid()); + + if constexpr (kHasQOffset) { + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + + // Async Copy for quantization offset + static_assert(Detail::AsyncCopyIterationsPerStageQOffset == 1, "Quant offset should be loaded in one shot!"); + static_assert(IteratorQOffset::kAccessesPerVector == 1, "Quant offset should 1 access per vector!"); + typename IteratorQOffset::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_QOffset_.get()); + + constexpr int kSrcBytes = + sizeof_bits::value * + IteratorQOffset::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async( + dst_ptr, iterator_QOffset.get(), iterator_QOffset.valid()); + } + + // Move to the next write stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + + /// Wait until we have at least one completed global fetch stage + CUTLASS_DEVICE + void gmem_wait() + { + // Wait until we have at least one committed global fetch stage. (#uncommitted = Base::kStages - 1 - #committed) + cutlass::arch::cp_async_wait(); + __syncthreads(); + + if constexpr(debug_layout) { + if (LayoutDebugType::debug_smem && layout_debug_.block_id_ == 1) { + if (threadIdx.x == 0){ + printf("stage: %d\n", smem_write_stage_idx_); + } + cutlass::debug::dump_shmem(layout_debug_.smem_qscale_ptr_, Base::SharedStorage::ShapeQScale::kCount); + if constexpr(kHasQOffset){ + cutlass::debug::dump_shmem(layout_debug_.smem_qoffset_ptr_, Base::SharedStorage::ShapeQScale::kCount); + } + } + } + } + + /// Perform a threadblock mainloop iteration of matrix multiply-accumulate + CUTLASS_DEVICE + void mac_loop_iter( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // Loading next warp-level tiles from shared memory. This can be skipped on the very + // last iteration where: + // (gemm_k_iterations == (1 - Base::kStages)) && (warp_mma_k == (Base::kWarpGemmIterations - 1)) + // However, evaluating this condition seems more expensive than simply loading the tiles + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // All warp-tiles issue their share of global->shared fragment copies + copy_tiles_and_advance( + iterator_A, + iterator_B, + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k + 1) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + if (warp_mma_k == 0) { + copy_qscale_tiles(iterator_QScale); + } + if (warp_mma_k == 1) { + copy_qoffset_tiles(iterator_QOffset); + } + + // The second-to-last warp-tile also moves to the next global fetch stage + if (warp_mma_k == Base::kWarpGemmIterations - 2) { + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + } + + } + } + + /// Specialized mainloop iteration of matrix multiply-accumulate, for small M + CUTLASS_DEVICE + void mac_loop_iter_small_m( + PipeState &pipe_state, ///< [in|out] loop-carried pipeline state + FragmentC &accum, ///< [in|out] destination accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over quant scales in global memory + IteratorQOffset &iterator_QOffset, ///< [in|out] iterator over quant offsets in global memory + int &gemm_k_iterations) ///< [in|out] number of threadblock mainloop iterations remaining + { + // Unroll the warp-level MMA tiles of a threadblock's mainloop iteration + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + // In the case of small M, memory latency dominates. We try to move uses far + // from their definitions to hide latency. + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, warp_mma_k % Base::kWarpGemmIterations); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[(warp_mma_k) % 2], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + + // Loading next warp-level tiles from shared memory. + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_A_; + + // All warp-tiles issue their share of global->shared fragment copies + copy_tiles_and_advance( + iterator_A, + iterator_B, + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + // Execute the current warp-tile of MMA operations + if (Detail::kStagedAccumulation) { + warp_mma_( + pipe_state.tmp_accum_, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + pipe_state.tmp_accum_ + ); + + if (warp_mma_k == 0) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + pipe_state.tmp_accum_.clear(); + } + } else { + warp_mma_( + accum, + pipe_state.warp_loaded_frag_A_[warp_mma_k % 2], + pipe_state.warp_transformed_frag_B_[warp_mma_k % 2], + accum + ); + } + + // The second-to-last warp-tile also moves to the next global fetch stage + if (warp_mma_k == Base::kWarpGemmIterations - 2) { + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Move to the next global fetch stage + advance_smem_write_stage(iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + advance_smem_read_stage(); + + // Disable global fetching when done with global fetch iterations + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset){ + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + copy_qscale_tiles(iterator_QScale); + copy_qoffset_tiles(iterator_QOffset); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + } + + } + } + + + /// Perform the specified number of threadblock mainloop iterations of matrix + /// multiply-accumulate. Assumes prologue has been initiated. + CUTLASS_DEVICE + void gemm_iters( + int gemm_k_iterations, ///< number of threadblock mainloop iterations + FragmentC &accum, ///< [in|out] accumulator tile + IteratorA &iterator_A, ///< [in|out] iterator over A operand in global memory + IteratorB &iterator_B, ///< [in|out] iterator over B operand in global memory + IteratorQScale &iterator_QScale, ///< [in|out] iterator over QScale operand in global memory + IteratorQOffset &iterator_QOffset) ///< [in|out] iterator over QOffset operand in global memory + { + PipeState pipe_state; + + // Disable global fetching if done with global fetch iterations + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + iterator_QScale.clear_mask(gemm_k_iterations == 0 || !should_load_qscale_); + if constexpr(kHasQOffset) { + iterator_QOffset.clear_mask(gemm_k_iterations == 0 || !should_load_qoffset_); + } + + // Load first warp-tile's B fragment from shared memory + this->warp_tile_iterator_QScale_.load( + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + ++this->warp_tile_iterator_QScale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(pipe_state.warp_loaded_frag_B_); + ++this->warp_tile_iterator_B_; + + // Load first warp-tile's A fragment from shared memory + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_A_.load(pipe_state.warp_loaded_frag_A_[0]); + ++this->warp_tile_iterator_A_; + + copy_tiles_and_advance(iterator_A, iterator_B, 0); + + if constexpr(Shape::kM > 32) { + // the case of bigger m + if constexpr(debug_layout) { + if (LayoutDebugType::debug_fragment && layout_debug_.block_id_ == 1 && layout_debug_.warp_id_ == 0 && layout_debug_.lane_id_ == 0){ + printf("LINE %d, warp_tile_B kgroup %d\n", __LINE__, 0); + } + LayoutDebugType::print_as_int4(pipe_state.warp_loaded_frag_B_, 'W', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QScale_), 'Q', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + if constexpr(kHasQOffset){ + LayoutDebugType::print_fragment(Operator::IteratorQScale::debug_expand(pipe_state.warp_loaded_frag_QOffset_), 'O', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } + + warp_mma_.transform( + pipe_state.warp_transformed_frag_B_[0], + pipe_state.warp_loaded_frag_B_, + pipe_state.warp_loaded_frag_QScale_, + pipe_state.warp_loaded_frag_QOffset_); + + if constexpr(debug_layout) { + LayoutDebugType::print_fragment(pipe_state.warp_transformed_frag_B_[0], 'B', layout_debug_.block_id_, layout_debug_.warp_id_, layout_debug_.lane_id_); + } + } else { + // the case of small m + copy_qscale_tiles(iterator_QScale); + copy_qoffset_tiles(iterator_QOffset); + } + + if (Detail::kStagedAccumulation) { + pipe_state.tmp_accum_.clear(); + } + + // Mainloop + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + if constexpr(Shape::kM > 32) { + mac_loop_iter( + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); + } else { + mac_loop_iter_small_m( + pipe_state, + accum, + iterator_A, + iterator_B, + iterator_QScale, + iterator_QOffset, + gemm_k_iterations); + } + } + + if (Detail::kStagedAccumulation) { + plus plus_accum; + accum = plus_accum(accum, pipe_state.tmp_accum_); + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } + + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< iterator over quant scales in global memory + IteratorQScale iterator_QScale, + ///< Iterator over quant offsets in global memory + IteratorQOffset iterator_QOffset, + ///< initial value of accumulator + FragmentC const &src_accum) { + + // Prologue (start fetching iterations of global fragments into shared memory) + prologue(iterator_A, iterator_B, iterator_QScale, iterator_QOffset, gemm_k_iterations); + + // Wait until we have at least one completed global fetch stage + gmem_wait(); + + // Initialize destination accumulators with source accumulators + accum = src_accum; + + // Perform the MAC-iterations + gemm_iters(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_QScale, iterator_QOffset); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h new file mode 100644 index 0000000000000..2c49888c94504 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/default_quantb_mma_tensor_op.h @@ -0,0 +1,112 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file default_quantb_mma_tensor_op.h + * @brief Modified from cutlass/gemm/warp/default_mma_tensor_op.h + * Default warp-level GEMM operators selected by data type, size, and layouts of operands. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h" + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for m-by-n-by-kgroup +template < + /// Shape of one matrix production operation (concept: GemmShape) + typename WarpShape_, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape_, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Data type of quant scales + typename ElementQScale, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale, + /// Data type of quant offsets + typename ElementQOffset, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset, + /// Blocking size of quantization + typename QuantBlocking, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC, + /// Operator describing the tensor operation + typename Operator_ = arch::OpMultiplyAdd, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false> +struct DefaultQuantBMmaTensorOp { + using Policy = cutlass::gemm::warp::MmaTensorOpPolicy< + cutlass::arch::Mma, + cutlass::MatrixShape<1, 1> >; + + // Define the warp-level tensor op + using Type = cutlass::gemm::warp::QuantBMmaTensorOp< + WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, QuantBlocking, ElementC, LayoutC, + Policy, PartitionsK, AccumulatorsInRowMajor>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/gemm/warp/mma_complex_tensor_op_tile_iterator_sm80.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h new file mode 100644 index 0000000000000..4ba39dda3db8d --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h @@ -0,0 +1,883 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + * + * @file quantb_meta_mma_tensor_op_tile_iterator.h + * @brief Templates for loading quantization meta data for operand B + * from shared memory to fragments. This is meant to be used in + * lock step with the operand B tile iterator. Containing logic + * to figure out the operand B layout in the tensor core, + * and deliver each meta data element to its corresponding + * operand B element for dequantization. + */ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor_op_multiplicand_sm75.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/fast_math.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace{ + +struct b32_pair{ + uint32_t a; + uint32_t b; +}; + +struct fp16_quad{ + cutlass::half_t a; + cutlass::half_t b; + cutlass::half_t c; + cutlass::half_t d; +}; + +struct b16_quad{ + int16_t a; + int16_t b; + int16_t c; + int16_t d; +}; + +union b64 { + uint64_t single; + b32_pair pair; + b16_quad quard; + fp16_quad fp16_quad; +}; + +static_assert(sizeof(b64) == 8, "b64 should be 64 bits"); + +/// Convert packed 4b weights into fp16(weight + 16) +/// Current bit hacking only supports fp16, need to add bf16 later. +/// +template +CUTLASS_DEVICE +void weights2Half(cutlass::Array const &weights, + cutlass::Array& dest) +{ + static_assert(Size % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); + uint32_t* dest_pair = reinterpret_cast(dest.data()); + const uint32_t* w_oct = reinterpret_cast(weights.data()); + + CUTLASS_PRAGMA_UNROLL + for (int oct_idx = 0; oct_idx < Size/8; oct_idx++, w_oct++, dest_pair += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + + // static_cast(16 + weight) + // 4b weights are prepacked into [0, 2, 4, 6, 1, 3, 5, 7], so that adjacent weights + // are in different 16b half words, making it easier to convert to fp16. + asm volatile( + "{\n\t" + " shl.b32 %0, %4, 6;\n" + " shl.b32 %1, %4, 2;\n" + " shr.u32 %2, %4, 2;\n" + " shr.u32 %3, %4, 6;\n" + " lop3.b32 %0, %0, 0x03c003c0, 0x4c004c00, 0xea;\n" // a & 0x03c0 | 0x4c00 + " lop3.b32 %1, %1, 0x03c003c0, 0x4c004c00, 0xea;\n" + " lop3.b32 %2, %2, 0x03c003c0, 0x4c004c00, 0xea;\n" + " lop3.b32 %3, %3, 0x03c003c0, 0x4c004c00, 0xea;\n" + "}\n" + : "=r"(dest_pair[0]), "=r"(dest_pair[1]), + "=r"(dest_pair[2]), "=r"(dest_pair[3]) + : "r"(*w_oct)); +#else + assert(0); +#endif + } + +} + +} // namespace + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +// Traits to describe the layout of quantization meta data layout in a MMA fragment +// Since operand B is quantized on a per block basis, it's one meta data per block. + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTile{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ArchMmaOperator = ArchMmaOperator_; + + static_assert(Threads == 32, "This iterator should work in a warp only."); + + /// Shape of the curresponding operand B tile iterator + using TileShapeB = MatrixShape; + + // Tensor core operand B layout is a column major 4x8 tile, divided + // into 32 threads (T0 ~ T31) as shown below. Each element of the tile is 32b, + // so for fp16 it becomes 8 x 8, and int8 it becomes 16 x 8. + // T0 | T4 | T8 | T12 | T16 | T20 | T24 | T28 + // T1 | T5 | T9 | T13 | T17 | T21 | T25 | T29 + // T2 | T6 | T10 | T14 | T18 | T22 | T26 | T30 + // T3 | T7 | T11 | T15 | T19 | T23 | T27 | T31 + using CoreTile = layout::PitchLinearShape<4, 8>; + + /// Each thread holds a 32b fragment per tile: for half precision, it's 2 elements, 4 elements for int8 + static int const kNumBsPerCoreTileFragement = 32 / sizeof_bits::value; + + /// Each mma instruction can process either 1 or 2 tensor core operand B tiles (stacked on the k dimension) + static int const kBTilesPerMma = + sizeof_bits::value * ArchMmaOperator::FragmentB::kElements / 32; + static_assert(kBTilesPerMma == 1 || kBTilesPerMma == 2, "Only support 1 or 2 operand B tiles per mma."); + + /// Each operand B tile iterator load covers a number of mma instructions + static int const kMmaIterationsB = WarpShapeB::kColumn / ArchMmaOperator::Shape::kN; + + /// Number of B elements a fragment of meta data should cover + static int const kExpandedSize = kNumBsPerCoreTileFragement * kBTilesPerMma * kMmaIterationsB; + + // Now we figure out how many meta data elements to load for each TileShapeB + + /// Number of meta elements per CoreTile. + static int const kCoreTileFragementSize = (kNumBsPerCoreTileFragement + BlockingShape::kRow - 1) / BlockingShape::kRow; + + /// Number of core tiles per mma instruction, different from kBTilesPerMma when blocking size on K dimension + /// exceeds the tile depth, so two tiles share the same meta data + static int const kTilesPerMma = ((kBTilesPerMma == 2) && + (BlockingShape::kRow <= kNumBsPerCoreTileFragement * CoreTile::kContiguous)) + ? 2 : 1; + + /// stride to reach the meta data for the next CoreTile on the K dimension + static int const kKTileStride = (kNumBsPerCoreTileFragement * CoreTile::kContiguous + BlockingShape::kRow - 1) / BlockingShape::kRow; + + /// Stride on N dimension should be the tile width, shrunk by blocking size on this dimension. + static int const kNStride = (CoreTile::kStrided + BlockingShape::kColumn - 1) / BlockingShape::kColumn; + + /// On N dimension, how many tiles share the same meta data + static int const kNRepeats = (BlockingShape::kColumn + CoreTile::kStrided - 1) / CoreTile::kStrided; + + /// Each fragment should cover kMmaIterationsB number of mma intructions on the N dimension. + /// When blocking size on this dimension exceeds the tile width, multiple iterations + /// would share the same data. + static int const kMmaIterations = (kMmaIterationsB + kNRepeats - 1) / kNRepeats; + + static int const kFragementSize = kCoreTileFragementSize * kTilesPerMma * kMmaIterations; + + CUTLASS_DEVICE + static MatrixCoord lane_position(int lane_id) { + if constexpr(kNumBsPerCoreTileFragement == 2 + && kBTilesPerMma == 2 + && BlockingShape::kRow == 1){ + // Optimize for a special case of: + // 16b gemm (kNumBsPerCoreTileFragement == 2) + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking + // The scale and offset tensors are prepacked to reduce the number of load instructions. + return make_Coord((lane_id % CoreTile::kContiguous) * 4, + lane_id / CoreTile::kContiguous); + } else { + return make_Coord((lane_id % CoreTile::kContiguous) * kNumBsPerCoreTileFragement, + lane_id / CoreTile::kContiguous); + } + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// This tile iterator is to load quantization meta data for operand B from +/// shared memory to fragments (hopefully allocated to registers by compilers). +/// Examples of meta data include scale or offsets. The operand B matrix is +/// quantized on a per block basis, meaning one element of meta data per block. +/// +/// This is meant to be used in lock step with the operand B tile iterator. +/// So all parameters are logical positions in the operand B tiles. +/// The goal here is to deliver each meta data element to its corresponding +/// operand B element for dequantization. As a result, we need to figure +/// out the operand B layout in the tensor core. +/// +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the quant scales + typename ElementScale_, + /// Layout of the quant scales + typename LayoutScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Layout of quant offsets + typename LayoutOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads, + /// Number of partitions along K dimension + int PartitionsK_ = 1> +class QuantBMetaMmaTensorOpTileIterator; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column major layout + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTensorOpTileIterator{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ElementScale = ElementScale_; + using Layout = cutlass::layout::ColumnMajor; + using ElementOffset = ElementOffset_; + using ArchMmaOperator = ArchMmaOperator_; + + static constexpr bool kHasOffset = !(std::is_same::value); + + static_assert(BlockingShape::kRow == 1 && BlockingShape::kColumn > 1, + "Only support row blocking for column major layout"); + + using MetaTile = QuantBMetaMmaTile; + + /// Number of MMA instructions for this tile + static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB; + + /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8 + static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement; + + /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension) + static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma; + + /// Number of B elements a fragment of meta data should cover + static constexpr int kExpandedSize = MetaTile::kExpandedSize; + + /// Number of meta elements per core tile fragment + static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize; + + /// stride for reaching the next core tile (if there is one) on the K dimension + static constexpr int kKTileStride = MetaTile::kKTileStride; + + /// do we need to load meta data for the next core tile on the K dimension? + static constexpr int kTilesPerMma = MetaTile::kTilesPerMma; + + static constexpr int kNStride = MetaTile::kNStride; + static constexpr int kNRepeats = MetaTile::kNRepeats; + static constexpr int kMmaIterations = MetaTile::kMmaIterations; + + using TensorRefScale = TensorRef; + using TensorRefOffset = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using FragmentScale = Array; + using FragmentOffset = typename std::conditional, + std::monostate>::type; + + using AccessTypeScale = Array; + using AccessTypeOffset = Array; + +private: + + ElementScale *pointer_; + Layout layout_; + + ElementOffset *pointer_offset_; + Layout layout_offset_; + + TensorCoord lane_position_; + +public: + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator() { } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator( + TensorRefScale const &ref, + TensorRefOffset const &ref_offset, + int lane_idx + ): + pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)){} + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(FragmentScale &frag, FragmentOffset &frag_offset) { + if constexpr(kNumBsPerCoreTileFragement == 2 + && kBTilesPerMma == 2){ + // Optimize for a special case of: + // 16b gemm (kNumBsPerCoreTileFragement == 2) + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking (BlockingShape::kRow == 1) + // The scale and offset tensors are prepacked to reduce the number of load instructions needed + const int row = lane_position_.row(); + const int column = lane_position_.column() / BlockingShape::kColumn; + + Array *dst_ptr = reinterpret_cast*>(frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + Array *src_ptr = reinterpret_cast*>(pointer_ + layout_({row, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + + if constexpr(kHasOffset){ + Array *dst_ptr_offset = reinterpret_cast*>(frag_offset.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + Array *src_ptr_offset = reinterpret_cast*>(pointer_offset_ + layout_offset_({row, c})); + *dst_ptr_offset = *src_ptr_offset; + dst_ptr_offset++; + } + } + + } else { + // Other cases, offsets and scales are not prepacked. + + const int row = lane_position_.row() / BlockingShape::kRow; + const int column = lane_position_.column() / BlockingShape::kColumn; + + AccessTypeScale* dst_ptr = reinterpret_cast(frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + AccessTypeScale* src_ptr = reinterpret_cast(pointer_ + layout_({r, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + } + + if constexpr(kHasOffset){ + AccessTypeOffset* dst_ptr = reinterpret_cast(frag_offset.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0, c = column; n_idx < kMmaIterations; n_idx++, c += kNStride){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_idx = 0, r = row; mma_tile_idx < kTilesPerMma; mma_tile_idx++, r += kKTileStride){ + AccessTypeOffset* src_ptr = reinterpret_cast(pointer_offset_ + layout_offset_({r, c})); + *dst_ptr = *src_ptr; + dst_ptr++; + } + } + } + } + } + + template + CUTLASS_HOST_DEVICE + static Array debug_expand(Array const &frag){ + Array ret; + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; + ret[out_idx] = frag[idx]; + out_idx++; + } + } + } + return ret; + } + + CUTLASS_HOST_DEVICE + static void dequant(FragmentScale const &scales, + FragmentOffset const &offsets, + Array const &weights, + Array& dest){ + static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm."); + static_assert(kExpandedSize % 8 == 0, "Weights should have been prepacked by 2x2 tiles, 2 weights per tile."); + + // First convert 4b weight into fp16(weight + 16) + weights2Half(weights, dest); + + if constexpr(kBTilesPerMma == 2){ + // Optimize for a special case of: + // 2 B operand tiles per mma (kBTilesPerMma == 2) + // (1,n) quantization blocking (BlockingShape::kRow == 1) + + uint32_t* dest_pair = reinterpret_cast(dest.data()); + const b64* scales_ptr = reinterpret_cast(scales.data()); + const ElementOffset* offsets_ptr = nullptr; + if constexpr(kHasOffset) { offsets_ptr = offsets.data(); } + + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + // dequantize: d = scale * (weight - offset) + // to use FMA, d = scale * weight + (scale * (-offset)) + + b64 offsets; + if constexpr(kHasOffset){ + const uint32_t* p = reinterpret_cast(offsets_ptr); + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, %4, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, %4, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(offsets.pair.a), "=r"(offsets.pair.b) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); +#else + assert(0); +#endif + + offsets_ptr += 4; + } else { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(offsets.pair.a), "=r"(offsets.pair.b) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); +#else + offsets.fp16_quad.a = scales_ptr->fp16_quad.a * static_cast(-16-8); + offsets.fp16_quad.b = scales_ptr->fp16_quad.b * static_cast(-16-8); + offsets.fp16_quad.c = scales_ptr->fp16_quad.c * static_cast(-16-8); + offsets.fp16_quad.d = scales_ptr->fp16_quad.d * static_cast(-16-8); +#endif + } + + CUTLASS_PRAGMA_UNROLL + for (int n_r = 0; n_r < kNRepeats; n_r++){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " fma.rn.f16x2 %0, %2, %0, %4;\n" // dest = scale * (16 + weight) + (scale * (-16 - offset)) + " fma.rn.f16x2 %1, %3, %1, %5;\n" + "}\n" + : "+r"(dest_pair[0]), "+r"(dest_pair[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(offsets.pair.a), "r"(offsets.pair.b)); +#else + assert(0); +#endif + dest_pair += 2; + } + scales_ptr++; + } + + } else { + // unoptiomized path for other cases, very slow + int out_idx = 0; + ElementScale offset; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma; + ElementScale s = scales[idx]; + if constexpr(kHasOffset){ + offset = s * static_cast(-16 - int(offsets[idx])); + } else { + offset = s * static_cast(-16-8); + } + dest[out_idx] = s * dest[out_idx] + offset; + out_idx++; + } + } + } + + } + + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator++() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; + int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; + lane_position_ += TensorCoord(rows, columns); + return *this; + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for row major layout + +template < + /// Shape of the operand B matrix to load in a warp (concept: MatrixShape) + typename WarpShapeB_, + /// Block dimensions of the blockwise quantization. So the actual meta data + /// warp shape is WarpShapeB_ / BlockingShape_ + typename BlockingShape_, + /// Data type of the meta data elements + typename ElementScale_, + /// Data type of quant offsets + typename ElementOffset_, + /// Underlying matrix multiply operator (concept: arch::Mma) + typename ArchMmaOperator_, + /// Number of threads participating in one matrix operation + int Threads> +class QuantBMetaMmaTensorOpTileIterator{ +public: + + using WarpShapeB = WarpShapeB_; + using BlockingShape = BlockingShape_; + using ElementScale = ElementScale_; + using ElementOffset = ElementOffset_; + using Layout = cutlass::layout::RowMajor; + using ArchMmaOperator = ArchMmaOperator_; + + static constexpr bool kHasOffset = !(std::is_same::value); + + static_assert(BlockingShape::kColumn == 1 && BlockingShape::kRow > 1, + "Only support column blocking for row major layout"); + + using MetaTile = QuantBMetaMmaTile; + + /// Number of MMA instructions for this tile + static constexpr int kMmaIterationsB = MetaTile::kMmaIterationsB; + + /// Number of B elements per mma tile fragment (32b), 2 for half precision, 4 for int8 + static constexpr int kNumBsPerCoreTileFragement = MetaTile::kNumBsPerCoreTileFragement; + + /// Each mma instruction can process either 1 or 2 operand B tiles (stacked on the k dimension) + static constexpr int kBTilesPerMma = MetaTile::kBTilesPerMma; + + /// Number of B elements a fragment of meta data should cover + static constexpr int kExpandedSize = MetaTile::kExpandedSize; + + /// Number of meta elements per core tile fragment + static constexpr int kCoreTileFragementSize = MetaTile::kCoreTileFragementSize; + + /// stride for reaching the next core tile (if there is one) on the K dimension + static constexpr int kKTileStride = MetaTile::kKTileStride; + + /// do we need to load meta data for the next core tile on the K dimension? + static constexpr int kTilesPerMma = MetaTile::kTilesPerMma; + + static constexpr int kNStride = MetaTile::kNStride; + static constexpr int kNRepeats = MetaTile::kNRepeats; + static constexpr int kMmaIterations = MetaTile::kMmaIterations; + + using TensorRefScale = TensorRef; + using TensorRefOffset = TensorRef; + using TensorCoord = typename Layout::TensorCoord; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using StrideIndex = typename Layout::Stride::Index; + + using FragmentScale = Array; + using FragmentOffset = typename std::conditional, + std::monostate>::type; + +private: + + ElementScale *pointer_; + Layout layout_; + + ElementOffset *pointer_offset_; + Layout layout_offset_; + + TensorCoord lane_position_; + +public: + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator() { } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator( + TensorRefScale const &ref, + TensorRefOffset const &ref_offset, + int lane_idx + ): + pointer_(ref.data()), + layout_(ref.layout()), + pointer_offset_(ref_offset.data()), + layout_offset_(ref_offset.layout()), + lane_position_(MetaTile::lane_position(lane_idx)) + {} + + /// Loads a fragment + CUTLASS_HOST_DEVICE + void load(FragmentScale &frag, FragmentOffset &frag_offset) { + const int row = lane_position_.row() / BlockingShape::kRow; + const int column = lane_position_.column() / BlockingShape::kColumn; + static_assert(kTilesPerMma * kCoreTileFragementSize == 1, "Only support one meta data per core tile"); + + ElementScale* src_ptr = pointer_ + layout_({row, column}); + ElementScale* dst_ptr = frag.data(); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + dst_ptr[n_idx] = src_ptr[n_idx * kNStride]; + } + + if constexpr(kHasOffset){ + ElementOffset* src_ptr_offset = pointer_offset_ + layout_offset_({row, column}); + ElementOffset* dst_ptr_offset = frag_offset.data(); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){ + dst_ptr_offset[n_idx] = src_ptr_offset[n_idx * kNStride]; + } + } + } + + template + CUTLASS_HOST_DEVICE + static Array debug_expand(Array const &frag){ + Array ret; + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + int n_idx = n_out / kNRepeats; + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + int mma_tile_idx = mma_tile_out_idx / (kBTilesPerMma / kTilesPerMma); + CUTLASS_PRAGMA_UNROLL + for (int elem_out_idx = 0; elem_out_idx < kNumBsPerCoreTileFragement; elem_out_idx++){ + int elem_idx = elem_out_idx / BlockingShape::kRow; + int col = elem_idx + mma_tile_idx * kCoreTileFragementSize; + int idx = col * kMmaIterations + n_idx; + ret[out_idx] = frag[idx]; + out_idx++; + } + } + } + return ret; + } + + CUTLASS_HOST_DEVICE + static void dequant(FragmentScale const &scales, + FragmentOffset const &offsets, + Array const &weights, + Array& dest){ + static_assert(kNRepeats == 1, "This is implied by BlockingShape::kColumn == 1"); + static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm now."); + + // First convert 4b weight into fp16(weight + 16) + weights2Half(weights, dest); + + ElementScale addon[kMmaIterationsB]; + if constexpr (kMmaIterationsB % 4 == 0) { + const b64* scales_ptr = reinterpret_cast(scales.data()); + uint32_t* addon_ptr = reinterpret_cast(addon); + if constexpr(kHasOffset){ + const uint32_t* p = reinterpret_cast(offsets.data()); + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [d, c, b, a] --> [d, b, c, a] + " prmt.b32 rb2, %4, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " shr.u32 rb1, rb2, 2;\n" // rb1 = [x, d, x, c] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " lop3.b32 rb1, rb1, 0x03c003c0, 0xcc00cc00, 0xea;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - offset) + " mul.rn.f16x2 %1, %3, rb1;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b), + "r"(p[0])); +#else + assert(0); +#endif + scales_ptr++; + p++; + addon_ptr += 2; + } + } else { + CUTLASS_PRAGMA_UNROLL + for (int n_idx = 0; n_idx < kMmaIterationsB; n_idx += 4){ +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %2, rb0;\n" // offset = scale * (-16 - 8) + " mul.rn.f16x2 %1, %3, rb0;\n" + "}\n" + : "=r"(addon_ptr[0]), "=r"(addon_ptr[1]) + : "r"(scales_ptr->pair.a), "r"(scales_ptr->pair.b)); +#else + assert(0); +#endif + scales_ptr++; + addon_ptr += 2; + } + } + } else if constexpr (kMmaIterationsB % 2 == 0) { + const uint32_t* scales_ptr = reinterpret_cast(scales.data()); + uint32_t* addon_ptr = reinterpret_cast(addon); + + if constexpr (kHasOffset){ + // possible buffer over read 2 bytes here. + const uint32_t* p = reinterpret_cast(offsets.data()); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0, rb1, rb2;\n" + + // offset from [?, ?, b, a] --> [?, b, ?, a] + " prmt.b32 rb2, %2, rb0, 0x3120;\n" + + // static_cast(-16 - offset) + // input [d, b, c, a], + " shl.b32 rb0, rb2, 6;\n" // rb0 = [x, b, x, a] << 6 + " lop3.b32 rb0, rb0, 0x03c003c0, 0xcc00cc00, 0xea;\n" // a & 0x03c0 | 0xcc00 + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - offset) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0]) + "r"(p[0])); +#else + assert(0); +#endif + } else { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) + asm volatile( + "{\n\t" + " .reg .b32 rb0;\n" + " mov.u32 rb0, 0xce00ce00;\n" + " mul.rn.f16x2 %0, %1, rb0;\n" // offset = scale * (-16 - 8) + "}\n" + : "=r"(addon_ptr[0]) + : "r"(scales_ptr[0])); +#else + assert(0); +#endif + } + } else { + // kMmaIterationsB == 1 + if constexpr(kHasOffset){ + uint8_t zp = offsets[0]; + addon[0] = scales[0] * static_cast(-16 - static_cast(zp)); + } else { + addon[0] = scales[0] * static_cast(-16-8); + } + } + + int out_idx = 0; + CUTLASS_PRAGMA_UNROLL + for (int n_out = 0; n_out < kMmaIterationsB; n_out++){ + CUTLASS_PRAGMA_UNROLL + for (int mma_tile_out_idx = 0; mma_tile_out_idx < kBTilesPerMma; mma_tile_out_idx++){ + dest[out_idx] = scales[n_out] * dest[out_idx] + addon[n_out]; + dest[out_idx + 1] = scales[n_out] * dest[out_idx + 1] + addon[n_out]; + out_idx += 2; + } + } + } + + /// Advances the pointer + CUTLASS_HOST_DEVICE + QuantBMetaMmaTensorOpTileIterator &operator++() { + // This is for operand B, so advance on the K dimension + lane_position_ += make_Coord(MetaTile::TileShapeB::kRow, 0); + return *this; + } + + CUTLASS_DEVICE + QuantBMetaMmaTensorOpTileIterator &add_tile_offset( + TensorCoord const &tile_offset) { + int rows = tile_offset.row() * MetaTile::TileShapeB::kRow; + int columns = tile_offset.column() * MetaTile::TileShapeB::kColumn; + lane_position_ += TensorCoord(rows, columns); + return *this; + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// +} // namespace warp +} // namespace gemm +} // namespace cutlass diff --git a/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h new file mode 100644 index 0000000000000..f29cedf326a44 --- /dev/null +++ b/onnxruntime/core/mickey/cutlass_ext/q4gemm/warp/quantb_mma_tensor_op.h @@ -0,0 +1,361 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + * Modifications Copyright (c) Microsoft. + * Licensed under the MIT license. + * + * @file quantb_mma_tensor_op.h + * @brief Modified from cutlass/gemm/warp/mma_tensor_op.h + * Templates implementing warp-level matrix multiply-accumulate operations + * targeting tensor cores. + */ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/mma_sm75.h" +#include "cutlass/arch/mma_sm80.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" +#include "cutlass/gemm/warp/mma_tensor_op_policy.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" + +#include "cutlass_ext/q4gemm/warp/quantb_meta_mma_tensor_op_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Data type of quant scales + typename ElementQScale_, + /// Layout of quant scales (concept: MatrixLayout) + typename SmemLayoutQScale_, + /// Data type of quant offsets + typename ElementQOffset_, + /// Layout of quant offsets (concept: MatrixLayout) + typename SmemLayoutQOffset_, + /// Blocking dimensions of quantization + typename QuantBlocking_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK_ = 1, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Used for partial specialization + typename Enable = bool +> +class QuantBMmaTensorOp { +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename Policy::Operator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Architecture tag from underlying instruction + using ArchTag = typename ArchMmaOperator::ArchTag; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassTensorOp; + + /// Shape of underlying instruction + using InstructionShape = typename ArchMmaOperator::Shape; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = ComplexTransform::kNone; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Number of threads participating in warp-level matrix product + static int const kThreadCount = 32; + + /// Number of partitions along K dimension + static int const kPartitionsK = PartitionsK_; + +public: + + /// Iterates over the A operand in memory + using IteratorA = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kA, ElementA, LayoutA, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = + Array; + + /// Iterates over the B operand in memory + using IteratorB = MmaTensorOpMultiplicandTileIterator< + MatrixShape, Operand::kB, ElementB, LayoutB, + MatrixShape, + Policy::OpDelta::kRow, kThreadCount, kPartitionsK>; + // warp B MatrixShape<64, 64>, + // layout B cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<16, 64>, + // instruction op shape cutlass::MatrixShape<16, 8>, + // kPartitionsK 1 + // FragmentB::kElements 32 + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; // cutlass::Array + + /// Storage for transformed B tile + /// When loading weights, we packed 4 int4 weights into one 2-byte-element, when expanded + /// we multiply the number of elements by 4. + /// TODO: make sure ArchMmaOperator::ElementB same as dequantized ElementB + /// and change the transform function below to perform dequantization + using TransformedFragmentB = + Array; + + /// Iterates over the C operand in memory + using IteratorC = MmaTensorOpAccumulatorTileIterator< + MatrixShape, ElementC, LayoutC, + typename ArchMmaOperator::Shape, typename Policy::OpDelta>; + + /// Storage for C tile + using FragmentC = typename IteratorC::Fragment; + + using ElementQScale = ElementQScale_; + using SmemLayoutQScale = SmemLayoutQScale_; + using QuantBlocking = QuantBlocking_; + + using ElementQOffset = ElementQOffset_; + using SmemLayoutQOffset = SmemLayoutQOffset_; + + /// Iterates over the quantization parameters in memory + using WarpQScaleShape = MatrixShape<(Shape::kK / QuantBlocking::kRow), (Shape::kN / QuantBlocking::kColumn)>; + static_assert(Shape::kK % QuantBlocking::kRow == 0, "K must be multiple of QuantBlocking::kRow"); + static_assert(Shape::kN % QuantBlocking::kColumn == 0, "N must be multiple of QuantBlocking::kColumn"); + static_assert(WarpQScaleShape::kCount > 0, "QuantBlocking too big to fit in a warp block!"); + + // TODO This is an expanding iterator, it needs to replicate the quantization parameters + // to all threads in the warp. + using IteratorQMeta = QuantBMetaMmaTensorOpTileIterator< + MatrixShape, QuantBlocking, ElementQScale, SmemLayoutQScale, + ElementQOffset, SmemLayoutQOffset, + ArchMmaOperator, kThreadCount, kPartitionsK>; + + using FragmentQScale = typename IteratorQMeta::FragmentScale; + using FragmentQOffset = typename IteratorQMeta::FragmentOffset; + + /// Number of mma operations performed + using MmaIterations = MatrixShape< + (Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM, + (Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN + >; + +public: + + /// Underlying matrix multiply operator (concept: arch::Mma) + ArchMmaOperator mma; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + QuantBMmaTensorOp() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &D, + TransformedFragmentA const &A, + TransformedFragmentB const &B, + FragmentC const &C + ) const { + + using MmaOperandA = typename ArchMmaOperator::FragmentA; + using MmaOperandB = typename ArchMmaOperator::FragmentB; + using MmaOperandC = typename ArchMmaOperator::FragmentC; + + D = C; + + MmaOperandA const *ptr_A = reinterpret_cast(&A); + MmaOperandB const *ptr_B = reinterpret_cast(&B); + MmaOperandC *ptr_D = reinterpret_cast(&D); + + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + // Serpentine visitation order maximizing reuse of Rb + // The visitation order is like + // _ + // | | | | + // | | | | + // |_| |_| + // + // Down Up Down Up + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[n + m_serpentine * MmaIterations::kColumn]); + } else { + mma( + ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[m_serpentine + n * MmaIterations::kRow]); + } + } + } + #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + // Serpentine visitation order maximizing reuse of Ra + // The visitation order is like + // _________ + // _________| + // |_________ + // __________| + // + // Right Left Right Left + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n); + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n_serpentine + m * MmaIterations::kColumn], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[n_serpentine + m * MmaIterations::kColumn]); + } else { + mma(ptr_D[m + n_serpentine * MmaIterations::kRow], + ptr_A[m], + ptr_B[n_serpentine], + ptr_D[m + n_serpentine * MmaIterations::kRow]); + } + } + } + #else + assert(0); + #endif + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentB &dst_B, + FragmentB const &B, + FragmentQScale const &scales, + FragmentQOffset const &offsets) const { + + Array const *ptr_B = + reinterpret_cast const *>(&B); + IteratorQMeta::dequant(scales, offsets, *ptr_B, dst_B); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +//#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/onnxruntime/core/util/matrix_layout.h b/onnxruntime/core/util/matrix_layout.h index a0405e32034ae..783a29d8a2055 100644 --- a/onnxruntime/core/util/matrix_layout.h +++ b/onnxruntime/core/util/matrix_layout.h @@ -17,7 +17,6 @@ #include #include "core/common/gsl.h" -// TODO!! Already have this in cuda, what about cpu code though? #if defined(_MSC_VER) #define ORT_FORCEINLINE __forceinline #else diff --git a/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h new file mode 100644 index 0000000000000..6ea8b55505214 --- /dev/null +++ b/onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h @@ -0,0 +1,203 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_quant_sm80.h + * + * Abstract: + * Oracle computation for blockwise 4b quantization for fp16 + * gemm kernel specifically for Ampere GPUs. This is used for + * testing the cuda kernel implementation in + * (test/providers/cuda/test_cases) + * and for testing the cuda op prepack code in (test/optimizer) + */ + +#pragma once + +#include "core/util/matrix_layout.h" +#include "core/common/common.h" + +namespace onnxruntime { +namespace test { + +static inline void sm80_prepack_weights_ref( + int rows, + int columns, + const MatrixRef& tensor_weight, + const MatrixRef& tensor_weight_prepacked) { + ORT_ENFORCE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns, + "Unexpected tensor_weight shape! Expected: (", rows / 2, ", ", columns, "), Got: (", + tensor_weight.shape()[0], ", ", tensor_weight.shape()[1], ")."); + ORT_ENFORCE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2, + "tensor_weight_prepacked shape is not compatible with prepacked weight shape"); + + auto t0_base = make_Position(0, 0); + auto t1_base = make_Position(4, 0); + auto t2_base = make_Position(0, 8); + auto t3_base = make_Position(4, 8); + for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { + for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { + // Packing from a 8x16 tile to a 16x8 tile + auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); + auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); + for (int col = 0; col < 8; ++col) { + for (int row = 0; row < 4; ++row) { + auto cord = make_Position(row, col); + auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 + uint8_t buf[4]; + buf[0] = tensor_weight.at(dtile_base + t0_base + cord); + buf[1] = tensor_weight.at(dtile_base + t1_base + cord); + buf[2] = tensor_weight.at(dtile_base + t2_base + cord); + buf[3] = tensor_weight.at(dtile_base + t3_base + cord); + + // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights + // are in different b16 register at the same positions. This makes it easier to convert to + // fp16x2 format in a b32 register + + tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); + tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); + tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); + } + } + } + } +} + +template < + typename ScaleElementT, + typename Layout, + typename QuantBlocking> +inline void sm80_prepack_quant_scales_ref( + int rows, + int columns, + const MatrixRef& tensor_scale, + const MatrixRef& tensor_scale_prepacked) { + ORT_ENFORCE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn), + "Unexpected tensor_scale shape! Expected: (", + rows / QuantBlocking::kRow, ", ", columns / QuantBlocking::kColumn, ")"); + ORT_ENFORCE(tensor_scale_prepacked.shape() == tensor_scale.shape()); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (sizeof(ScaleElementT) != 2 || QuantBlocking::kRow != 1) { + ORT_THROW("sm80_prepack_quant_scales_ref should only be called for row-wise block quantization on 16b float values."); + } + + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + + for (int col = 0; col < tensor_scale.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); + tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); + tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); + tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); + } + } + } +} + +template +inline void sm80_prepack_quant_offsets_ref( + int rows, + int columns, + MatrixRef tensor_offset, + MatrixRef tensor_offset_prepacked) { + const auto meta_shape = make_Position(rows / QuantBlocking::kRow, columns / QuantBlocking::kColumn); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + ORT_ENFORCE(tensor_offset_prepacked.shape() == meta_shape, + "Unexpected tensor_offset_prepacked shape (", + tensor_offset_prepacked.shape()[0], ",", tensor_offset_prepacked.shape()[1], + ")! Expected: (", meta_shape[0], ", ", meta_shape[1], ")"); + ORT_ENFORCE(tensor_offset.shape() == zp_shape, + "Unexpected tensor_offset shape (", + tensor_offset.shape()[0], ",", tensor_offset.shape()[1], + ")! Expected: (", zp_shape[0], ", ", zp_shape[1], ")"); + + // Only prepacking scale and offset tensors for a often used special case: + // 16b gemm (2 elements per 32b register, operand tile shape 8x8) + // 2 B operand tiles per mma instruction stacked on k dimension + // (1,n) quantization blocking + if constexpr (QuantBlocking::kRow != 1) { + ORT_THROW("sm80_prepack_quant_offsets_ref should only be called for row-wise block quantization."); + } + // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread + // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use + // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, + // as shown below (T stands for thread): + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // T0, T4, T8, T12 + // T1, T5, T9, T13 + // T2, T6, T10, T14 + // T3, T7, T11, T15 + // + // We need to deliver quantization scale and offset elements to the corresponding threads, + // so we can perform dequantization efficiently. With a column major layout, each thread + // needs two separate loads for a mma instruction, due to the tile fragment layout shown + // above. To reduce the number of loads, we rearrange each column as below, so we can use + // a single load to load fragments for two tiles: + // T0 T0 + // T1 T0 + // T2 T1 + // T3 => T1 + // T0 T2 + // T1 T2 + // T2 T3 + // T3 T3 + if (tensor_offset_prepacked.good()) { + for (int col = 0; col < tensor_offset_prepacked.shape()[1]; ++col) { + for (int row_blk = 0; row_blk < tensor_offset_prepacked.shape()[0]; row_blk += 16) { + for (int thread_id = 0; thread_id < 4; thread_id++) { + const int dst_idx = row_blk + thread_id * 4; + const int src_idx = row_blk + thread_id * 2; + // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own + // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to + // convert to fp16x2 format in a b32 register + uint8_t pair01 = tensor_offset.at(src_idx / 2, col); + uint8_t pair89 = tensor_offset.at((src_idx + 8) / 2, col); + tensor_offset_prepacked.at(dst_idx + 0, col) = pair01 & 0xf; + tensor_offset_prepacked.at(dst_idx + 1, col) = pair89 & 0xf; + tensor_offset_prepacked.at(dst_idx + 2, col) = pair01 >> 4; + tensor_offset_prepacked.at(dst_idx + 3, col) = pair89 >> 4; + } + } + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h new file mode 100644 index 0000000000000..bbe370675fc48 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h @@ -0,0 +1,188 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_gemm_sm80.h + * + * Abstract: + * Bridge between gtest code and gemm kernel implementation. + * Gemm kernel requires CUTLASS header files, which causes strange + * compilation errors with RE2 header files, which are required + * by gtest. + */ + +#pragma once + +#include + +#include "core/util/matrix_layout.h" +#include "core/common/common.h" +#include "core/mickey/blk_q4/f16_prepack_sm80.h" +#include "test/cuda_host/blkq4_fp16_quant_sm80.h" + +namespace onnxruntime { +namespace cuda { +namespace test { + +Status sm80_supported(); + +/** + * @brief Generate a set of quantized weights, scales and offsets + * and dequantized weights for testing quantization and + * dequantization. All outputs are column major layout. + * + * @tparam ElementT The type of the dequantized weights. + * @tparam block_size The block size of the quantization. + * @tparam col_blocking Whether to use column blocking (all elements of + * a block comes from a single column) or row blocking + * @tparam has_offsets Whether to generate offsets. + * + * @param[in] rows The number of rows of the weight matrix. + * @param[in] columns The number of columns of the weight matrix. + * @param[out] dequants The dequantized weights, column major layout. + * @param[out] q_weights The quantized weights, column major layout. + * @param[out] q_scales The scales, column major layout. + * @param[out] q_zp The zero points, column major layout. + */ +template +inline void blkq4_weights_gen( + int rows, int columns, + std::vector& dequants, + std::vector& q_weights, + std::vector& q_scales, + std::vector& q_zp) { + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + col_blocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + + static_assert(std::is_same::value); + static_assert(std::is_same::value); + static_assert(std::is_same::value); + + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution dis(0, 8192); + + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + + // + // For testing quantization and dequantization, it is not straight + // forward to avoid flaky tests due to rounding errors. The way we + // try to achieve this is to: + // 1. Generate a set of quantized weights, scales and offsets + // 2. Dequantize the weights + // 3. Quantize the dequantized weights + // 4. Compare the dequantied-and-then-quantized weights with + // the original quantized weights + // + // Random filling of the initial values are key to get this right. + // For weights, we must ensure each block gets a full range of + // values, i.e. must contain 0 and 15. And for scales, they must + // all be positive. + // + + q_weights.resize(q_weight_shape.product()); + MatrixRef tensor_q_weight( + q_weights, make_Position(rows / 2, columns)); + int v = 7; + for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { + for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { + uint8_t v0 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + uint8_t v1 = 0; + if (r + 1 < rows) { + v1 = static_cast(v); + v = (v + 5) % 16; + if (v == 11 || v == 7 || v == 3) { + // making the cycle 13 instead of 16, avoiding same values in a row + v = (v + 5) % 16; + } + } + + tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); + } + } + + q_scales.resize(meta_shape.product()); + for (size_t i = 0; i < q_scales.size(); i++) { + uint32_t v = dis(gen); + uint32_t m = (v % 63) + 1; + uint32_t e = (v >> 6) % 4; + q_scales[i] = ElementT(m / static_cast(1 << (2 + e))); + } + MatrixRef tensor_scale( + q_scales, meta_shape); + + MatrixRef tensor_offset; + if constexpr (has_offsets) { + q_zp.resize(zp_shape.product()); + tensor_offset = MatrixRef( + q_zp, zp_shape); + for (int c = 0; c < zp_shape[1]; c++) { + for (int r = 0; r < zp_shape[0]; ++r) { + uint8_t v0 = dis(gen) % 16; + uint8_t v1 = 8; + if (r * 2 + 1 < meta_shape[0]) { + v1 = dis(gen) % 16; + } + tensor_offset.at(r, c) = static_cast(v0 | (v1 << 4)); + } + } + } + + dequants.resize(rows * columns); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + + // Dequantize weights and save into matrix B + for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { + for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { + auto weight_cord = make_Position(row / 2, col); + auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); + uint8_t offset = 8; + if constexpr (has_offsets) { + if (scale_cord[0] % 2 == 0) { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) & 0x0f; + } else { + offset = tensor_offset.at(scale_cord[0] / 2, scale_cord[1]) >> 4; + } + } + int w = 0; + if (row % 2 == 0) { + w = int(tensor_q_weight.at(weight_cord) & 0x0f); + } else { + w = int(tensor_q_weight.at(weight_cord) >> 4); + } + float scale = float(tensor_scale.at(scale_cord)); + float dequant = scale * float(w - offset); + tensor_dequant.at(row, col) = ElementT(dequant); + // Prints for help debugging in case of test failure + // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); + } + } +} + +template < + int block_size, + bool column_wise_blocking, + bool small_m, + bool has_offsets> +void run_blkq4_gemm(int m, int n, int k); + +} // namespace test +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc new file mode 100644 index 0000000000000..e687ae73e66f2 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_test.cc @@ -0,0 +1,330 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_gemm_sm80_test.cc + * + * Abstract: + * Test code for block-wise quantized 4b GEMM kernels. + * This part requires gtest header files, which do not play + * well with CUTLASS headers. + */ + +#include + +#include "core/framework/float16.h" +#include "core/mlas/inc/mlas_q4.h" + +#include "blkq4_fp16_gemm_sm80.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +template +void testPrepack(int rows, int columns) { + using ElementT = MLFloat16; + constexpr int block_size = 32; + using Base = onnxruntime::cuda::BlockwiseQuantization< + ElementT, + block_size, + 4, + col_blocking>; + + using QuantBlocking = typename Base::QuantBlocking; + using ElementW = typename Base::ElementW; + using LayoutWPack = typename Base::LayoutWPack; + using ElementQOffset = typename Base::ElementQOffset; + using LayoutQmeta = typename Base::LayoutQmeta; + + const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); + const auto meta_shape = Base::get_quant_meta_shape(rows, columns); + const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]); + + std::vector q_weights; + std::vector q_scales; + std::vector q_zp; + std::vector dequants; + onnxruntime::cuda::test::blkq4_weights_gen( + rows, columns, dequants, q_weights, q_scales, q_zp); + + // for quantization tool, the input is row major, all outputs are column major + MatrixRef tensor_q_weight( + q_weights, make_Position(rows / 2, columns)); + MatrixRef tensor_scale( + q_scales, meta_shape); + MatrixRef tensor_offset; + if constexpr (has_offset) { + tensor_offset = MatrixRef(q_zp, zp_shape); + } + + // for quantization tool, the input is row major, test weight gen output is column major + std::vector dequants_transposed(dequants.size()); + MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); + MatrixRef tensor_dequant_transposed(dequants_transposed, make_Position(rows, columns)); + for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { + for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { + tensor_dequant_transposed.at(row, col) = tensor_dequant.at(row, col); + } + } + + int q_rows, q_cols; + MlasBlockwiseQuantizedShape( + block_size, col_blocking, rows, columns, q_rows, q_cols); + // to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes + EXPECT_EQ(q_rows, q_weight_shape[0]); + EXPECT_EQ(q_cols, q_weight_shape[1]); + + // + // Quantization tool outputs: + // + std::vector o_elements(q_rows * q_cols); + MatrixRef tensor_o_elements(o_elements, q_weight_shape); + + std::vector o_scales(meta_shape.product()); + MatrixRef tensor_o_scales(o_scales, meta_shape); + + std::vector o_zp(zp_shape.product()); + MatrixRef tensor_o_zp(o_zp, zp_shape); + + MlasQuantizeBlockwise(o_elements.data(), o_scales.data(), has_offset ? o_zp.data() : nullptr, + dequants_transposed.data(), block_size, + col_blocking, rows, columns, columns, nullptr); + for (int col = 0; col < tensor_q_weight.shape()[1]; ++col) { + for (int row = 0; row < tensor_q_weight.shape()[0]; ++row) { + EXPECT_EQ(tensor_o_elements.at(row, col), tensor_q_weight.at(row, col)) + << "quantized value mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + if (has_offset) { + uint8_t pair01 = tensor_o_zp.at(row / 2, col); + uint8_t expected_pair01 = tensor_offset.at(row / 2, col); + EXPECT_EQ(expected_pair01 & 0xf, pair01 & 0xf) + << "quantized offset mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + if (row + 1 < meta_shape[0]) { + EXPECT_EQ(expected_pair01 >> 4, pair01 >> 4) + << "quantized offset mismatch at [" << row + 1 << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + EXPECT_EQ(tensor_scale.at(row + 0, col), tensor_o_scales.at(row + 0, col)) + << "quantized scale mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + if (row + 1 < meta_shape[0]) { + EXPECT_EQ(tensor_scale.at(row + 1, col), tensor_o_scales.at(row + 1, col)) + << "quantized scale mismatch at [" << row + 1 << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + } + + // + // Now we just setup quantized weights tensor_q_weight, quantization scale tensor_scale + // and quantization offset tensor_offset. The above tests just make sure our setup is + // consistent with quantization tool output. + // + // Next we test the prepack code + // + + std::vector packed_w_ref(q_weight_shape.product()); + MatrixRef tensor_packed_w_ref( + packed_w_ref, make_Position(rows, columns / 2)); + onnxruntime::test::sm80_prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref); + + std::vector packed_w(q_weight_shape.product()); + MatrixRef tensor_packed_w( + packed_w, make_Position(rows, columns / 2)); + Base::prepack_weights(rows, columns, o_elements, packed_w); + + for (int col = 0; col < tensor_packed_w.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_w.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_w_ref.at(row, col), tensor_packed_w.at(row, col)) + << "prepacked weights mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + std::vector packed_scales_ref(meta_shape.product()); + MatrixRef tensor_packed_s_ref = + make_MatrixRef(packed_scales_ref, meta_shape); + if constexpr (Base::ShouldRearrangeMeta) { + onnxruntime::test::sm80_prepack_quant_scales_ref( + rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); + } else { + for (int col = 0; col < tensor_packed_s_ref.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_s_ref.shape()[0]; ++row) { + tensor_packed_s_ref.at(row, col) = tensor_scale.at(row, col); + } + } + } + + std::vector packed_scales(meta_shape.product()); + MatrixRef tensor_packed_s( + packed_scales, meta_shape); + Base::prepack_quant_scales(rows, columns, o_scales, packed_scales); + + for (int col = 0; col < tensor_packed_s.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_s.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_s_ref.at(row, col), tensor_packed_s.at(row, col)) + << "prepacked scales mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + + if (has_offset) { + std::vector packed_zp_ref(meta_shape.product()); + MatrixRef tensor_packed_zp_ref = + make_MatrixRef(packed_zp_ref, meta_shape); + if constexpr (Base::ShouldRearrangeMeta) { + onnxruntime::test::sm80_prepack_quant_offsets_ref( + rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); + } else { + for (int col = 0; col < meta_shape[1]; ++col) { + for (int row = 0; row < meta_shape[0]; row += 2) { + uint8_t pair01 = tensor_offset.at(row / 2, col); + tensor_packed_zp_ref.at(row, col) = pair01 & 0xf; + if (row + 1 < meta_shape[0]) { + tensor_packed_zp_ref.at(row + 1, col) = pair01 >> 4; + } + } + } + } + + std::vector packed_zp(meta_shape.product()); + MatrixRef tensor_packed_zp( + packed_zp, meta_shape); + Base::prepack_quant_offsets(rows, columns, o_zp, packed_zp); + + for (int col = 0; col < tensor_packed_zp.shape()[1]; ++col) { + for (int row = 0; row < tensor_packed_zp.shape()[0]; ++row) { + EXPECT_EQ(tensor_packed_zp_ref.at(row, col), tensor_packed_zp.at(row, col)) + << "prepacked offsets mismatch at [" << row << "," << col << "]" + << " shape[" << rows << "," << columns << "]" + << (col_blocking ? "Column-wise-block" : "Row-wise-block") + << std::endl; + } + } + } +} + +// TODO: code runs on CPU, but this is for sm80 only, maybe enable only when test on sm80 +TEST(BlkQ4_GEMM, PrepackSm80Test) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + + testPrepack(32, 32); + testPrepack(32, 32); + testPrepack(32, 32); + testPrepack(32, 32); + testPrepack(32, 64); + testPrepack(32, 128); + testPrepack(32, 256); + testPrepack(64, 32); + testPrepack(128, 32); + testPrepack(256, 32); + testPrepack(256, 256); + testPrepack(32, 128); + testPrepack(128, 32); + testPrepack(256, 256); + testPrepack(32, 64); + testPrepack(32, 128); + testPrepack(32, 256); + testPrepack(64, 32); + testPrepack(128, 32); + testPrepack(256, 32); + testPrepack(256, 256); + testPrepack(32, 128); + testPrepack(128, 32); + testPrepack(256, 256); +} + +TEST(BlkQ4_GEMM, Sm80RowBlockingTest) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 32, 64); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 32, 64); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 96, 64); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 96, 64); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(32, 96, 192); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(32, 96, 192); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(256, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, true>(256, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(512, 2048 + 32, 960); + onnxruntime::cuda::test::run_blkq4_gemm<32, false, false, false>(512, 2048 + 32, 960); + + onnxruntime::cuda::test::run_blkq4_gemm<16, false, false, false>(256, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, false, false, true>(256, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, false>(256, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, false, false, true>(256, 1024, 576); +} + +TEST(BlkQ4_GEMM, Sm80ColBlockingTest) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, false>(64, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, true, false, true>(64, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, false>(256, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, true, false, true>(256, 1024, 576); +} + +TEST(BlkQ4_GEMM, Sm80SmallMTest) { + Status status = onnxruntime::cuda::test::sm80_supported(); + if (!status.IsOK()) { + // skip the test if sm80 is not supported + return; + } + + // // small m + onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, false>(16, 704, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, false, true, true>(16, 704, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, false>(16, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, false, true, true>(16, 1024, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<16, true, true, false>(16, 672, 576); + onnxruntime::cuda::test::run_blkq4_gemm<16, true, true, true>(16, 672, 576); + + onnxruntime::cuda::test::run_blkq4_gemm<64, true, true, false>(16, 1024, 576); + onnxruntime::cuda::test::run_blkq4_gemm<64, true, true, true>(16, 1024, 576); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu new file mode 100644 index 0000000000000..69c929d446ce4 --- /dev/null +++ b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80_testcu.cu @@ -0,0 +1,344 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. + * + * Module Name: + * blkq4_fp16_gemm_sm80_testcu.cu + * + * Abstract: + * Test code for invoking block-wise quantized 4b GEMM kernels. + * This part requires CUTLASS header files, which do not play + * well with gtest headers. + */ + +#include +#include +#include + +#include "core/mickey/blk_q4/f16_gemm_sm80.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "core/common/common.h" + +#include "blkq4_fp16_gemm_sm80.h" + +namespace onnxruntime { +namespace cuda{ +namespace test{ + +Status sm80_supported(){ + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::ostringstream ss; + ss << "Unable to obtain GPU device properties: " << cudaGetErrorString(error); + return Status(common::ONNXRUNTIME, common::ENGINE_ERROR, ss.str()); + } + + if (!((props.major * 10 + props.minor) >= 80)) { + std::ostringstream ss; + ss << "Device compute capability mismatch, desired 8.0, actual " << props.major << "." << props.minor; + return Status(common::ONNXRUNTIME, common::ENGINE_ERROR, ss.str()); + } + return Status::OK(); +} + +/** + * @brief Reference implementation of GEMM + * Copied directly from cutlass util/reference/device/gemm.h + * for the strange reason that compiler insists on asking + * for explicit stream argument in kernel launch. +*/ +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType +> +void compute_gemm_ref( + cutlass::gemm::GemmCoord problem_size, + ScalarType alpha, + cutlass::TensorRef tensor_a, + cutlass::TensorRef tensor_b, + ScalarType beta, + cutlass::TensorRef tensor_c, + cutlass::TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = cutlass::MatrixShape<4, 4>; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) + ); + + // Launch a GEMM kernel + cutlass::reference::device::kernel::Gemm< + cutlass::TensorRef, + cutlass::TensorRef, + cutlass::TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + cutlass::multiply_add, + cutlass::NumericConverter + ><<>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Converting cutlass tensor to MatrixRef +// + +template < + typename Element, + typename LayoutCutlass, + typename Layout = std::conditional_t::value, ColumnMajorLayout, RowMajorLayout> + > +__forceinline__ +MatrixRef make_MatrixRef(cutlass::HostTensor const& tensor) { + static_assert(std::is_same::value + || std::is_same::value); + auto shape = make_Position(tensor.extent().row(), tensor.extent().column()); + auto* ptr = const_cast::type *>(tensor.host_data()); + return MatrixRef(ptr, tensor.capacity(), shape); +} + +template < + typename Element, + typename LayoutCutlass, + typename Layout = std::conditional_t::value, ColumnMajorLayout, RowMajorLayout> + > +__forceinline__ +MatrixRef make_ConstMatrixRef(cutlass::HostTensor const& tensor) { + static_assert(std::is_same::value + || std::is_same::value); + auto shape = make_Position(tensor.extent().row(), tensor.extent().column()); + return MatrixRef(tensor.host_data(), tensor.capacity(), shape); +} + +// +// Invoking the kernel +// + +template< + int block_size, + bool column_wise_blocking, + bool small_m, + bool has_offsets> +void run_blkq4_gemm(int m, int n, int k) { + unsigned int seed = 28571; // Replace with desired seed value + std::seed_seq seq{seed}; + std::mt19937 gen(seq); + std::uniform_int_distribution<> dis(0, 8192); + + using ElementDequant = cutlass::half_t; + using QuantBlocking = + typename std::conditional, + cutlass::MatrixShape<1, block_size>>::type; + + using GemmRunner = BlkQ4F16GemmImpl; + + using ElementAccumulator = typename GemmRunner::ElementAccumulator; + using ElementComputeEpilogue = typename GemmRunner::ElementComputeEpilogue; + using ElementInputA = typename GemmRunner::ElementInputA; + using ElementOutput = typename GemmRunner::ElementOutput; + using ElementW = typename GemmRunner::ElementW; + using ElementWPack = typename GemmRunner::ElementWPack; + using ElementQScale = typename GemmRunner::ElementQScale; + using ElementQOffset = typename GemmRunner::ElementQOffset; + + using LayoutInputA = typename GemmRunner::LayoutInputA; + using LayoutOutput = typename GemmRunner::LayoutOutput; + using LayoutInputWPack = typename GemmRunner::LayoutInputWPack; + using LayoutInputQScale = typename GemmRunner::LayoutInputQScale; + + const cutlass::gemm::GemmCoord problem_size = {m, n, k}; + const auto q_weight_shape = cutlass::make_Coord(problem_size.k()/2, problem_size.n()); + const auto meta_shape = cutlass::make_Coord(problem_size.k()/QuantBlocking::kRow, problem_size.n()/QuantBlocking::kColumn); + + // + // Generate quantized and dequantizeed input matrix B [K, N] + // + static_assert(std::is_same::value); + std::vector q_weights; + std::vector q_scales; + std::vector q_zp; + std::vector dequants; + onnxruntime::cuda::test::blkq4_weights_gen( + problem_size.k(), problem_size.n(), dequants, q_weights, q_scales, q_zp); + + using PrepackT = onnxruntime::cuda::BlockwiseQuantization< + ElementDequant, + block_size, + 4, + column_wise_blocking>; + + std::vector packed_w(q_weight_shape.product()); + PrepackT::prepack_weights(problem_size.k(), problem_size.n(), q_weights, packed_w); + std::vector packed_scales(meta_shape.product()); + PrepackT::prepack_quant_scales(problem_size.k(), problem_size.n(), q_scales, packed_scales); + std::vector packed_zp; + if constexpr (has_offsets) { + packed_zp.resize(meta_shape.product()); + PrepackT::prepack_quant_offsets(problem_size.k(), problem_size.n(), q_zp, packed_zp); + } + + cutlass::HostTensor tensor_a( + problem_size.mk()); // <- Create matrix A with dimensions M x K + cutlass::HostTensor tensor_c( + problem_size.mn()); // <- Create matrix C with dimensions M x N + cutlass::HostTensor tensor_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // CUTLASS kernel + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(4), + ElementInputA(-4), + 2); // <- Fill matrix A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(4), + ElementOutput(-4), + 0); // <- Fill matrix C on host with uniform-distribution random data + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + + // + // Copy data from host to GPU... + // + thrust::device_vector d_packed_w(packed_w); + cutlass::TensorRef ref_W( + reinterpret_cast(d_packed_w.data().get()), + LayoutInputWPack::packed({problem_size.k()/2, problem_size.n()/2})); + + thrust::device_vector d_packed_scales(packed_scales); + cutlass::TensorRef ref_scales( + d_packed_scales.data().get(), LayoutInputQScale::packed(meta_shape)); + + thrust::device_vector d_packed_zp(packed_zp); + cutlass::TensorRef ref_zp( + d_packed_zp.data().get(), LayoutInputQScale::packed(meta_shape)); + + tensor_a.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + + // run GEMM + cutlass::Status status; + if constexpr (has_offsets){ + status = GemmRunner::run( + nullptr, problem_size, tensor_a.device_ref(), ref_W, + ref_scales, ref_zp, + tensor_c.device_ref(), tensor_d.device_ref()); + } else { + status = GemmRunner::run( + nullptr, problem_size, tensor_a.device_ref(), ref_W, + ref_scales, + tensor_c.device_ref(), tensor_d.device_ref()); + } + ORT_ENFORCE(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status)); + + // Running reference kernel + using ElementInputB = ElementInputA; + using LayoutInputB = cutlass::layout::ColumnMajor; + thrust::device_vector d_dequants(dequants); + cutlass::TensorRef ref_B( + d_dequants.data().get(), LayoutInputB::packed(problem_size.kn())); + cutlass::HostTensor tensor_ref_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // reference kernel + + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + tensor_ref_d.sync_device(); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + compute_gemm_ref( + problem_size, + alpha, + tensor_a.device_ref(), + ref_B, + beta, + tensor_c.device_ref(), + tensor_ref_d.device_ref()); + + // Wait for kernels to finish + cudaDeviceSynchronize(); + + // Copy output data from CUTLASS and reference kernel to host for comparison + tensor_d.sync_host(); + tensor_ref_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + ORT_ENFORCE(passed, "Gemm kernel result wrong!"); +} + +template void run_blkq4_gemm<16, true, false, true>(int m, int n, int k); +template void run_blkq4_gemm<16, true, false, false>(int m, int n, int k); +template void run_blkq4_gemm<32, true, false, true>(int m, int n, int k); +template void run_blkq4_gemm<32, true, false, false>(int m, int n, int k); +template void run_blkq4_gemm<64, true, false, true>(int m, int n, int k); +template void run_blkq4_gemm<64, true, false, false>(int m, int n, int k); +template void run_blkq4_gemm<16, false, false, true>(int m, int n, int k); +template void run_blkq4_gemm<16, false, false, false>(int m, int n, int k); +template void run_blkq4_gemm<32, false, false, true>(int m, int n, int k); +template void run_blkq4_gemm<32, false, false, false>(int m, int n, int k); +template void run_blkq4_gemm<64, false, false, true>(int m, int n, int k); +template void run_blkq4_gemm<64, false, false, false>(int m, int n, int k); +template void run_blkq4_gemm<16, true, true, true>(int m, int n, int k); +template void run_blkq4_gemm<16, true, true, false>(int m, int n, int k); +template void run_blkq4_gemm<32, true, true, true>(int m, int n, int k); +template void run_blkq4_gemm<32, true, true, false>(int m, int n, int k); +template void run_blkq4_gemm<64, true, true, true>(int m, int n, int k); +template void run_blkq4_gemm<64, true, true, false>(int m, int n, int k); +template void run_blkq4_gemm<16, false, true, true>(int m, int n, int k); +template void run_blkq4_gemm<16, false, true, false>(int m, int n, int k); +template void run_blkq4_gemm<32, false, true, true>(int m, int n, int k); +template void run_blkq4_gemm<32, false, true, false>(int m, int n, int k); +template void run_blkq4_gemm<64, false, true, true>(int m, int n, int k); +template void run_blkq4_gemm<64, false, true, false>(int m, int n, int k); + +} // namespace test +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc b/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc deleted file mode 100644 index aba2b0b2cb4a4..0000000000000 --- a/onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_sm80_prepack_test.cc +++ /dev/null @@ -1,507 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include "core/framework/float16.h" -#include "core/mickey/blk_q4/prepack_sm80.h" -#include "core/mlas/inc/mlas_q4.h" - -#include "gtest/gtest.h" - -namespace onnxruntime { -namespace test { - -void prepack_weights_ref( - int rows, - int columns, - const MatrixRef& tensor_weight, - const MatrixRef& tensor_weight_prepacked) { - EXPECT_TRUE(tensor_weight.shape()[0] == rows / 2 && tensor_weight.shape()[1] == columns); - EXPECT_TRUE(tensor_weight_prepacked.shape()[0] == rows && tensor_weight_prepacked.shape()[1] == columns / 2); - - auto t0_base = make_Position(0, 0); - auto t1_base = make_Position(4, 0); - auto t2_base = make_Position(0, 8); - auto t3_base = make_Position(4, 8); - for (int col_dtile = 0; col_dtile < columns / 16; ++col_dtile) { - for (int row_dtile = 0; row_dtile < rows / 16; ++row_dtile) { - // Packing from a 8x16 tile to a 16x8 tile - auto dtile_base = make_Position(row_dtile * 8, col_dtile * 16); - auto packed_tile_base = make_Position(row_dtile * 16, col_dtile * 8); - for (int col = 0; col < 8; ++col) { - for (int row = 0; row < 4; ++row) { - auto cord = make_Position(row, col); - auto packed_cord = packed_tile_base + make_Position(row * 4, col); // packed tile is 16x8 - uint8_t buf[4]; - buf[0] = tensor_weight.at(dtile_base + t0_base + cord); - buf[1] = tensor_weight.at(dtile_base + t1_base + cord); - buf[2] = tensor_weight.at(dtile_base + t2_base + cord); - buf[3] = tensor_weight.at(dtile_base + t3_base + cord); - - // [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] so that each pair of adjacent weights - // are in different b16 register at the same positions. This makes it easier to convert to - // fp16x2 format in a b32 register - - tensor_weight_prepacked.at(packed_cord) = (buf[0] & 0x0f) | ((buf[1] & 0x0f) << 4); - tensor_weight_prepacked.at(packed_cord + make_Position(1, 0)) = (buf[2] & 0x0f) | ((buf[3] & 0x0f) << 4); - tensor_weight_prepacked.at(packed_cord + make_Position(2, 0)) = ((buf[0] & 0xf0) >> 4) | (buf[1] & 0xf0); - tensor_weight_prepacked.at(packed_cord + make_Position(3, 0)) = ((buf[2] & 0xf0) >> 4) | (buf[3] & 0xf0); - } - } - } - } -} - -template < - typename ScaleElementT, - typename Layout, - typename QuantBlocking> -void prepack_quant_scales_ref( - int rows, - int columns, - const MatrixRef& tensor_scale, - const MatrixRef& tensor_scale_prepacked) { - EXPECT_TRUE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn)); - EXPECT_TRUE(tensor_scale_prepacked.shape() == tensor_scale.shape()); - - // Only prepacking scale and offset tensors for a often used special case: - // 16b gemm (2 elements per 32b register, operand tile shape 8x8) - // 2 B operand tiles per mma instruction stacked on k dimension - // (1,n) quantization blocking - if constexpr (sizeof(ScaleElementT) == 2 && QuantBlocking::kRow == 1) { - // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread - // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use - // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, - // as shown below (T stands for thread): - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // - // We need to deliver quantization scale and offset elements to the corresponding threads, - // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two separate loads for a mma instruction, due to the tile fragment layout shown - // above. To reduce the number of loads, we rearrange each column as below, so we can use - // a single load to load fragments for two tiles: - // T0 T0 - // T1 T0 - // T2 T1 - // T3 => T1 - // T0 T2 - // T1 T2 - // T2 T3 - // T3 T3 - - for (int col = 0; col < tensor_scale.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_scale.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - tensor_scale_prepacked.at(dst_idx + 0, col) = tensor_scale.at(src_idx + 0, col); - tensor_scale_prepacked.at(dst_idx + 1, col) = tensor_scale.at(src_idx + 1, col); - tensor_scale_prepacked.at(dst_idx + 2, col) = tensor_scale.at(src_idx + 8, col); - tensor_scale_prepacked.at(dst_idx + 3, col) = tensor_scale.at(src_idx + 9, col); - } - } - } - } else { - // In all other cases, we don't prepack scale or offset - FAIL() << "Scale prepack only supported for 16b gemm with (1,n) quantization blocking"; - } -} - -template -void prepack_quant_offsets_ref( - size_t rows, - size_t columns, - MatrixRef tensor_offset, - MatrixRef tensor_offset_prepacked) { - // EXPECT_TRUE(tensor_offset.shape()[0] == (rows / QuantBlocking::kRow) && tensor_offset.shape()[1] == (columns / QuantBlocking::kColumn)); - EXPECT_TRUE(tensor_offset_prepacked.shape() == tensor_offset.shape()); - - // Only prepacking scale and offset tensors for a often used special case: - // 16b gemm (2 elements per 32b register, operand tile shape 8x8) - // 2 B operand tiles per mma instruction stacked on k dimension - // (1,n) quantization blocking - if constexpr (QuantBlocking::kRow != 1) { - FAIL() << "Offsets prepack only supported for 16b gemm with (1,n) quantization blocking"; - } - // In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread - // holds a fragment of the tile containing 2 elements in the k dimension. Most often we use - // mma instruction shape of 16x8x16, which means 2 B tiles are stacked in the k dimension, - // as shown below (T stands for thread): - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // T0, T4, T8, T12 - // T1, T5, T9, T13 - // T2, T6, T10, T14 - // T3, T7, T11, T15 - // - // We need to deliver quantization scale and offset elements to the corresponding threads, - // so we can perform dequantization efficiently. With a column major layout, each thread - // needs two separate loads for a mma instruction, due to the tile fragment layout shown - // above. To reduce the number of loads, we rearrange each column as below, so we can use - // a single load to load fragments for two tiles: - // T0 T0 - // T1 T0 - // T2 T1 - // T3 => T1 - // T0 T2 - // T1 T2 - // T2 T3 - // T3 T3 - if (tensor_offset_prepacked.good()) { - for (int col = 0; col < tensor_offset.shape()[1]; ++col) { - for (int row_blk = 0; row_blk < tensor_offset.shape()[0]; row_blk += 16) { - for (int thread_id = 0; thread_id < 4; thread_id++) { - const int dst_idx = row_blk + thread_id * 4; - const int src_idx = row_blk + thread_id * 2; - // [a, b, c, d] => [a, c, b, d] so that adjacent weights are in their own - // 16b element: [a, x, b, x] and [x, c, x, d], which makes it easier to - // convert to fp16x2 format in a b32 register - tensor_offset_prepacked.at(dst_idx + 0, col) = tensor_offset.at(src_idx + 0, col); - tensor_offset_prepacked.at(dst_idx + 1, col) = tensor_offset.at(src_idx + 8, col); - tensor_offset_prepacked.at(dst_idx + 2, col) = tensor_offset.at(src_idx + 1, col); - tensor_offset_prepacked.at(dst_idx + 3, col) = tensor_offset.at(src_idx + 9, col); - } - } - } - } -} - -template -void testPrepack(int rows, int columns, bool has_offset = true) { - using ElementT = MLFloat16; - constexpr int block_size = 32; - using Base = onnxruntime::cuda::BlockwiseQuantization< - ElementT, - block_size, - 4, - ColumnMajorQuantBlocking>; - - using QuantBlocking = typename Base::QuantBlocking; - using ElementW = typename Base::ElementW; - using LayoutWPack = typename Base::LayoutWPack; - using ElementQOffset = typename Base::ElementQOffset; - using LayoutQmeta = typename Base::LayoutQmeta; - - unsigned int seed = 28571; // Replace with desired seed value - std::seed_seq seq{seed}; - std::mt19937 gen(seq); - std::uniform_int_distribution<> dis(0, 8192); - - const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns); - const auto meta_shape = Base::get_quant_meta_shape(rows, columns); - - // - // For testing quantization and dequantization, it is not straight - // forward to avoid flaky tests due to rounding errors. The way we - // try to achieve this is to: - // 1. Generate a set of quantized weights, scales and offsets - // 2. Dequantize the weights - // 3. Quantize the dequantized weights - // 4. Compare the dequantied-and-then-quantized weights with - // the original quantized weights - // - // Random filling of the initial values are key to get this right. - // For weights, we must ensure each block gets a full range of - // values, i.e. must contain 0 and 15. And for scales, they must - // all be positive. - // - - std::vector q_weights(q_weight_shape.product()); - MatrixRef tensor_q_weight( - q_weights, make_Position(rows / 2, columns)); - int v = 7; - for (int c = 0; c < tensor_q_weight.shape()[1]; c++) { - for (int r = 0; r < tensor_q_weight.shape()[0]; ++r) { - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - - tensor_q_weight.at(r, c) = ElementW((v1 << 4) | v0); - } - } - - std::vector q_scales(meta_shape.product()); - for (size_t i = 0; i < q_scales.size(); i++) { - q_scales[i] = ElementT(((dis(gen) % 127) + 1) / 32.0f); - } - MatrixRef tensor_scale( - q_scales, meta_shape); - - std::vector q_zp(meta_shape.product()); - for (size_t i = 0; i < q_zp.size(); i++) { - q_zp[i] = dis(gen) % 16; - } - MatrixRef tensor_offset( - q_zp, meta_shape); - -#if 0 // debug - // Fill tensor_q_weight with the patterned data, easier to debug with print - int loop_val = 0; - int offset = 3; - for (int col_tile = 0; col_tile < tensor_q_weight.extent().column()/8; ++col_tile) { - for (int row_tile = 0; row_tile < tensor_q_weight.extent().row()/4; ++row_tile) { - for (int col = 0; col < 8; ++col) { - for (int row = 0; row < 4; ++row) { - auto weight_cord = cutlass::make_Coord(row_tile * 4 + row, col_tile * 8 + col); - auto val = (loop_val + offset) % 256; - tensor_q_weight.at(weight_cord) = ElementW(val); - loop_val++; - if (loop_val == 256) { - loop_val = 0; - offset += 11; - } - } - } - } - } - for (int col = 0; col < tensor_scale.extent().column(); ++col){ - int c = col * QuantBlocking::kColumn; - for (int row = 0; row < tensor_scale.extent().row(); ++row){ - int r = row * QuantBlocking::kRow; - auto weight_cord = cutlass::make_Coord(r/2, c); - int w = 0; - if (r % 2 == 0) { - w = int(tensor_q_weight.at(weight_cord) & 0x0f); - } else { - w = int(tensor_q_weight.at(weight_cord) >> 4); - } - tensor_scale.at({row, col}) = w; - tensor_offset.at({row, col}) = ElementQOffset(w); - } - } - - int fill_val = -512; - int factor = 1; - for (int col = 0; col < tensor_scale.extent().column(); ++col){ - for (int row = 0; row < tensor_scale.extent().row(); ++row){ - tensor_scale.at({row, col}) = ElementQScale((float)fill_val * float(factor)); - fill_val++; - if (fill_val == 512) { - fill_val = -512; - factor += 1; - } - } - } - -#endif // debug - - std::vector dequants(rows * columns); - MatrixRef tensor_dequant(dequants, make_Position(rows, columns)); - - // Dequantize weights and save into matrix B for reference - for (int col = 0; col < tensor_dequant.shape()[1]; ++col) { - for (int row = 0; row < tensor_dequant.shape()[0]; ++row) { - auto weight_cord = make_Position(row / 2, col); - auto scale_cord = make_Position(row / QuantBlocking::kRow, col / QuantBlocking::kColumn); - const uint8_t offset = has_offset ? tensor_offset.at(scale_cord) : 8; - int w = 0; - if (row % 2 == 0) { - w = int(tensor_q_weight.at(weight_cord) & 0x0f); - } else { - w = int(tensor_q_weight.at(weight_cord) >> 4); - } - float scale = float(tensor_scale.at(scale_cord)); - float dequant = scale * float(w - offset); - tensor_dequant.at(row, col) = ElementT(dequant); - // Prints for help debugging in case of test failure - // fprintf(stderr, "(%2d,%2d)= %2d, %2d, %f, %f\n", row, col, w, offset, scale, dequant); - } - } - - int q_rows, q_cols; - MlasBlockwiseQuantizedShape( - block_size, ColumnMajorQuantBlocking, rows, columns, q_rows, q_cols); - // to be exact, q_rows are padded to multiple of block_size, deal with it when we care about strange shapes - EXPECT_EQ(q_rows, q_weight_shape[0]); - EXPECT_EQ(q_cols, q_weight_shape[1]); - - // - // Quantization tool outputs: - // - std::vector o_elements(q_rows * q_cols); - MatrixRef tensor_o_elements(o_elements, q_weight_shape); - - std::vector o_scales(meta_shape.product()); - MatrixRef tensor_o_scales(o_scales, meta_shape); - - std::vector o_zp(((meta_shape[0] + 1) / 2) * meta_shape[1], true); - MatrixRef tensor_o_zp( - o_zp, make_Position((meta_shape[0] + 1) / 2, meta_shape[1])); - - MlasQuantizeBlockwise(o_elements.data(), o_scales.data(), has_offset ? o_zp.data() : nullptr, - tensor_dequant.data().data(), block_size, - ColumnMajorQuantBlocking, rows, columns, columns, nullptr); - for (int col = 0; col < tensor_q_weight.shape()[1]; ++col) { - for (int row = 0; row < tensor_q_weight.shape()[0]; ++row) { - EXPECT_EQ(tensor_o_elements.at(row, col), tensor_q_weight.at(row, col)) - << "quantized value mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - - for (int col = 0; col < meta_shape[1]; ++col) { - for (int row = 0; row < meta_shape[0]; row += 2) { - if (has_offset) { - uint8_t pair01 = tensor_o_zp.at(row / 2, col); - EXPECT_EQ(tensor_offset.at(row + 0, col), pair01 & 0xf) - << "quantized offset mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - if (row + 1 < meta_shape[0]) { - EXPECT_EQ(tensor_offset.at(row + 1, col), pair01 >> 4) - << "quantized offset mismatch at [" << row + 1 << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - - EXPECT_EQ(tensor_scale.at(row + 0, col), tensor_o_scales.at(row + 0, col)) - << "quantized scale mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - if (row + 1 < meta_shape[0]) { - EXPECT_EQ(tensor_scale.at(row + 1, col), tensor_o_scales.at(row + 1, col)) - << "quantized scale mismatch at [" << row + 1 << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - } - - // - // Now we just setup fp16 weights tensor_dequant, quantized weights tensor_q_weight, - // quantization scale tensor_scale and quantization offset tensor_offset. The above - // testing just make sure our test setup is consistent with quantization tool output. - // - // Next we test the prepack code - // - - std::vector packed_w_ref(q_weight_shape.product()); - MatrixRef tensor_packed_w_ref( - packed_w_ref, make_Position(rows, columns / 2)); - prepack_weights_ref(rows, columns, tensor_q_weight, tensor_packed_w_ref); - - std::vector packed_w(q_weight_shape.product()); - MatrixRef tensor_packed_w( - packed_w, make_Position(rows, columns / 2)); - Base::prepack_weights(rows, columns, o_elements, packed_w); - - for (int col = 0; col < tensor_packed_w.shape()[1]; ++col) { - for (int row = 0; row < tensor_packed_w.shape()[0]; ++row) { - EXPECT_EQ(tensor_packed_w_ref.at(row, col), tensor_packed_w.at(row, col)) - << "prepacked weights mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - - std::vector packed_scales_ref(meta_shape.product()); - MatrixRef tensor_packed_s_ref = - Base::ShouldRearrangeMeta ? make_MatrixRef(packed_scales_ref, meta_shape) - : tensor_scale; - if (Base::ShouldRearrangeMeta) { - prepack_quant_scales_ref( - rows, columns, tensor_scale.const_ref(), tensor_packed_s_ref); - } - - std::vector packed_scales(meta_shape.product()); - MatrixRef tensor_packed_s( - packed_scales, meta_shape); - Base::prepack_quant_scales(rows, columns, o_scales, packed_scales); - - for (int col = 0; col < tensor_packed_s.shape()[1]; ++col) { - for (int row = 0; row < tensor_packed_s.shape()[0]; ++row) { - EXPECT_EQ(tensor_packed_s_ref.at(row, col), tensor_packed_s.at(row, col)) - << "prepacked scales mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - - if (has_offset) { - std::vector packed_zp_ref(meta_shape.product()); - MatrixRef tensor_packed_zp_ref = - Base::ShouldRearrangeMeta ? make_MatrixRef(packed_zp_ref, meta_shape) - : tensor_offset; - if (Base::ShouldRearrangeMeta) { - prepack_quant_offsets_ref( - rows, columns, tensor_offset.const_ref(), tensor_packed_zp_ref); - } - - std::vector packed_zp(meta_shape.product()); - MatrixRef tensor_packed_zp( - packed_zp, meta_shape); - Base::prepack_quant_offsets(rows, columns, o_zp, packed_zp); - - for (int col = 0; col < tensor_packed_zp.shape()[1]; ++col) { - for (int row = 0; row < tensor_packed_zp.shape()[0]; ++row) { - EXPECT_EQ(tensor_packed_zp_ref.at(row, col), tensor_packed_zp.at(row, col)) - << "prepacked offsets mismatch at [" << row << "," << col << "]" - << " shape[" << rows << "," << columns << "]" - << (ColumnMajorQuantBlocking ? "Column-wise-block" : "Row-wise-block") - << std::endl; - } - } - } -} - -// TODO: code runs on CPU, but this is for sm80 only, maybe enable only when test on sm80 -TEST(BlkQ4_GEMM, PrepackSm80Test) { - testPrepack(32, 32); - testPrepack(32, 32, false); - testPrepack(32, 32); - testPrepack(32, 32, false); - testPrepack(32, 64); - testPrepack(32, 128); - testPrepack(32, 256); - testPrepack(64, 32); - testPrepack(128, 32); - testPrepack(256, 32); - testPrepack(256, 256); - testPrepack(32, 128, false); - testPrepack(128, 32, false); - testPrepack(256, 256, false); - testPrepack(32, 64); - testPrepack(32, 128); - testPrepack(32, 256); - testPrepack(64, 32); - testPrepack(128, 32); - testPrepack(256, 32); - testPrepack(256, 256); - testPrepack(32, 128, false); - testPrepack(128, 32, false); - testPrepack(256, 256, false); -} - -} // namespace test -} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc index 5505d689381c9..8dfaaedcbb378 100644 --- a/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/cuda_execution_provider_test.cc @@ -29,7 +29,7 @@ TEST(TestDeferredRelease, WithArena) { AllocatorPtr cpu_pinned_alloc = ep.CreatePreferredAllocators()[1]; // let the CudaStream instance "own" the default stream, so we can avoid the // work to initialize cublas/cudnn/... It is ok since it is just a customized unit test. - CudaStream stream(nullptr, gpu_alloctor->Info().device, cpu_pinned_alloc, false, true, nullptr, nullptr); + CudaStream stream(nullptr, gpu_alloctor->Info().device, cpu_pinned_alloc, false, true, nullptr, nullptr, info); // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; @@ -71,7 +71,7 @@ TEST(TestDeferredRelease, WithoutArena) { // For details, see CUDAPinnedAllocator in cuda_allocator.cc. // let the CudaStream instance "own" the default stream, so we can avoid the // work to initialize cublas/cudnn/... It is ok since it is just a customized unit test. - CudaStream stream(nullptr, gpu_alloctor->Info().device, cuda_pinned_alloc, false, true, nullptr, nullptr); + CudaStream stream(nullptr, gpu_alloctor->Info().device, cuda_pinned_alloc, false, true, nullptr, nullptr, info); // 10 MB const size_t n_bytes = 10 * 1000000; const int64_t n_allocs = 64; From 1e78bcea6011ac43093bb08a647cf3717d73047a Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Tue, 5 Mar 2024 13:33:01 -0800 Subject: [PATCH 196/207] Implement CUDA IsInf-10,20 (#19772) ### Description Implment IsInf-10,20 for CUDA. Add FP16 types also on CPU. ### Motivation and Context Certain models lag in performance due to IsInf not available on CUDA. --- docs/OperatorKernels.md | 4 +- .../core/framework/data_types_internal.h | 2 +- .../core/providers/cpu/tensor/isinf.cc | 64 ++++++++++--- .../core/providers/cuda/cu_inc/common.cuh | 94 +++++++++++++++++++ onnxruntime/core/providers/cuda/cuda_common.h | 18 ++++ .../providers/cuda/cuda_execution_provider.cc | 5 + .../cuda/math/unary_elementwise_ops.cc | 38 ++++++++ .../cuda/math/unary_elementwise_ops.h | 12 +++ .../cuda/math/unary_elementwise_ops_impl.cu | 38 ++++++++ .../cuda/math/unary_elementwise_ops_impl.h | 15 +++ .../core/providers/rocm/cu_inc/common.cuh | 94 +++++++++++++++++++ .../providers/rocm/rocm_execution_provider.cc | 9 ++ .../test/providers/cpu/tensor/isinf_test.cc | 42 +++++++++ 13 files changed, 420 insertions(+), 15 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 71b0def659741..4514a85531d6b 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -160,7 +160,7 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)| -|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| |IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| |||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| @@ -631,6 +631,8 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| |||[1, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |LSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|14+|**T** = tensor(double), tensor(float), tensor(float16)
**T1** = tensor(int32)| diff --git a/include/onnxruntime/core/framework/data_types_internal.h b/include/onnxruntime/core/framework/data_types_internal.h index fbeee8a2aedc5..3a3b5cb6888f2 100644 --- a/include/onnxruntime/core/framework/data_types_internal.h +++ b/include/onnxruntime/core/framework/data_types_internal.h @@ -305,7 +305,7 @@ class CallableDispatchableHelper { return 0; } - void CheckCalledOnce() { + void CheckCalledOnce() const { ORT_ENFORCE(called_ == 1, "Unsupported data type: ", dt_type_); } }; diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc index 1b449f46927a2..9d18d1fa62288 100644 --- a/onnxruntime/core/providers/cpu/tensor/isinf.cc +++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc @@ -23,7 +23,9 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( using IsInfTypesOpset20 = TypeList< float, - double + double, + MLFloat16, + BFloat16 #if !defined(DISABLE_FLOAT8_TYPES) , Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ @@ -76,10 +78,8 @@ ONNX_CPU_OPERATOR_KERNEL( IsInf); IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) { - Status status = info.GetAttr("detect_positive", &detect_positive_); - ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive"); - status = info.GetAttr("detect_negative", &detect_negative_); - ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative"); + detect_positive_ = info.GetAttrOrDefault("detect_positive", 1); + detect_negative_ = info.GetAttrOrDefault("detect_negative", 1); opset_ = info.node().SinceVersion(); } @@ -87,29 +87,67 @@ namespace isinf_internal { template struct ComputeDispatchTarget { void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { - const auto total_items = X.Shape().Size(); + auto input_data = X.DataAsSpan(); auto output_data = Y.MutableData(); if (detect_positive && detect_negative) { EigenMap(Y) = EigenMap(X).array().isInf(); } else if (detect_positive) { - auto input_data = X.Data(); - auto end_data = input_data + total_items; std::transform( - input_data, end_data, output_data, [](T v) { + input_data.begin(), input_data.end(), output_data, [](T v) { return (v == std::numeric_limits::infinity()); }); } else if (detect_negative) { - auto input_data = X.Data(); - auto end_data = input_data + total_items; std::transform( - input_data, end_data, output_data, [](T v) { + input_data.begin(), input_data.end(), output_data, [](T v) { return (v == -std::numeric_limits::infinity()); }); } else { // all false - memset(output_data, false, onnxruntime::narrow(total_items)); + memset(output_data, false, input_data.size()); + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto output_data = Y.MutableData(); + auto input_data = X.DataAsSpan(); + if (detect_positive && detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsInfinity(); }); + } else if (detect_positive) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsPositiveInfinity(); }); + } else if (detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](MLFloat16 v) { return v.IsNegativeInfinity(); }); + } else { + // all false + memset(output_data, false, input_data.size()); + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto output_data = Y.MutableData(); + auto input_data = X.DataAsSpan(); + if (detect_positive && detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsInfinity(); }); + } else if (detect_positive) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsPositiveInfinity(); }); + } else if (detect_negative) { + std::transform(input_data.begin(), input_data.end(), output_data, + [](BFloat16 v) { return v.IsNegativeInfinity(); }); + } else { + // all false + memset(output_data, false, input_data.size()); } } }; diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 66794f88d8670..bba9178348132 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -438,6 +438,100 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) { return fmodf((float)a, (float)b); } +namespace isinf_details { +template +struct IsInfTyped { + static __device__ __inline__ bool IsInf(T a) { + // cast is needed because on non MS compilers, + // because there isinf() returns int + // and we want to avoid stupid warnings + return static_cast(isinf(a)); + } + static __device__ __inline__ bool IsInfPos(T a) { + return a == std::numeric_limits::infinity(); + } + static __device__ __inline__ bool IsInfNeg(T a) { + return a == -std::numeric_limits::infinity(); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(half a) { + return MLFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(half a) { + return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(half a) { + return MLFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(BFloat16 a) { + return BFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template +struct ReturnFalse { + constexpr static bool __device__ __inline__ IsInf(T) { return false; } + constexpr static bool __device__ __inline__ IsInfPos(T) { return false; } + constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(Float8E5M2 a) { + return a.val == 0b01111100 || a.val == 0b11111100; + } + static __device__ __inline__ bool IsInfPos(Float8E5M2 a) { + return a.val == 0b01111100; + } + static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) { + return a.val == 0b11111100; + } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +#endif +} // namespace isinf_details + +template +struct _IsInf { + __device__ __inline__ bool operator()(T a) const { + if constexpr (detect_positive && detect_negative) { + return isinf_details::IsInfTyped::IsInf(a); + } else if constexpr (detect_positive) { + return isinf_details::IsInfTyped::IsInfPos(a); + } else if constexpr (detect_negative) { + return isinf_details::IsInfTyped::IsInfNeg(a); + } else { + return false; + } + } +}; + // We would like to use 64-bit integer to support large matrices. However, CUDA seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef CUDA_LONG diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 41c999bacee13..61da125b40953 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -70,6 +70,15 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef Float8E4M3FNUZ MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + template <> class ToCudaType { public: @@ -79,6 +88,15 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef Float8E5M2FNUZ MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + #endif inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 8ba282031a5d4..3c0930638a205 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -830,6 +830,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 10, TopK); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 12, Mod); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 10, 19, IsInf); // opset 11 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress); @@ -1342,6 +1343,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, S class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, float, Gelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, double, Gelu); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -1739,6 +1741,8 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 11 BuildKernelCreateInfo, @@ -2250,6 +2254,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #endif }; diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index fd8b69d7bd2f5..00de1b37f3302 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -71,6 +71,44 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa return Status::OK(); \ } +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + IsInf, + kOnnxDomain, + 10, + 19, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsInf); + +ONNX_OPERATOR_KERNEL_EX( + IsInf, + kOnnxDomain, + 20, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsInf); + +IsInf::IsInf(const OpKernelInfo& info) : UnaryElementwise(info) { + detect_positive_ = static_cast(info.GetAttrOrDefault("detect_positive", 1)); + detect_negative_ = static_cast(info.GetAttrOrDefault("detect_negative", 1)); + opset_ = info.node().SinceVersion(); +} + +Status IsInf::ComputeInternal(OpKernelContext* context) const { + UnaryElementwisePreparation p; + ORT_RETURN_IF_ERROR(UnaryElementwise::Prepare(context, &p)); + + Explicit_Impl_IsInf(Stream(context), opset_, detect_positive_, detect_negative_, + p.input_tensor->GetElementType(), p.input_tensor->DataRaw(), + p.output_tensor->MutableData(), + p.input_tensor->Shape().Size()); + return Status::OK(); +} + #define UNARY_OP_VERSIONED_TYPED(name, startver, endver, T) \ UNARY_ELEMENTWISE_REGISTER_VERSIONED_KERNEL(name, startver, endver, T) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h index 775b78c43a736..3b7d6df7221b7 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once + #include "core/providers/cuda/cuda_kernel.h" namespace onnxruntime { @@ -119,5 +120,16 @@ class Sign final : public UnaryElementwise { Status ComputeInternal(OpKernelContext* context) const override; }; +class IsInf final : public UnaryElementwise { + public: + explicit IsInf(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + private: + bool detect_positive_{true}; + bool detect_negative_{true}; + int opset_; +}; + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 73c5ac80756be..fd8f7929d4426 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -11,6 +11,7 @@ #endif namespace onnxruntime { + namespace cuda { #define OP(name, expr) \ @@ -284,5 +285,42 @@ EXPLICIT_IMPL_CASTSAT(__nv_bfloat16, Float8E5M2) #endif +namespace isinf_details { +template +struct IsInf_DispFunc { + void operator()(cudaStream_t stream, const void* input_raw, bool* output_data, + bool detect_positive, bool detect_negative, size_t count) const { + using CudaType = typename ToCudaType::MappedType; + const auto* input_data = reinterpret_cast(input_raw); + if (detect_positive && detect_negative) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } else if (detect_positive) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } else if (detect_negative) { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } else { + UnaryElementWiseImpl(stream, input_data, output_data, _IsInf{}, count); + } + } +}; + +} // namespace isinf_details + +void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, + bool detect_positive, bool detect_negative, + int32_t input_data_type, + const void* input_raw, bool* output_data, + size_t count) { + if (op_set < 20) { + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, + detect_positive, detect_negative, count); + } else { + utils::MLTypeCallDispatcher dispatcher{input_data_type}; + dispatcher.Invoke(stream, input_raw, output_data, + detect_positive, detect_negative, count); + } +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h index 608a81a24cf4f..a606d479bc79b 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.h @@ -137,5 +137,20 @@ void Impl_CastSat( #endif +// IsInf + +#if !defined(DISABLE_FLOAT8_TYPES) +#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16, Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, \ + Float8E5M2FNUZ +#else +#define ISINF_OPSET20_ALL_FLOATS float, double, MLFloat16, BFloat16 +#endif + +void Explicit_Impl_IsInf(cudaStream_t stream, int op_set, + bool detect_positive, bool detect_negative, + int32_t input_data_type, + const void* input_raw, bool* output_data, + size_t count); } // namespace cuda + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 5f966ac746fcb..f3685606c17f5 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -335,6 +335,100 @@ __device__ __inline__ BFloat16 _Fmod(BFloat16 a, BFloat16 b) { return fmodf((float)a, (float)b); } +namespace isinf_details { +template +struct IsInfTyped { + static __device__ __inline__ bool IsInf(T a) { + // cast is needed because on non MS compilers, + // because there isinf() returns int + // and we want to avoid stupid warnings + return static_cast(isinf(a)); + } + static __device__ __inline__ bool IsInfPos(T a) { + return a == std::numeric_limits::infinity(); + } + static __device__ __inline__ bool IsInfNeg(T a) { + return a == -std::numeric_limits::infinity(); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(half a) { + return MLFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~MLFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(half a) { + return MLFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(half a) { + return MLFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == + static_cast(*reinterpret_cast(&a) & ~BFloat16::kSignMask); + } + static __device__ __inline__ bool IsInfPos(BFloat16 a) { + return BFloat16::kPositiveInfinityBits == *reinterpret_cast(&a); + } + static __device__ __inline__ bool IsInfNeg(BFloat16 a) { + return BFloat16::kNegativeInfinityBits == *reinterpret_cast(&a); + } +}; + +#if !defined(DISABLE_FLOAT8_TYPES) + +template +struct ReturnFalse { + constexpr static bool __device__ __inline__ IsInf(T) { return false; } + constexpr static bool __device__ __inline__ IsInfPos(T) { return false; } + constexpr static bool __device__ __inline__ IsInfNeg(T) { return false; } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +template <> +struct IsInfTyped { + static __device__ __inline__ bool IsInf(Float8E5M2 a) { + return a.val == 0b01111100 || a.val == 0b11111100; + } + static __device__ __inline__ bool IsInfPos(Float8E5M2 a) { + return a.val == 0b01111100; + } + static __device__ __inline__ bool IsInfNeg(Float8E5M2 a) { + return a.val == 0b11111100; + } +}; + +template <> +struct IsInfTyped : ReturnFalse {}; + +#endif +} // namespace isinf_details + +template +struct _IsInf { + __device__ __inline__ bool operator()(T a) const { + if constexpr (detect_positive && detect_negative) { + return isinf_details::IsInfTyped::IsInf(a); + } else if constexpr (detect_positive) { + return isinf_details::IsInfTyped::IsInfPos(a); + } else if constexpr (detect_negative) { + return isinf_details::IsInfTyped::IsInfNeg(a); + } else { + return false; + } + } +}; + // We would like to use 64-bit integer to support large matrices. However, ROCM seems to support only 32-bit integer // For now, use int32_t to ensure that both Linux and Windows see this as 32 bit integer type. #ifndef HIP_LONG diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 0265c06b9a938..4a679b790ee40 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -793,6 +793,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, MLFloat16, ThresholdedRelu); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 10, TopK); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 12, Mod); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 10, 19, IsInf); // opset 11 class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 11, 11, float, ArgMax); @@ -1342,6 +1343,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, R class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 19, Shape); +// Opset 20 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 20, IsInf); + template <> KernelCreateInfo BuildKernelCreateInfo() { return {}; @@ -1738,6 +1742,8 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, // opset 11 BuildKernelCreateInfo, @@ -2294,6 +2300,9 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // opset 20 + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc index 2e583c5d2547b..bd97306142f18 100644 --- a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc @@ -99,6 +99,48 @@ TEST(IsInfTest, test_isinf_negative_double20) { run_is_inf_test(20, 0, 1, input, output); } +TEST(IsInfTest, test_isinf_mlfloat16) { + std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16, + MLFloat16::NegativeInfinity, MLFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(20, 1, 1, input, output); +} + +TEST(IsInfTest, test_isinf_positive_mlfloat16) { + std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16, + MLFloat16::NegativeInfinity, MLFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} + +TEST(IsInfTest, test_isinf_negative_mlfloat16) { + std::initializer_list input = {MLFloat16{-1.7f}, MLFloat16::NaN, MLFloat16::Infinity, 3.6_fp16, + MLFloat16::NegativeInfinity, MLFloat16::Infinity}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(20, 0, 1, input, output); +} + +TEST(IsInfTest, test_isinf_bfloat16) { + std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, + BFloat16::NegativeInfinity, BFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(20, 1, 1, input, output); +} + +TEST(IsInfTest, test_isinf_positive_bfloat16) { + std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, + BFloat16::NegativeInfinity, BFloat16::Infinity}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} + +TEST(IsInfTest, test_isinf_negative_bfloat16) { + std::initializer_list input = {BFloat16{-1.7f}, BFloat16::NaN, BFloat16::Infinity, 3.6_bfp16, + BFloat16::NegativeInfinity, BFloat16::Infinity}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(20, 0, 1, input, output); +} + #if !defined(DISABLE_FLOAT8_TYPES) TEST(IsInfTest, test_Float8E4M3FN) { std::initializer_list input = { From d9730c7f43437070eba28d8dcdd9f94c102265ab Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Tue, 5 Mar 2024 14:39:36 -0800 Subject: [PATCH 197/207] [TensorRT EP] Fix bug for DDS output handling for empty tensor (#19575) When the DDS output is empty tensor (i.e. any of the dimension is 0), TRT EP won't perform either cudaMemcpyAsync() nor cuda::Impl_Cast(), to prevent accidentally overwriting other location that might belong to other tensors. This PR also refactors the code to only allocate single bytes for all empty tensors. #TODO: add unit tests to cover the DDS code paths or doing more testing with concurrent,sequential, threaded faster-rcnn using onnx_test_runner and verifying outputs --------- Co-authored-by: Chi Lo --- cmake/deps.txt | 4 +- .../tensorrt/tensorrt_execution_provider.cc | 465 ++++++------------ 2 files changed, 160 insertions(+), 309 deletions(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index 9cba25b00157d..9630b6185fcf6 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -37,8 +37,8 @@ mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 -#use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) -onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 +#use the commit of Final DDS removal. DDS output is now supported by ORT TRT. +onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/bacfaaa951653cd4e72efe727a543567cb38f7de.zip;26434329612e804164ab7baa6ae629ada56c1b26 protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa protoc_win64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win64.zip;b4521f7ada5b260380f94c4bd7f1b7684c76969a protoc_win32;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-win32.zip;3688010318192c46ce73213cdfb6b3e5656da874 diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 81346671f2aad..157cd0a200b35 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -717,6 +717,77 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector(); \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + data = const_cast(input_tensor_ptr); \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_CAST_INPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto input_tensor_ptr = input_tensor.GetTensorData(); \ + if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + data = scratch_buffers.back().get(); \ + cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + data = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + buffers[output_name] = output_tensor_ptr; \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + } \ + break; \ + } + +#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = static_cast(elem_cnt); \ + } else { \ + scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ + buffers[output_name] = scratch_buffers.back().get(); \ + output_dim_sizes[i] = 1; \ + } \ + break; \ + } + +#define CASE_COPY_TENSOR(DATA_TYPE, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(DstT), cudaMemcpyDeviceToDevice, stream)); \ + } \ + break; \ + } + +#define CASE_CAST_TENSOR(DATA_TYPE, SrcT, DstT) \ + case DATA_TYPE: { \ + auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ + if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ + cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ + } \ + break; \ + } + /* * Set TensorRT execution context input. * @@ -737,6 +808,17 @@ Status BindContextInput(Ort::KernelContext& ctx, auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo(); const auto tensor_shapes = tensor_info.GetShape(); const auto tensor_type = tensor_info.GetElementType(); + /* + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ */ + const auto elem_cnt = tensor_info.GetElementCount(); if (trt_engine->isShapeInferenceIO(input_name)) { // Get the shape value of "shape tensor" @@ -765,113 +847,24 @@ Status BindContextInput(Ort::KernelContext& ctx, ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to call nvinfer1::IExecutionContext::setInputShape() for input '" + error_input_name + "'")); } - // Bind "execution tensor" input buffers + + // Bind "execution tensor" input buffer + // + // Note: If an engine binding is an empty tensor, it still needs a non-null memory address, and different tensors should have different addresses. + // Therefore, in the case of empty tensor, TRT EP always allocates a dummy byte. + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#empty-tensors void* data = nullptr; switch (tensor_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint16_t))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(bool))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int8_t))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint8_t))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - data = scratch_buffers.back().get(); - } else { - data = const_cast(input_tensor_ptr); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64 - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - data = scratch_buffers.back().get(); - } else { - SafeInt input_dim_size = 1; - for (int j = 0, end = nb_dims; j < end; ++j) { - if (tensor_shapes[j] == 0) { - input_dim_size = 1; - break; - } else { - input_dim_size *= tensor_shapes[j]; - } - } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, input_dim_size * sizeof(int32_t))); - data = scratch_buffers.back().get(); - cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), input_dim_size); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - // Cast DOUBLE input to FLOAT because TensorRT doesn't fully support INT64 - auto input_tensor_ptr = input_tensor.GetTensorData(); - if (input_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - data = scratch_buffers.back().get(); - } else { - SafeInt input_dim_size = 1; - for (int j = 0, end = nb_dims; j < end; ++j) { - if (tensor_shapes[j] == 0) { - input_dim_size = 1; - break; - } else { - input_dim_size *= tensor_shapes[j]; - } - } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, input_dim_size * sizeof(float))); - data = scratch_buffers.back().get(); - cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), input_dim_size); - } - break; - } + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) + // Cast int64 input to int32 input because TensorRT doesn't support int64 + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) + // Cast double input to float because TensorRT doesn't support double + CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); @@ -884,7 +877,7 @@ Status BindContextInput(Ort::KernelContext& ctx, } /* - * Set TensorRT execution context output. + * Bind TensorRT execution context output. * * Please note that the "data-depedent shape" output needs corresponding allocator provided. * @@ -912,7 +905,6 @@ Status BindContextOutput(Ort::KernelContext& ctx, size_t i, std::unordered_map& output_tensors, std::unordered_map& output_dim_sizes, - std::unordered_set& dds_output_set, DDSOutputAllocatorMap& dds_output_allocator_map, std::vector>& scratch_buffers, OrtAllocator* alloc, @@ -920,142 +912,47 @@ Status BindContextOutput(Ort::KernelContext& ctx, // Get output shape nvinfer1::Dims dims = trt_context->getTensorShape(output_name); int nb_dims = dims.nbDims; - bool is_dds_output = false; + bool is_DDS = false; std::vector output_shapes(nb_dims); for (int j = 0, end = nb_dims; j < end; ++j) { // data-dependent shape if (dims.d[j] == -1) { - is_dds_output = true; - dds_output_set.emplace(output_name); + is_DDS = true; break; } output_shapes[j] = dims.d[j]; } + auto known_DDS = dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end(); + // If the output tensor has data-dependent shape, TRT EP will provide an IOutputAllocator for enqueueV3 to dynamically allocate memory buffer. // Once enqueueV3 returns, TRT EP will then bind the output allocation to ORT kernel context output. // (Please note that we take strategy A mentioned in https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#dynamic-shaped-output, // which we defer allocation until the size is known and don't call IExecution::setTensorAddress) // // Otherwise, if the shape of the output tensor is known prior to the runtime, ORT will pre-allocate memory buffer for the output tensor for enqueueV3. - if (is_dds_output) { - if (dds_output_allocator_map.find(output_name) == dds_output_allocator_map.end()) { + if (is_DDS || known_DDS) { + if (!known_DDS) { auto allocatorPtr = std::make_unique(); trt_context->setOutputAllocator(output_name, allocatorPtr.get()); dds_output_allocator_map[output_name] = std::move(allocatorPtr); - } else { - trt_context->setOutputAllocator(output_name, dds_output_allocator_map[output_name].get()); } } else { output_tensors[i] = ctx.GetOutput(output_index, output_shapes); auto& output_tensor = output_tensors[i]; + const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + switch (output_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint16_t))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(bool))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int8_t))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(uint8_t))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - buffers[output_name] = scratch_buffers.back().get(); - } else { - buffers[output_name] = output_tensor_ptr; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - // Allocate INT32 CUDA memory for INT64 output type because TensorRT doesn't fully support INT64 - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(int32_t))); - buffers[output_name] = scratch_buffers.back().get(); - output_dim_sizes[i] = 1; - } else { - SafeInt output_dim_size(1); - for (int j = 0, end = nb_dims; j < end; ++j) { - if (dims.d[j] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= dims.d[j]; - } - } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(int32_t))); - buffers[output_name] = scratch_buffers.back().get(); - output_dim_sizes[i] = output_dim_size; - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - // Allocate FLOAT CUDA memory for DOUBLE output type because TensorRT doesn't fully support DOUBLE - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr == nullptr) { - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, sizeof(float))); - buffers[output_name] = scratch_buffers.back().get(); - output_dim_sizes[i] = 1; - } else { - SafeInt output_dim_size(1); - for (int j = 0, end = nb_dims; j < end; ++j) { - if (dims.d[j] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= dims.d[j]; - } - } - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, output_dim_size * sizeof(float))); - buffers[output_name] = scratch_buffers.back().get(); - output_dim_sizes[i] = output_dim_size; - } - break; - } + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) + // Allocate int32 CUDA memory for int64 output type because TensorRT doesn't support int64 + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t, int32_t) + // Allocate float CUDA memory for double output type because TensorRT doesn't support double + CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported."); @@ -1068,10 +965,13 @@ Status BindContextOutput(Ort::KernelContext& ctx, } /* - * Set ORT kernel context Output. + * Bind ORT kernel context Output. * - * Note: In the case of DDS (data-dependent shape) output, TRT requires a provided allocator to allocate memory during runtime. + * In the case of DDS (data-dependent shape) output, TRT requires a provided allocator to allocate memory during runtime. * Once the output has been put in the allocation buffer, ORT calls this function to bind the allocation to ORT kernel context output. + * + * Note: Current approach of setting the ORT kernel context output is copying the output data from allocation buffer to ORT context output address which is not optimal, + * we are waiting for ORT core to support "assign" memory address to ORT context output. Some works need to be done in ORT memory planner to be aware of this memory support. */ Status BindKernelOutput(Ort::KernelContext& ctx, OrtMemoryInfo* mem_info, @@ -1083,93 +983,46 @@ Status BindKernelOutput(Ort::KernelContext& ctx, auto allocator = allocator_map[output_name].get(); auto& shape = allocator->getOutputShape(); auto output_tensor = ctx.GetOutput(output_index, shape); + + /* + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ */ auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + /* + * Copy output data from allocation buffer to ORT kernel context output location or + * cast (int32 or float) -> (int64 or double) to ORT kernel context output location. + * + * Note: + * 1. If the output tensor is empty tensor (i.e. any of the dimension is 0) which means element count is 0, + * TRT EP does not perform cuda memory copy nor cuda cast to prevent overwriting other location that might belong to other tensors. + * 2. The cudaMemcpyAsync() and cuda::Impl_Cast() (implemented as _UnaryElementWise() in cuda ep) are all async, but we + * don't need to explicitly call cudaStreamSynchronize() after those APIs due to CUDA EP and TRT EP uses same stream, + * and within the same stream, operations are guaranteed to be executed in order. + */ switch (output_type) { - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(float), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(uint16_t), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(bool), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(int8_t), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(uint8_t), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(output_tensor_ptr, allocator->getBuffer(), elem_cnt * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream)); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - // The allocation buffer holds the INT32 output data since TRT doesn't support INT64 but INT32. - // So, we need to cast the data from INT32 to INT64 and then set INT64 output data to kernel context. - SafeInt output_dim_size(1); - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= shape[i]; - } - } - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), output_dim_size); - } - break; - } - case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { - // The allocation buffer holds the FLOAT output data since TRT doesn't support DOUBLE but FLOAT. - // So, we need to cast the data from FLOAT to DOUBEL and then set DOUBLE output data to kernel context. - SafeInt output_dim_size(1); - for (size_t i = 0; i < shape.size(); ++i) { - if (shape[i] == 0) { - output_dim_size = 1; - break; - } else { - output_dim_size *= shape[i]; - } - } - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), output_dim_size); - } - break; - } + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, uint16_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, bool) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, int8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) + CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) + // The allocation buffer holds the int32 output data since TRT doesn't support int64. So, we need to cast the data (int32 -> int64) for ORT kernel output. + CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int32_t, int64_t) + // The allocation buffer holds the float output data since TRT doesn't support double. So, we need to cast the data (float -> double) for ORT kernel output. + CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } - CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); return Status::OK(); } @@ -3513,7 +3366,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView output_tensors.reserve(num_outputs); std::unordered_map output_dim_sizes; output_dim_sizes.reserve(num_outputs); - std::unordered_set dds_output_set; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -3531,7 +3383,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView } Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_set, dds_output_allocator_map, scratch_buffers, alloc, buffers); + dds_output_allocator_map, scratch_buffers, alloc, buffers); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } @@ -3590,7 +3442,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView output_type = iter->second; } - if (dds_output_set.find(output_name) != dds_output_set.end()) { + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { size_t output_index = 0; const auto& index_iter = output_indexes.find(output_name); if (index_iter != output_indexes.end()) { @@ -3806,7 +3658,6 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con output_tensors.reserve(num_outputs); std::unordered_map output_dim_sizes; output_dim_sizes.reserve(num_outputs); - std::unordered_set dds_output_set; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -3824,7 +3675,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con } Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_set, dds_output_allocator_map, scratch_buffers, alloc, buffers); + dds_output_allocator_map, scratch_buffers, alloc, buffers); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } @@ -3883,7 +3734,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con output_type = iter->second; } - if (dds_output_set.find(output_name) != dds_output_set.end()) { + if (dds_output_allocator_map.find(output_name) != dds_output_allocator_map.end()) { size_t output_index = 0; const auto& index_iter = output_indexes.find(output_name); if (index_iter != output_indexes.end()) { From d10256975527e8e041cedb19227cb5f207087c42 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 6 Mar 2024 10:06:25 +0800 Subject: [PATCH 198/207] Fix seed for recomputed Dropout (#19715) ### Fix seed for recomputed Dropout If Dropout node is recomputed in the backward, we should make sure its execution is same as the run in the forward. If we don't set seed attribute, then this cannot be guaranteed. Add ` export ORTMODULE_MEMORY_OPT_LEVEL=2` to enabled per layer recompute with compromised recomputable subgraphs. --- docs/Memory_Optimizer.md | 1 + docs/ORTModule_Training_Guidelines.md | 5 ++- onnxruntime/core/common/string_utils.h | 12 +++++++ .../memory_optimizer/memory_insight.cc | 6 +++- .../memory_optimizer/memory_optimizer.cc | 34 +++++++++++++++++-- .../memory_optimizer/memory_optimizer.h | 1 + .../ortmodule/_graph_execution_manager.py | 7 +++- .../training/ortmodule/_runtime_inspector.py | 34 ++++++++++++++----- .../python/training/ortmodule/options.py | 18 ++++++++-- 9 files changed, 101 insertions(+), 17 deletions(-) diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md index 97f7e7ff2c14b..eaa48c9da0609 100644 --- a/docs/Memory_Optimizer.md +++ b/docs/Memory_Optimizer.md @@ -51,6 +51,7 @@ There are two modes to enable the memory optimizations: - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1 ``` 3. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case. +4. By `export ORTMODULE_MEMORY_OPT_LEVEL=2`, all plans including compromised recomptable subgraphs will also be enabled. ### Mode 2 - Advanced Usage (User Selected Subgraph Recompute) diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md index 84631bd1f6555..54137937ad56d 100644 --- a/docs/ORTModule_Training_Guidelines.md +++ b/docs/ORTModule_Training_Guidelines.md @@ -287,7 +287,10 @@ A classical usage of disabling the deep copy: when the deep copy before module e #### ORTMODULE_MEMORY_OPT_LEVEL - **Feature Area**: *ORTMODULE/Optimizations* -- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. Setting the level to be 0 means all detected subgraphs with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. When level is not 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details. +- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. + - Setting the level to be 1 means all detected recomputable subgraphs (NOT including compromised recomputable graphs) with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. + - Setting the level to be 2 means all detected recomputable subgraphs (including compromised recomputable graphs) with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. + - When the level is 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details. ```bash export ORTMODULE_MEMORY_OPT_LEVEL=0 diff --git a/onnxruntime/core/common/string_utils.h b/onnxruntime/core/common/string_utils.h index eca1221e84cb8..03e94cefd0564 100644 --- a/onnxruntime/core/common/string_utils.h +++ b/onnxruntime/core/common/string_utils.h @@ -65,5 +65,17 @@ inline std::string TrimString(std::string s) { return s; } +/** + * So use this simple hash to generate unique int by given string input. + */ +inline uint32_t GetHashFromString(const std::string& str_value) { + uint32_t hash = 0; + for (char const& c : str_value) { + hash = hash * 101 + c; + } + + return hash; +} + } // namespace utils } // namespace onnxruntime diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 3fbdd5da7b768..08c402bf669c8 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -9,6 +9,8 @@ #include #include +#include "core/common/string_utils.h" +#include "core/framework/random_seed.h" #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "orttraining/core/optimizer/memory_optimizer/common.h" @@ -284,7 +286,9 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, memory_opt_planner.AddNodeOptimizationPlan(p_node, std::move(recompute_plan)); } - if (can_compromise_stashed_activation) { + // Only detect compromise recompute when recompute is not found, in case there are multiple recompute plans + // for the same named activations, then user might enable those conflicting recompute plans by mistakes. + if (recompute_plan == nullptr && can_compromise_stashed_activation) { MO_LOG_DEBUG_INFO(logger, "Searching Node " + p_node->Name() + "(" + p_node->OpType() + ") for compromised recompute"); // If the subgraph recompute can save memory by comprising the assumption - recompute graphs' input must exist diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index 49e026ca86bd3..525e3b4b8de35 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -28,6 +28,29 @@ constexpr bool IsForwardPassOperator(ptrdiff_t op_order_in_topological_sort, return op_order_in_topological_sort <= boundary_op_order_in_topological_sort; } +// Reset seed attribute for the dropout node if the seed is not set. +bool SetSeedForDropoutNode(Node& node) { + // ONNX Dropout 1, 6, 7, 10 do not have seed attribute, so we remove them from the recompute support. + // TODO(pengwa): add the opset check in GetAllowedRecomputeOps. + if (graph_utils::IsSupportedOptypeVersionAndDomain(node, "Dropout", {12, 13}, kOnnxDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "BitmaskDropout", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "BiasDropout", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "BitmaskBiasDropout", {1}, kMSDomain) || + graph_utils::IsSupportedOptypeVersionAndDomain(node, "BiasSoftmaxDropout", {1}, kMSDomain)) { + auto& attrs = node.GetAttributes(); + if (attrs.count("seed")) { + return false; + } + + int64_t seed = static_cast(utils::GetHashFromString(node.OutputDefs()[0]->Name())) + + utils::GetRandomSeed(); + node.AddAttribute("seed", seed); + return true; + } + + return false; +} + } // namespace Status MemoryOptimizer::ParseOptimizationConfigFromString(const std::string& memory_optimizer_config, @@ -74,7 +97,7 @@ bool MemoryOptimizer::ModifyGraph(Graph& graph, optimizer::memory_optimizer::NodeRecomputePlan* recompute_plan = dynamic_cast(node_plan.get()); ORT_ENFORCE(recompute_plan != nullptr); - ORT_ENFORCE(CreateRecomputeGraph(graph, recompute_plan->GetNodesInTopoOrder(), replacement_node_ptr).IsOK()); + ORT_ENFORCE(CreateRecomputeGraph(graph, recompute_plan->GetNodesInTopoOrder(), logger, replacement_node_ptr).IsOK()); } else { ORT_THROW("unsupported optimization type found."); } @@ -93,7 +116,7 @@ bool MemoryOptimizer::ModifyGraph(Graph& graph, auto tid = node_index_to_its_order_in_topological_sort_map.find(it->GetNode().Index()); // It is possible the consumer node is newly added as the recompute node, so we need a check here. - // For those kind of ops, we can treat them as backward ops. + // For those kinds of ops, we can treat them as backward ops. if (tid == node_index_to_its_order_in_topological_sort_map.end() || !IsForwardPassOperator(node_index_to_its_order_in_topological_sort_map.at(tid->first), boundary_op_order_in_topological_sort)) { @@ -223,6 +246,7 @@ void MemoryOptimizer::PrintSummary(const optimizer::memory_optimizer::MemoryOpti Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, const InlinedVector& nodes_in_topological_order, + const logging::Logger& logger, Node*& new_output_node_ptr) const { InlinedHashMap self_contained_outputs_map; for (size_t i = 0; i < nodes_in_topological_order.size(); ++i) { @@ -236,6 +260,12 @@ Status MemoryOptimizer::CreateRecomputeGraph(Graph& graph, continue; } + bool seed_reset = SetSeedForDropoutNode(*node_to_duplicate); + if (seed_reset) { + LOGS(logger, VERBOSE) << "Set seed for Node " << node_to_duplicate->Name() << "(" << node_to_duplicate->OpType() + << ")."; + } + InlinedVector new_input_args; new_input_args.reserve(node_to_duplicate->MutableInputDefs().size()); for (NodeArg* input_arg : node_to_duplicate->MutableInputDefs()) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h index b3e05fd334e48..1d837038e76c1 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.h @@ -94,6 +94,7 @@ class MemoryOptimizer : public GraphTransformer { */ Status CreateRecomputeGraph(Graph& graph, const InlinedVector& nodes_in_topological_order, + const logging::Logger& logger, Node*& recompute_subgraph_output_node) const; /************************************************** diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index e189ffff9cc7f..c67b05758c5aa 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -754,6 +754,11 @@ def _add_record(tbl, columns): if self._runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: opt_config_to_display = "ALL_RECOMPUTE_FOR_EACH_LAYER" + elif ( + self._runtime_options.memory_optimization_level + == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE + ): + opt_config_to_display = "ALL_RECOMPUTE_FOR_EACH_LAYER_WITH_COMPROMISE" else: opt_config_to_display = self._runtime_options.memory_optimizer_config @@ -766,7 +771,7 @@ def _add_record(tbl, columns): f"Memory Optimization Level: [{_MemoryOptimizationLevel.to_string(self._runtime_options.memory_optimization_level)}], " f"Optimization Config: [{opt_config_to_display}]" if len(self._runtime_options.memory_optimizer_config) > 0 - else "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=,,..." + else "Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1/2 or ORTMODULE_MEMORY_OPT_CONFIG=,,..." ), ], ) diff --git a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py index 772b9bd9e31ae..22e31466887a6 100644 --- a/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py +++ b/orttraining/orttraining/python/training/ortmodule/_runtime_inspector.py @@ -545,7 +545,10 @@ def find_memory_optimization_opportunity(self, execution_agent: TrainingAgent, r # If the memory optimization level is aggressive, we will first collect all # recompute subgraph by passing empty memory_optimizer_config to get_serialized_ortmodule_memory_stat. - if runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: + if runtime_options.memory_optimization_level in [ + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, + ]: memory_optimizer_config = "" ( @@ -581,16 +584,27 @@ def find_memory_optimization_opportunity(self, execution_agent: TrainingAgent, r self.cluster_id_combination_to_saving_symbolics_map[cluster_id] = values # For aggressive memory optimization, we update the memory_optimizer_config using all. - if runtime_options.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: + if runtime_options.memory_optimization_level > 0: recompute_configs = [] for cluster_id in self.cluster_id_combination_to_saving_symbolics_map: config_values = cluster_id.split(":") opt_type = int(config_values[1]) - # TODO(pengwa): use enum instead of 1 here. - if opt_type != 1: - continue - - recompute_configs.append(cluster_id) + if ( + runtime_options.memory_optimization_level + == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE + and opt_type == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE + ): + recompute_configs.append(cluster_id) + elif ( + runtime_options.memory_optimization_level + == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE + and opt_type + in [ + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, + ] + ): + recompute_configs.append(cluster_id) runtime_options.memory_optimizer_config = ",".join(recompute_configs) @@ -699,14 +713,16 @@ def _get_user_config_without_freq(configs: str): notes = [] if details: notes.append( - "[Memory Optimizer] Use ORTMODULE_MEMORY_OPT_LEVEL=1 to enable all recomputable subgraphs per transformer layer." + "[Memory Optimizer] Use ORTMODULE_MEMORY_OPT_LEVEL=1/2 to enable all recomputable subgraphs per transformer layer." ) saving_recommendation = "[Memory Optimizer] Or use comma as a delimiter to selectively enable multiple memory optimization plans:\n" saving_recommendation += " export ORTMODULE_MEMORY_OPT_CONFIG=,,..." notes.append(saving_recommendation) - saving_recommendation = "memory saving is calculated based on the 1st batch symbolic dim values:\n" + saving_recommendation = ( + "[Memory Optimizer] memory saving is calculated based on the 1st batch symbolic dim values:\n" + ) for dim_param, dim_value in self.symbolic_dim_name_to_value_map.items(): saving_recommendation += f" {dim_param}={dim_value}," notes.append(saving_recommendation) diff --git a/orttraining/orttraining/python/training/ortmodule/options.py b/orttraining/orttraining/python/training/ortmodule/options.py index 93d24a34df6bd..7263a5719e262 100644 --- a/orttraining/orttraining/python/training/ortmodule/options.py +++ b/orttraining/orttraining/python/training/ortmodule/options.py @@ -196,7 +196,10 @@ class _MemoryOptimizationLevel(IntFlag): """Enumeration to specify memory optimization level""" USER_SPECIFIED = 0 # Fully respect user-specified config - TRANSFORMER_LAYERWISE_RECOMPUTE = 1 # Enable all recomputable subgraphs per layer + TRANSFORMER_LAYERWISE_RECOMPUTE = ( + 1 # Enable all recomputable subgraphs (excluding compromised recomptable graphs) per layer + ) + TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE = 2 # Enable all recomputable subgraphs per layer @staticmethod def to_string(memory_optimization_level): @@ -206,6 +209,9 @@ def to_string(memory_optimization_level): if memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: return "TRANSFORMER_LAYERWISE_RECOMPUTE" + if memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE: + return "TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE" + return "" @@ -344,7 +350,10 @@ def _override_from_env_vars(self): self.memory_optimization_level = int(os.getenv("ORTMODULE_MEMORY_OPT_LEVEL", self.memory_optimization_level)) user_given_memory_optimizer_config = os.getenv("ORTMODULE_MEMORY_OPT_CONFIG", self.memory_optimizer_config) self.memory_optimizer_config = ",".join([c for c in user_given_memory_optimizer_config.split(",") if c]) - if self.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: + if self.memory_optimization_level in [ + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, + ]: # For transformer layer-wise recompute, we enable layer boundary when detecting subgraphs. # Then all detected subgraphs will not cross different layers. self.recompute_probe_config = "1:1" @@ -419,7 +428,10 @@ def memory_optimizer_is_enabled(self) -> bool: """Check whether memory optimizer is enabled.""" if self.memory_optimization_level == _MemoryOptimizationLevel.USER_SPECIFIED: return len(self.memory_optimizer_config) > 0 - elif self.memory_optimization_level == _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE: + elif self.memory_optimization_level in [ + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE, + _MemoryOptimizationLevel.TRANSFORMER_LAYERWISE_RECOMPUTE_WITH_COMPROMISE, + ]: return True return False From 1bfc26685b51522395e136a606005a72997e6bff Mon Sep 17 00:00:00 2001 From: Vincent Wang Date: Wed, 6 Mar 2024 10:11:46 +0800 Subject: [PATCH 199/207] ATen Op Supports Int Return Type and CPU Tensor Arguments (#19773) This PR: - add support for int as return type, will create a CPU scalar tensor for it. - add attributes to specify which arguments or returns are CPU tensors. - adjust ATen efficient attn to match latest PyTorch native function. - a Triton codegen bugfix by the way. --- .../cpu/aten_ops/aten_op_executor.h | 16 +- onnxruntime/core/framework/utils.cc | 24 ++- .../core/graph/contrib_ops/contrib_defs.cc | 2 + .../python/onnxruntime_pybind_state.cc | 10 +- .../aten_op_executor/__init__.py | 2 +- .../aten_op_executor/aten_op_executor.cc | 62 ++++--- .../ort_torch_ext/__init__.py | 4 +- .../python/training/ort_triton/_ir.py | 3 + .../ortmodule/graph_optimizers/__init__.py | 2 +- .../ortmodule/graph_optimizers/_aten_attn.py | 169 +++--------------- 10 files changed, 96 insertions(+), 198 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h index d72868cd8fa9f..56c8e2911e280 100644 --- a/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h +++ b/onnxruntime/contrib_ops/cpu/aten_ops/aten_op_executor.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace aten_ops { -typedef bool (*IsCpuArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); +typedef bool (*IsTensorArgumentFunc)(const char* op_name, const char* overload_name, size_t index, bool is_input); typedef void (*ExecuteATenOperatorFunc)(const char* op_name, const char* overload_name, size_t input_size, DLManagedTensor** dlpack_inputs, size_t output_size, DLManagedTensor** dlpack_outputs); @@ -22,17 +22,17 @@ class ATenOperatorExecutor { return instance; } - void Initialize(void* p_is_cpu_argument_func_raw, void* p_execute_aten_op_func_raw) { - ORT_ENFORCE(p_is_cpu_argument_func_raw && p_execute_aten_op_func_raw); - p_is_cpu_argument_func_ = reinterpret_cast(p_is_cpu_argument_func_raw); + void Initialize(void* p_is_tensor_argument_func_raw, void* p_execute_aten_op_func_raw) { + ORT_ENFORCE(p_is_tensor_argument_func_raw && p_execute_aten_op_func_raw); + p_is_tensor_argument_func_ = reinterpret_cast(p_is_tensor_argument_func_raw); p_execute_aten_op_func_ = reinterpret_cast(p_execute_aten_op_func_raw); } bool IsInitialized() { return p_execute_aten_op_func_ != nullptr; } - bool IsCpuArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { - ORT_ENFORCE(p_is_cpu_argument_func_, "ATenOperatorExecutor is not initialized."); - return p_is_cpu_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); + bool IsTensorArgument(const std::string& op_name, const std::string& overload_name, size_t index, bool is_input) { + ORT_ENFORCE(p_is_tensor_argument_func_, "ATenOperatorExecutor is not initialized."); + return p_is_tensor_argument_func_(op_name.c_str(), overload_name.c_str(), index, is_input); } void operator()(const std::string& op_name, const std::string& overload_name, size_t input_size, @@ -43,7 +43,7 @@ class ATenOperatorExecutor { } private: - IsCpuArgumentFunc p_is_cpu_argument_func_ = nullptr; + IsTensorArgumentFunc p_is_tensor_argument_func_ = nullptr; ExecuteATenOperatorFunc p_execute_aten_op_func_ = nullptr; }; diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 23fe5e1cd3d96..b737d735b977b 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -1015,9 +1015,19 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) } #ifdef ENABLE_ATEN + // For ATen node, we assume that all tensor inputs are on device, all non-tensor inputs are on CPU, + // except those specified in attribute cpu_input_args; if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && node.Domain() == kPytorchAtenDomain) { const auto& attrs = node.GetAttributes(); + if (auto entry = attrs.find("cpu_input_args"); entry != attrs.end()) { + const auto& attr = entry->second; + if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(), + [index](int64_t arg) { return static_cast(index) == arg; })) { + return true; + } + } + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); std::string op_name = attrs.at("operator").s(); std::string overload_name = ""; @@ -1025,7 +1035,7 @@ bool IsInputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index) overload_name = attrs.at("overload_name").s(); } - return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, true); + return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, true); } #else ORT_UNUSED_PARAMETER(node); @@ -1040,9 +1050,19 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index } #ifdef ENABLE_ATEN + // For ATen node, we assume that all tensor outputs are on device, all non-tensor outputs are on CPU, + // except those specified in attribute cpu_output_args; if (node.GetExecutionProviderType() == kCudaExecutionProvider && node.OpType() == "ATen" && node.Domain() == kPytorchAtenDomain) { const auto& attrs = node.GetAttributes(); + if (auto entry = attrs.find("cpu_output_args"); entry != attrs.end()) { + const auto& attr = entry->second; + if (utils::HasInts(attr) && std::any_of(attr.ints().cbegin(), attr.ints().cend(), + [index](int64_t arg) { return static_cast(index) == arg; })) { + return true; + } + } + ORT_ENFORCE(utils::HasString(attrs.at("operator"))); std::string op_name = attrs.at("operator").s(); std::string overload_name = ""; @@ -1050,7 +1070,7 @@ bool IsOutputOnCpu(const Node& node, const KernelCreateInfo* p_kci, size_t index overload_name = attrs.at("overload_name").s(); } - return contrib::aten_ops::ATenOperatorExecutor::Instance().IsCpuArgument(op_name, overload_name, index, false); + return !contrib::aten_ops::ATenOperatorExecutor::Instance().IsTensorArgument(op_name, overload_name, index, false); } #else ORT_UNUSED_PARAMETER(node); diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index f06a3785f362d..6709398c788f0 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3474,6 +3474,8 @@ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 /*min_arity*/ 1) .Attr("operator", "Name of ATen operator.", AttributeProto::STRING) .Attr("overload_name", "Overload name of ATen operator.", AttributeProto::STRING, false) + .Attr("cpu_input_args", "CPU input argument indices.", AttributeProto::INTS, false) + .Attr("cpu_output_args", "CPU output argument indices.", AttributeProto::INTS, false) .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor."); #endif diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 9c36eb635ffcf..e5e0e81cb7da8 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1327,14 +1327,14 @@ void addGlobalMethods(py::module& m) { #ifdef ENABLE_ATEN m.def("register_aten_op_executor", - [](const std::string& is_cpu_argument_address_str, const std::string& aten_op_executor_address_str) -> void { - size_t is_cpu_argument_address_int, aten_op_executor_address_int; + [](const std::string& is_tensor_argument_address_str, const std::string& aten_op_executor_address_str) -> void { + size_t is_tensor_argument_address_int, aten_op_executor_address_int; ORT_THROW_IF_ERROR( - ParseStringWithClassicLocale(is_cpu_argument_address_str, is_cpu_argument_address_int)); + ParseStringWithClassicLocale(is_tensor_argument_address_str, is_tensor_argument_address_int)); ORT_THROW_IF_ERROR(ParseStringWithClassicLocale(aten_op_executor_address_str, aten_op_executor_address_int)); - void* p_is_cpu_argument = reinterpret_cast(is_cpu_argument_address_int); + void* p_is_tensor_argument = reinterpret_cast(is_tensor_argument_address_int); void* p_aten_op_executor = reinterpret_cast(aten_op_executor_address_int); - contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_cpu_argument, p_aten_op_executor); + contrib::aten_ops::ATenOperatorExecutor::Instance().Initialize(p_is_tensor_argument, p_aten_op_executor); }); #endif } diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py index 8bf7cbf80eb37..9dee6564509d5 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/__init__.py @@ -29,5 +29,5 @@ def load_aten_op_executor_cpp_extension(): from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor _C.register_aten_op_executor( - str(aten_op_executor.is_cpu_argument_address()), str(aten_op_executor.execute_aten_operator_address()) + str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address()) ) diff --git a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc index 903a394a06ef3..e8be98cbfc0e4 100644 --- a/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc +++ b/onnxruntime/python/torch_cpp_extensions/aten_op_executor/aten_op_executor.cc @@ -34,18 +34,23 @@ struct ATenOperator { std::vector is_optional_arguments; std::vector> default_values; size_t return_size; + std::vector ret_kinds; c10::IValue ToIValueArgument(const DLManagedTensor* dlpack, size_t index) const { TORCH_INTERNAL_ASSERT(index < argument_size); bool is_optional = is_optional_arguments[index]; - TORCH_INTERNAL_ASSERT(dlpack || is_optional || default_values[index]); + TORCH_INTERNAL_ASSERT(dlpack || is_optional || default_values[index] || + elem_kinds[index] == c10::TypeKind::TensorType); if (!dlpack) { if (is_optional) { // Optional argument always has no default value. return c10::IValue(c10::nullopt); } - - return *default_values[index]; + if (default_values[index]) { + return *default_values[index]; + } + // Fow bw func, it's possible that input is an undefined tensor from fw outputs, dlpack is nullptr for such case. + return c10::IValue(at::Tensor()); } bool is_list = is_list_arguments[index]; @@ -142,7 +147,10 @@ class ATenOperatorCache { } aten_op.return_size = schema.returns().size(); for (const auto& ret : schema.returns()) { - TORCH_INTERNAL_ASSERT(ret.type()->kind() == c10::TypeKind::TensorType); + c10::TypeKind ret_type = ret.type()->kind(); + // Support tensor or int only for now. + TORCH_INTERNAL_ASSERT(ret_type == c10::TypeKind::TensorType || ret_type == c10::TypeKind::IntType); + aten_op.ret_kinds.emplace_back(ret_type); } ops_.emplace(key, aten_op); } @@ -154,32 +162,15 @@ class ATenOperatorCache { std::unordered_map, ATenOperator, PairHash> ops_; }; -const std::unordered_map> kCpuTensorInputsMap = { - {"_efficient_attention_forward", {4, 5, 11, 12}}, {"_efficient_attention_backward", {6, 7, 12, 13}}}; - -const std::unordered_map> kCpuTensorOutputsMap = { - {"_efficient_attention_forward", {2, 3}}}; - -// Backend uses this function to check if an argument is CPU input or not. -bool IsCpuArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) { +// Backend uses this function to check if an argument is tensor type or not. +bool IsTensorArgument(const char* op_name, const char* overload_name, size_t index, bool is_input) { + const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); if (is_input) { - // If the argument is non-tensor type, it's CPU argument. - const auto& aten_op = ATenOperatorCache::Instance().GetOperator(op_name, overload_name); TORCH_INTERNAL_ASSERT(index < aten_op.argument_size); - if (aten_op.elem_kinds[index] != c10::TypeKind::TensorType) { - return true; - } - } - - std::string full_name = std::string(op_name); - std::string overload_name_str = std::string(overload_name); - if (overload_name_str != "") { - full_name += ("." + overload_name_str); + return aten_op.elem_kinds[index] == c10::TypeKind::TensorType; } - - const auto& cpu_tensors_map = is_input ? kCpuTensorInputsMap : kCpuTensorOutputsMap; - return cpu_tensors_map.find(full_name) != cpu_tensors_map.end() && - cpu_tensors_map.at(full_name).find(index) != cpu_tensors_map.at(full_name).end(); + TORCH_INTERNAL_ASSERT(index < aten_op.return_size); + return aten_op.ret_kinds[index] == c10::TypeKind::TensorType; } void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t input_size, @@ -216,16 +207,23 @@ void ExecuteATenOperator(const char* op_name, const char* overload_name, size_t TORCH_INTERNAL_ASSERT(output_size == aten_op.return_size); size_t output_index = 0; for (const auto& ret : torch::jit::pop(stack, output_size)) { - const auto& tensor = ret.toTensor(); - dlpack_outputs[output_index++] = - tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr; + if (ret.isTensor()) { + const auto& tensor = ret.toTensor(); + dlpack_outputs[output_index++] = + tensor.defined() ? at::toDLPack(tensor.is_contiguous() ? tensor : tensor.contiguous()) : nullptr; + } else if (ret.isInt()) { + at::Tensor scalar = at::scalar_to_tensor(at::Scalar(ret.toInt())); + dlpack_outputs[output_index++] = at::toDLPack(scalar); + } else { + TORCH_INTERNAL_ASSERT(false); + } } } -size_t is_cpu_argument_address() { return reinterpret_cast(&IsCpuArgument); } +size_t is_tensor_argument_address() { return reinterpret_cast(&IsTensorArgument); } size_t execute_aten_operator_address() { return reinterpret_cast(&ExecuteATenOperator); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("is_cpu_argument_address", &is_cpu_argument_address, "Address of tensor argument check."); + m.def("is_tensor_argument_address", &is_tensor_argument_address, "Address of tensor argument check."); m.def("execute_aten_operator_address", &execute_aten_operator_address, "Address of Aten operator executor"); } diff --git a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py index 329fba5aa670a..7d5716b85db30 100644 --- a/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py +++ b/onnxruntime/python/torch_cpp_extensions/ort_torch_ext/__init__.py @@ -5,7 +5,7 @@ from onnxruntime.capi import _pybind_state as _C -from .aten_op_executor import execute_aten_operator_address, is_cpu_argument_address +from .aten_op_executor import execute_aten_operator_address, is_tensor_argument_address def run_once_aten_op_executor(f): @@ -30,7 +30,7 @@ def aten_op_executor_wrapper(*args, **kwargs): @run_once_aten_op_executor def load_aten_op_executor_cpp_extension(): - _C.register_aten_op_executor(str(is_cpu_argument_address()), str(execute_aten_operator_address())) + _C.register_aten_op_executor(str(is_tensor_argument_address()), str(execute_aten_operator_address())) def init_aten_op_executor(): diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index a2b8407645c46..a963d30a9e6e7 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -392,5 +392,8 @@ def __init__( for ir_node in kernel.sub_nodes: if isinstance(ir_node, DropoutNode): ir_node.global_offset = running_offset + kernel.offset_calc.symbolic_shape_variables.update( + [symbol.name for symbol in running_offset.free_symbols] + ) running_offset = running_offset + sympy.prod(ir_node.outputs[0].shape) self.has_dropout = True diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py index 3d3538a62da61..368d1b238fd9e 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/__init__.py @@ -13,7 +13,7 @@ if ( "ORTMODULE_USE_EFFICIENT_ATTENTION" in os.environ and int(os.getenv("ORTMODULE_USE_EFFICIENT_ATTENTION")) == 1 - and Version(torch.__version__) >= Version("2.1.1") + and Version(torch.__version__) >= Version("2.3.0") ): from ._aten_attn import optimize_graph_for_aten_efficient_attention # noqa: F401 diff --git a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py index b1e8809f03fc0..c1fb6e68568f5 100644 --- a/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py +++ b/orttraining/orttraining/python/training/ortmodule/graph_optimizers/_aten_attn.py @@ -5,9 +5,12 @@ """ PyTorch's _efficient_attention_forward/_efficient_attention_backward APIs is keep changing. Current implementation -is tested well on version 2.2.0.dev20231010+cu121, and should be run well since official version 2.2.0. If may fail to +is tested well on version 2.3.0.dev20240221+cu118, and should be run well since official version 2.3.0. If may fail to run is you are using PyTorch with older versions. +This file is more like an example of how to add a new graph optimizer. Ideally user can add graph optimizer according +to the specific model they are using on their own instead of putting every possible graph optimizer here. + PyTorch also has API for flash attention (currently doesn't support random attention mask or Dropout), we can add support if we want to try in the future. """ @@ -40,13 +43,14 @@ def _make_efficient_attention_nodes( scale_node = make_constant_node("scale_" + str(idx), TensorProto.FLOAT, [], [scale]) dropout_ratio_node = make_constant_node("dropout_ratio_" + str(idx), TensorProto.FLOAT, [], [dropout_ratio]) causal_node = make_constant_node("causal_" + str(idx), TensorProto.INT64, [], [1 if causal else 0]) - int_zero_node = make_constant_node("int_zero_" + str(idx), TensorProto.INT64, [], [0]) - true_node = make_constant_node("true_" + str(idx), TensorProto.BOOL, [], [True]) - false_node = make_constant_node("false_" + str(idx), TensorProto.BOOL, [], [False]) + one_node = make_constant_node("one_" + str(idx), TensorProto.INT64, [], [1]) + zero_node = make_constant_node("zero_" + str(idx), TensorProto.INT64, [], [0]) logsumexp = helper.make_tensor_value_info("logsumexp" + str(idx), TensorProto.FLOAT, []) seed = helper.make_tensor_value_info("seed" + str(idx), TensorProto.INT64, []) offset = helper.make_tensor_value_info("offset" + str(idx), TensorProto.INT64, []) - new_value_infos = [logsumexp, seed, offset] + msb_q = helper.make_tensor_value_info("msb_q_" + str(idx), TensorProto.INT64, []) + msb_k = helper.make_tensor_value_info("msb_k_" + str(idx), TensorProto.INT64, []) + new_value_infos = [logsumexp, seed, offset, msb_q, msb_k] if expand_bias: shape_0 = helper.make_node("Shape", [q], ["shape_0_" + str(idx)], start=0, end=1) shape_1 = helper.make_node("Shape", [q], ["shape_1_" + str(idx)], start=2, end=3) @@ -54,13 +58,13 @@ def _make_efficient_attention_nodes( shape_3 = helper.make_node("Shape", [k], ["shape_3_" + str(idx)], start=1, end=2) concat = helper.make_node( "Concat", - ["shape_0_" + str(idx), "shape_1_" + str(idx), "shape_2_" + str(idx), "shape_3_" + str(idx)], + [shape_0.output[0], shape_1.output[0], shape_2.output[0], shape_3.output[0]], ["concated_shape_" + str(idx)], axis=0, ) - expand = helper.make_node("Expand", [bias, "concated_shape_" + str(idx)], ["expanded_bias_" + str(idx)]) + expand = helper.make_node("Expand", [bias, concat.output[0]], ["expanded_bias_" + str(idx)]) nodes_to_add.extend([shape_0, shape_1, shape_2, shape_3, concat, expand]) - bias = "expanded_bias_" + str(idx) + bias = expand.output[0] fwd_node = helper.make_node( "ATen", [ @@ -71,18 +75,21 @@ def _make_efficient_attention_nodes( "", "", "", + "", dropout_ratio_node.output[0], causal_node.output[0], - true_node.output[0], + one_node.output[0], scale_node.output[0], "", "", ], - [y, logsumexp.name, seed.name, offset.name], + [y, logsumexp.name, seed.name, offset.name, msb_q.name, msb_k.name], "efficient_attention_forward_" + str(idx), None, "org.pytorch.aten", operator="_efficient_attention_forward", + cpu_input_args=[4, 5, 12, 13], + cpu_output_args=[2, 3, 4, 5], ) bwd_node = helper.make_node( "ATen", @@ -95,14 +102,14 @@ def _make_efficient_attention_nodes( y, "", "", - int_zero_node.output[0], - int_zero_node.output[0], + msb_q.name, + msb_k.name, logsumexp.name, dropout_ratio_node.output[0], seed.name, offset.name, causal_node.output[0], - false_node.output[0], + zero_node.output[0], scale_node.output[0], "", ], @@ -111,10 +118,9 @@ def _make_efficient_attention_nodes( None, "org.pytorch.aten", operator="_efficient_attention_backward", + cpu_input_args=[6, 7, 12, 13], ) - nodes_to_add.extend( - [scale_node, dropout_ratio_node, causal_node, int_zero_node, true_node, false_node, fwd_node, bwd_node] - ) + nodes_to_add.extend([scale_node, dropout_ratio_node, causal_node, one_node, zero_node, fwd_node, bwd_node]) return nodes_to_add, new_value_infos @@ -240,140 +246,9 @@ def _optimize_for_pattern_1(matcher: GraphMatcher, idx: int, nodes: List[NodePro return nodes, nodes_to_add, new_value_infos -# No causal mask, no attention mask, without Dropout. -_PATTERN_2: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ - ("MatMul", False, []), # 0 - ("Mul", True, [(0, 0, 0)]), # 1 - ("Mul", True, [(0, 0, 1)]), # 2 - ("Transpose", True, [(1, 0, 0)]), # 3 - ("Transpose", True, [(2, 0, 0)]), # 4 - ("Softmax", False, [(0, 0, 0)]), # 5 - ("MatMul", False, [(5, 0, 0)]), # 6 - ("Transpose", True, [(6, 0, 1)]), # 7 - ("Transpose", False, [(6, 0, 0)]), # 8 - ("FusedMatMul", False, [(7, 0, 1)]), # 9 - ("SoftmaxGrad_13", False, [(9, 0, 0), (5, 0, 1)]), # 10 - ("FusedMatMul", False, [(2, 0, 1), (10, 0, 0)]), # 11 - ("FusedMatMul", False, [(1, 0, 0), (10, 0, 1)]), # 12 - ("Mul", False, [(11, 0, 0)]), # 13 - ("Mul", False, [(12, 0, 0)]), # 14 - ("Identity", False, [(13, 0, 0)]), # 15 - ("Identity", False, [(14, 0, 0)]), # 16 - ("Transpose", False, [(15, 0, 0)]), # 17 - ("Transpose", False, [(16, 0, 0)]), # 18 - ("FusedMatMul", False, [(5, 0, 0)]), # 19 - ("Transpose", True, [(19, 0, 1)]), # 20 - ("Transpose", False, [(19, 0, 0)]), # 21 -] - - -def _optimize_for_pattern_2(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): - # Check forward only as the backward is expected to be consistent if it's built correctly. - scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) - scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 - scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) - scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 - if not ( - check_attribute_value(nodes[3], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1]) - and check_attribute_value(nodes[7], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[8], "perm", [0, 2, 1, 3]) - and scale_value_1 == scale_value_2 - ): - return [], [], [] - - nodes_to_add, new_value_infos = _make_efficient_attention_nodes( - idx, - nodes[3].input[0], - nodes[4].input[0], - nodes[7].input[0], - nodes[8].output[0], - nodes[20].input[0], - nodes[17].output[0], - nodes[18].output[0], - nodes[21].output[0], - "", - False, - scale_value_1, - 0.0, - False, - ) - return nodes, nodes_to_add, new_value_infos - - -# Has causal mask, no attention mask, without Dropout. -_PATTERN_3: List[Tuple[str, bool, List[Tuple[int, int, int]]]] = [ - ("MatMul", False, []), # 0 - ("Mul", True, [(0, 0, 0)]), # 1 - ("Mul", True, [(0, 0, 1)]), # 2 - ("Transpose", True, [(1, 0, 0)]), # 3 - ("Transpose", True, [(2, 0, 0)]), # 4 - ("Add", False, [(0, 0, 0)]), # 5 - ("Slice", True, [(5, 0, 1)]), # 6 - ("Slice", True, [(6, 0, 0)]), # 7 - ("Unsqueeze", True, [(6, 0, 2)]), # 8 - ("Gather", True, [(8, 0, 0)]), # 9 - ("Shape", True, [(9, 0, 0)]), # 10 - ("Softmax", False, [(5, 0, 0)]), # 11 - ("MatMul", False, [(11, 0, 0)]), # 12 - ("Transpose", True, [(12, 0, 1)]), # 13 - ("Transpose", False, [(12, 0, 0)]), # 14 - ("FusedMatMul", False, [(13, 0, 1)]), # 15 - ("SoftmaxGrad_13", False, [(15, 0, 0), (11, 0, 1)]), # 16 - ("Identity", False, [(16, 0, 0)]), # 17 - ("FusedMatMul", False, [(2, 0, 1), (17, 0, 0)]), # 18 - ("FusedMatMul", False, [(1, 0, 0), (17, 0, 1)]), # 19 - ("Mul", False, [(18, 0, 0)]), # 20 - ("Mul", False, [(19, 0, 0)]), # 21 - ("Identity", False, [(20, 0, 0)]), # 22 - ("Identity", False, [(21, 0, 0)]), # 23 - ("Transpose", False, [(22, 0, 0)]), # 24 - ("Transpose", False, [(23, 0, 0)]), # 25 - ("FusedMatMul", False, [(11, 0, 0)]), # 26 - ("Transpose", True, [(26, 0, 1)]), # 27 - ("Transpose", False, [(26, 0, 0)]), # 28 -] - - -def _optimize_for_pattern_3(matcher: GraphMatcher, idx: int, nodes: List[NodeProto]): - # Check forward only as the backward is expected to be consistent if it's built correctly. - scale_value_1 = matcher.get_constant_value(nodes[1].input[1]) - scale_value_1 = scale_value_1[0] if isinstance(scale_value_1, list) else scale_value_1 - scale_value_2 = matcher.get_constant_value(nodes[2].input[1]) - scale_value_2 = scale_value_2[0] if isinstance(scale_value_2, list) else scale_value_2 - if not ( - check_attribute_value(nodes[3], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[4], "perm", [0, 2, 3, 1]) - and check_attribute_value(nodes[13], "perm", [0, 2, 1, 3]) - and check_attribute_value(nodes[14], "perm", [0, 2, 1, 3]) - and scale_value_1 == scale_value_2 - ): - return [], [], [] - - nodes_to_add, new_value_infos = _make_efficient_attention_nodes( - idx, - nodes[3].input[0], - nodes[4].input[0], - nodes[13].input[0], - nodes[14].output[0], - nodes[27].input[0], - nodes[24].output[0], - nodes[25].output[0], - nodes[28].output[0], - "", - False, - scale_value_1, - 0.0, - True, - ) - return nodes, nodes_to_add, new_value_infos - - _PATTERNS = [ (_PATTERN_0, _optimize_for_pattern_0), (_PATTERN_1, _optimize_for_pattern_1), - (_PATTERN_2, _optimize_for_pattern_2), - (_PATTERN_3, _optimize_for_pattern_3), ] From a788514027c3a6ee5f284c965ccffcb8805302a5 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Tue, 5 Mar 2024 18:27:26 -0800 Subject: [PATCH 200/207] [js/web] dump debug logs for karma for diagnose purpose (#19785) ### Description dump debug logs for karma for diagnose purpose. This is for debugging the CI issue of Chrome launch failure and considered temporary. --- js/web/script/test-runner-cli.ts | 3 +++ .../github/azure-pipelines/templates/win-web-ci.yml | 12 ++++++------ 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index 59bd0d5f6313a..ace64e9532b12 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -569,6 +569,9 @@ async function main() { if (webnn) { chromiumFlags.push('--enable-experimental-web-platform-features'); } + if (process.argv.includes('--karma-debug')) { + karmaArgs.push('--log-level debug'); + } karmaArgs.push(`--bundle-mode=${args.bundleMode}`); karmaArgs.push(...chromiumFlags.map(flag => `--chromium-flags=${flag}`)); if (browser.startsWith('Edge')) { diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index b882d6fb167fd..9553bc1bc3547 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -153,31 +153,31 @@ jobs: errorActionPreference: stop displayName: 'Pack NPM packages' - script: | - npm test -- -e=chrome -b=webgl,wasm + npm test -- -e=chrome -b=webgl,wasm --karma-debug workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (wasm,webgl backend)' condition: eq('${{ parameters.RunWebGpuTests }}', 'false') - script: | - npm test -- -e=chrome -b=webgl,wasm,webgpu $(webgpuCommandlineExtraFlags) + npm test -- -e=chrome -b=webgl,wasm,webgpu --karma-debug $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (ALL backends)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - script: | - npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags) + npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-tensor --karma-debug $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-tensor)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - script: | - npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags) + npm test -- suite1 -e=chrome -b=webgpu --io-binding=gpu-location --karma-debug $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-location)' condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - script: | - npm test -- --webgl-texture-pack-mode -b=webgl -e=chrome + npm test -- --webgl-texture-pack-mode -b=webgl -e=chrome --karma-debug workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests - WebGL: packed mode' - script: | - npm test -- --wasm-enable-proxy -b=wasm -e=chrome + npm test -- --wasm-enable-proxy -b=wasm -e=chrome --karma-debug workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests - WebAssembly: proxy' condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) From db59cec82f226dbba3ce7c5b03db35b0fe07fb60 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 6 Mar 2024 15:03:55 +1000 Subject: [PATCH 201/207] Don't reduce warning level for CUDA build on Windows (#19663) ### Description Address warnings so all the ORT projects build with /W4 on Windows. Mainly - unused parameters - variables shadowing other ones ### Motivation and Context #19588 started on this. --- cmake/CMakeLists.txt | 6 +-- cmake/onnxruntime_providers_cuda.cmake | 13 ++++- .../core/providers/cuda/cuda_context.h | 2 +- .../cuda/bert/add_bias_transpose.cu | 10 ++-- .../contrib_ops/cuda/bert/attention_impl.cu | 20 +++---- .../cuda/bert/attention_prepare_qkv.cu | 4 +- .../bert/cutlass_fmha/fmha_launch_template.h | 8 +-- .../cuda/bert/decoder_attention_impl.cu | 2 +- .../cuda/bert/group_query_attention_impl.cu | 4 +- .../cuda/bert/packed_attention_impl.cu | 2 +- .../bert/packed_multihead_attention_impl.cu | 4 +- .../contrib_ops/cuda/bert/rotary_embedding.cc | 2 - .../cuda/bert/rotary_embedding_impl.cu | 2 +- .../mha_runner.cu | 54 +++++++++---------- .../cuda/diffusion/group_norm_common_base.h | 6 +-- onnxruntime/contrib_ops/cuda/inverse.cc | 8 +-- .../contrib_ops/cuda/math/complex_mul_impl.cu | 4 +- .../contrib_ops/cuda/math/gemm_float8.cu | 2 +- .../cuda/moe/ft_moe/moe_cutlass_kernel.h | 2 +- .../moe/ft_moe/moe_gemm_kernels_template.h | 29 ++++++---- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 4 +- .../cuda/moe/ft_moe/moe_problem_visitor.h | 8 +-- .../quantization/attention_quantization.cc | 2 +- .../qordered_ops/qordered_attention.cc | 2 +- .../qordered_ops/qordered_attention_impl.cu | 2 +- .../qordered_ops/qordered_qdq_impl.cu | 2 +- .../cuda/transformers/generation_cuda_impl.cu | 17 ++++-- .../providers/cuda/cuda_execution_provider.h | 20 +++---- .../core/providers/cuda/cudnn_common.cc | 1 - .../cuda/math/unary_elementwise_ops_impl.cu | 7 +-- onnxruntime/core/providers/cuda/nn/conv.cc | 20 ++++--- onnxruntime/core/providers/cuda/nn/conv.h | 2 +- .../core/providers/cuda/nn/layer_norm.h | 2 - .../core/providers/cuda/nn/layer_norm_impl.cu | 2 - .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 1 - .../cuda/tensor/gelu_approximate_impl.cu | 6 +-- .../cuda/tensor/resize_antialias_impl.cu | 20 +++---- .../core/providers/cuda/tensor/resize_impl.cu | 2 +- .../providers/cuda/tensor/transpose_impl.cu | 6 +-- .../core/providers/cuda/triton_kernel.cu | 50 ++++++++++------- .../core/providers/tensorrt/nv_includes.h | 20 +++++++ .../tensorrt/onnx_ctx_model_helper.h | 2 +- .../tensorrt/tensorrt_execution_provider.cc | 48 ++++++++++------- .../tensorrt/tensorrt_execution_provider.h | 5 +- .../tensorrt_execution_provider_custom_ops.cc | 5 +- .../tensorrt_execution_provider_custom_ops.h | 23 +++++--- ...oder_masked_multihead_attention_op_test.cc | 12 ++--- .../providers/cpu/generator/random_test.cc | 4 +- onnxruntime/test/unittest_main/test_main.cc | 17 +++++- .../training_ops/cuda/cross_entropy_test.cc | 10 ++-- .../training_ops/cuda/nn/conv_shared.cc | 11 ++-- .../cuda/nn/conv_transpose_grad.cc | 2 - .../training_ops/cuda/nn/layer_norm_impl.cu | 2 - .../training_ops/cuda/optimizer/lamb_impl.cu | 2 +- .../templates/jobs/win-ci-prebuild-steps.yml | 11 +++- 55 files changed, 315 insertions(+), 219 deletions(-) create mode 100644 onnxruntime/core/providers/tensorrt/nv_includes.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0d55d4cab9826..3f919d7bf6e18 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1274,11 +1274,7 @@ endif() #Dependencies end. In the next we'll enable "treat warning as error" #Adjust warning flags -if (onnxruntime_USE_CUDA) - set_msvc_c_cpp_compiler_warning_level(3) -else() - set_msvc_c_cpp_compiler_warning_level(4) -endif() +set_msvc_c_cpp_compiler_warning_level(4) set(onnxruntime_DELAYLOAD_FLAGS "") diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 7f295a59a0931..aeeac10ead27d 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -141,18 +141,22 @@ if (HAS_GUARD_CF) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /guard:cf>") endif() + if (HAS_QSPECTRE) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /Qspectre>") endif() + foreach(ORT_FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler \"${ORT_FLAG}\">") endforeach() + # CUDA 11.3+ supports parallel compilation # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver-threads if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.3) option(onnxruntime_NVCC_THREADS "Number of threads that NVCC can use for compilation." 1) target_compile_options(${target} PRIVATE "$<$:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">") endif() + if (UNIX) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-reorder>" "$<$>:-Wno-reorder>") @@ -162,6 +166,13 @@ #mutex.cuh(91): warning C4834: discarding return value of function with 'nodiscard' attribute target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4834>") target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4127>") + if (MSVC) + # the VS warnings for 'Conditional Expression is Constant' are spurious as they don't handle multiple conditions + # e.g. `if (std::is_same_v && not_a_const)` will generate the warning even though constexpr cannot + # be used due to `&& not_a_const`. This affects too many places for it to be reasonable to disable at a finer + # granularity. + target_compile_options(${target} PRIVATE "$<$:/wd4127>") + endif() endif() onnxruntime_add_include_to_target(${target} onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers) @@ -187,7 +198,7 @@ target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) endif() endif() - + if (onnxruntime_USE_TRITON_KERNEL) # compile triton kernel, generate .a and .h files include(onnxruntime_compile_triton_kernel.cmake) diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 108173474db46..7104e70c3a8a9 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -58,7 +58,7 @@ struct CudaContext : public CustomOpContext { template T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) { - if (sizeof(T) > sizeof(void*)) { + if constexpr (sizeof(T) > sizeof(void*)) { ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT); } const auto& ort_api = Ort::GetApi(); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 1ea2540db486f..9e6752b451868 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -843,11 +843,11 @@ void InvokeAddBiasTransposeTrt( template <> void LaunchAddBiasTransposeTrt( - cudaStream_t stream, const int max_threads_per_block, - const int batch_size, const int sequence_length, - const int num_heads, const int head_size, - const float* biases, const float* query, const float* key, const float* value, float* output, - bool is_cross_attention, int kv_sequence_length) { + cudaStream_t /*stream*/, const int /*max_threads_per_block*/, + const int /*batch_size*/, const int /*sequence_length*/, + const int /*num_heads*/, const int /*head_size*/, + const float* /*biases*/, const float* /*query*/, const float* /*key*/, const float* /*value*/, float* /*output*/, + bool /*is_cross_attention*/, int /*kv_sequence_length*/) { ORT_ENFORCE(false, "Shall not call this since fused kernel does not support float input."); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index c20f42c4d06bc..a93fdf74dc28c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -58,12 +58,12 @@ size_t AlignSize(size_t bytes) { return bytesAligned; } -void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) { - if (this->sequence_length != sequence_length) { +void CumulatedSequenceLengthCache::Initialize(int32_t seq_length, cudaStream_t stream) { + if (this->sequence_length != seq_length) { ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, - this->max_batch_size, sequence_length, stream); - this->sequence_length = sequence_length; + this->max_batch_size, seq_length, stream); + this->sequence_length = seq_length; } } @@ -213,9 +213,9 @@ Status FusedTrtCrossAttention( template <> Status FusedTrtCrossAttention( - cudaStream_t stream, - contrib::AttentionParameters& parameters, - AttentionData& data) { + cudaStream_t /*stream*/, + contrib::AttentionParameters& /*parameters*/, + AttentionData& /*data*/) { return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "Trt fused cross attention does not support float tensor"); } @@ -276,9 +276,9 @@ Status FusedTrtSelfAttention( // Template Specialization for float type template <> Status FusedTrtSelfAttention( - cudaStream_t stream, - contrib::AttentionParameters& parameters, - AttentionData& data) { + cudaStream_t /*stream*/, + contrib::AttentionParameters& /*parameters*/, + AttentionData& /*data*/) { return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "Trt fused attention does not support float tensor"); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index a513d9e8d2211..b843966d88e85 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -231,7 +231,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + T* /*q*/, T* /*k*/, T* /*v*/, AttentionQkvFormat& qkv_format) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; @@ -279,7 +279,7 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + T* /*q*/, T* k, T* /*v*/, AttentionQkvFormat& qkv_format) { const int batch_size = parameters.batch_size; const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index db78722cc0e4c..c12cb374d9adf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -242,18 +242,18 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { using AlignedAK = AttentionKernel; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) -#pragma warning(disable : 6287) +#pragma warning(disable : 6287 4189) // kAligned is used via capture so 4189 warning seems incorrect #endif // Run a more efficient kernel with `isAligned=True` when memory is correctly aligned. bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 && params.qk_head_size % AlignedAK::kAlignmentK == 0 && params.v_head_size % AlignedAK::kAlignmentV == 0; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { LaunchCutlassFmha(params); })); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif } template diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu index e24d9da94c964..c0b1996789183 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -17,7 +17,7 @@ Status DecoderQkvToContext( const cudaDeviceProp& device_prop, Stream* ort_stream, cublasHandle_t& cublas, - const size_t element_size, + const size_t /*element_size*/, const int batch_size, const int sequence_length, const int kv_sequence_length, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index d88e9a49fb5ee..cb5631542c113 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, // Convert Past to Total sequence length tensor Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int threads_per_block) { + const int /*threads_per_block*/) { if (parameters.is_prompt) { return Status::OK(); } @@ -655,7 +655,7 @@ Status EfficientAttention( template Status QkvToContext( const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, + cublasHandle_t& /*cublas*/, Stream* ort_stream, contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data) { diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index ce7ac3796dbe1..a84a310b46ca0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -440,7 +440,7 @@ Status LaunchTransposeRemovePadding( template Status FusedScaledDotProductAttention( - const cudaDeviceProp& device_prop, + const cudaDeviceProp& /*device_prop*/, cudaStream_t stream, PackedAttentionParameters& parameters, PackedAttentionData& data) { diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 49029da12a308..982c7eaa2cb2c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -381,7 +381,7 @@ void InvokeTranspose( const T* query, const T* key, const T* value, const T* bias, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, - AttentionQkvFormat source_format, AttentionQkvFormat target_format, + [[maybe_unused]] AttentionQkvFormat source_format, AttentionQkvFormat target_format, const int32_t* token_offset, int32_t token_count, cudaStream_t stream) { if (key != nullptr && value != nullptr) { @@ -551,7 +551,7 @@ void LaunchTranspose( template Status FusedAttentionTrt( - const cudaDeviceProp& device_prop, + const cudaDeviceProp& /*device_prop*/, cudaStream_t stream, PackedAttentionParameters& parameters, PackedMultiHeadAttentionData& data) { diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index 9de7ba3885c3c..ab7479f2938fe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -82,8 +82,6 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { interleaved, device_prop.maxThreadsPerBlock, parameters.transposed); - - return Status::OK(); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index c6637041f05bd..3a14161f29e9f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -93,7 +93,7 @@ Status LaunchRotaryEmbeddingKernel( const int num_heads, const int head_size, const int rotary_embedding_dim, - const int max_sequence_length, + const int /*max_sequence_length*/, const int position_ids_format, const bool interleaved, const int max_threads_per_block, diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu index 8fb6575d27cc0..4a4e3eeecf642 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu @@ -53,9 +53,9 @@ class FusedMHARunnerFP16v2::mhaImpl { ~mhaImpl() {} - void setup(const int S, const int B) { + void setup(const int seq_len, const int B) { // For bert and vit, use flash attention when sequence length is larger than the threshold. - use_flash_attention = is_flash_attention(S); + use_flash_attention = is_flash_attention(seq_len); params.force_unroll = use_flash_attention; @@ -68,26 +68,26 @@ class FusedMHARunnerFP16v2::mhaImpl { warps_n = 1; } else { if (sm == 70) { - if (S == 64 || S == 96) { + if (seq_len == 64 || seq_len == 96) { warps_m = 2; warps_n = 2; - } else if (S == 128) { + } else if (seq_len == 128) { warps_m = 1; warps_n = 4; - } else if (S == 256 || S == 384) { + } else if (seq_len == 256 || seq_len == 384) { warps_m = 1; warps_n = 8; } else { ORT_ENFORCE(false, "Unsupported sequence length"); } } else { - if (S == 32 || S == 64 || S == 96 || S == 128) { + if (seq_len == 32 || seq_len == 64 || seq_len == 96 || seq_len == 128) { warps_m = 2; warps_n = 2; - } else if (S == 192 || S == 256) { + } else if (seq_len == 192 || seq_len == 256) { warps_m = 1; warps_n = 4; - } else if (S == 384) { + } else if (seq_len == 384) { warps_m = 1; warps_n = 8; } else { @@ -99,7 +99,7 @@ class FusedMHARunnerFP16v2::mhaImpl { // The number of threads per CTA. threads_per_cta = warps_m * warps_n * warps_k * 32; // The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension. - xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m); + xmmas_m = (seq_len + 16 * warps_m - 1) / (16 * warps_m); const float scale_bmm1 = interface->mScale; const float scale_softmax = 1.f; // Seems to be only required for int8 @@ -111,7 +111,7 @@ class FusedMHARunnerFP16v2::mhaImpl { params.b = B; params.h = interface->mNumHeads; - params.s = S; + params.s = seq_len; params.d = interface->mHeadSize; params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half); @@ -121,7 +121,7 @@ class FusedMHARunnerFP16v2::mhaImpl { has_causal_mask = false; } - void setup_causal_masked_fmha(const int S, const int B) { + void setup_causal_masked_fmha(const int seq_len, const int B) { const float scale_bmm1 = interface->mScale; const float scale_softmax = 1.f; // Seems to be only required for int8 const float scale_bmm2 = 1.f; @@ -132,7 +132,7 @@ class FusedMHARunnerFP16v2::mhaImpl { params.b = B; params.h = interface->mNumHeads; - params.s = S; + params.s = seq_len; params.d = interface->mHeadSize; params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half); @@ -182,30 +182,30 @@ class FusedMHARunnerFP16v2::mhaImpl { return max_seq_len; } - int S = max_seq_len; + int seq_len = max_seq_len; if (max_seq_len <= 32) { - S = (sm == 70) ? 64 : 32; + seq_len = (sm == 70) ? 64 : 32; } else if (max_seq_len <= 64) { - S = 64; + seq_len = 64; } else if (max_seq_len <= 96) { - S = 96; + seq_len = 96; } else if (max_seq_len <= 128) { - S = 128; + seq_len = 128; } else if (max_seq_len <= 192) { - S = (sm == 70) ? 256 : 192; + seq_len = (sm == 70) ? 256 : 192; } else if (max_seq_len <= 256) { - S = 256; + seq_len = 256; } else if (max_seq_len <= 384) { - S = 384; + seq_len = 384; } - return S; + return seq_len; } protected: - bool is_flash_attention(const int S) const { + bool is_flash_attention(const int seq_len) const { ORT_ENFORCE(interface->mHasCausalMask == false); - return interface->mEnableFlashAttention && S >= kMinSequenceLengthFlashAttention; + return interface->mEnableFlashAttention && seq_len >= kMinSequenceLengthFlashAttention; } private: @@ -232,12 +232,12 @@ FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads, pimpl(new mhaImpl(this)) { } -void FusedMHARunnerFP16v2::setup(const int S, const int B) { - MHARunner::setup(S, B); +void FusedMHARunnerFP16v2::setup(const int seq_len, const int B) { + MHARunner::setup(seq_len, B); if (mHasCausalMask) { - pimpl->setup_causal_masked_fmha(S, B); + pimpl->setup_causal_masked_fmha(seq_len, B); } else { - pimpl->setup(S, B); + pimpl->setup(seq_len, B); } } diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h index ea87d0c29111e..a80584d3293a0 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h @@ -136,10 +136,10 @@ struct GroupNormNHWCParams { bool use_silu, bool broadcast_skip, int channels_per_block) { - int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_group_in = num_channels / num_groups; // channels_per_block is computed in PrePack. // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. - if (channels_per_block < channels_per_group) { + if (channels_per_block < channels_per_group_in) { channels_per_block = GetChannelsPerBlock(num_channels, num_groups); } @@ -167,7 +167,7 @@ struct GroupNormNHWCParams { this->hw_per_block = DivUp(this->hw, blocks_per_hw); this->channels_per_block = channels_per_block; - this->channels_per_group = channels_per_group; + this->channels_per_group = channels_per_group_in; this->hwc = this->hw * this->c; this->inv_hw_channels_per_group = 1.F / (float)(this->hw * this->channels_per_group); this->groups_per_block = channels_per_block / this->channels_per_group; diff --git a/onnxruntime/contrib_ops/cuda/inverse.cc b/onnxruntime/contrib_ops/cuda/inverse.cc index 81e161e60642c..9075dda26f86b 100644 --- a/onnxruntime/contrib_ops/cuda/inverse.cc +++ b/onnxruntime/contrib_ops/cuda/inverse.cc @@ -78,9 +78,9 @@ struct Inverse::ComputeImpl { cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; // Make a copy of the input which will serve as a workspace as well. - if (std::is_same::value || std::is_same::value) { + if constexpr (std::is_same::value || std::is_same::value) { IAllocatorUniquePtr input_workspace = inst->GetScratchBuffer(input_count, ort_stream); - if (std::is_same::value) { + if constexpr (std::is_same::value) { // Convert from MLFloat16(half) to float Impl_Cast(stream, reinterpret_cast(input.Data()), input_workspace.get(), input_count); } else { @@ -96,7 +96,7 @@ struct Inverse::ComputeImpl { // Need to compute ptrs for output buffers // Output for MLFloat IAllocatorUniquePtr output_ptrs = inst->GetScratchBuffer(n_batches, ort_stream); - if (std::is_same::value) { + if constexpr (std::is_same::value) { IAllocatorUniquePtr ml_float_output = inst->GetScratchBuffer(input_count, ort_stream); ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(stream, ml_float_output.get(), num_batches, rows, output_ptrs)); // Do the inverse @@ -112,7 +112,7 @@ struct Inverse::ComputeImpl { ORT_RETURN_IF_ERROR(CheckForSingularity(stream, info, info_cpu, num_batches)); // We are done here } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { IAllocatorUniquePtr input_workspace = inst->GetScratchBuffer(static_cast(input_count), ort_stream); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_workspace.get(), input.Data(), sizeof(double) * input_count, cudaMemcpyDeviceToDevice, stream)); diff --git a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu index ca94477114ee2..47a64502b3480 100644 --- a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu +++ b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu @@ -97,8 +97,8 @@ void ComplexMul_Impl( const TArray* rhs_padded_strides, const T* rhs_data, const TArray* fdm_output_strides, - const onnxruntime::cuda::fast_divmod& fdm_H, - const onnxruntime::cuda::fast_divmod& fdm_C, + const onnxruntime::cuda::fast_divmod& /*fdm_H*/, + const onnxruntime::cuda::fast_divmod& /*fdm_C*/, T* output_data, int64_t count, int64_t lhs_size, diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 064b6dd392437..28ab27ee33d10 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -174,7 +174,7 @@ Status GemmFloat8::ComputeGemm( int32_t dtype_A, int32_t dtype_B, int32_t dtype_C, int32_t dtype_Y, const TensorShape& shape_A, const TensorShape& shape_B, - const TensorShape& shape_C, const TensorShape& shape_Y, + const TensorShape& shape_C, const TensorShape& /*shape_Y*/, bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b, const void* p_input_c, const void* p_scale_a, const void* p_scale_b, const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h index bfe30b71170d8..cfe306c2482a5 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h @@ -202,7 +202,7 @@ struct MoeFCGemm { total_rows_before_expert(total_rows_before_expert), gemm_n(gemm_n), gemm_k(gemm_k), - host_problem_sizes(nullptr) { + host_problem_sizes(host_problem_sizes) { if (platform::is_same::value || platform::is_same::value) { assert(weight_scales); } diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 66950c9b65970..a3dcf0da16b98 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -20,6 +20,12 @@ #pragma GCC diagnostic ignored "-Wstrict-aliasing" #endif +// Ignore CUTLASS warning C4100: unreferenced formal parameter +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + #include "cutlass/array.h" #include "cutlass/numeric_conversion.h" #include "cutlass/layout/matrix.h" @@ -36,6 +42,10 @@ #include "layout_traits_helper.h" #include "moe_cutlass_kernel.h" +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + #ifdef __GNUC__ #pragma GCC diagnostic pop #endif @@ -149,10 +159,10 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w template struct dispatch_stages { - static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, - CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) { + static void dispatch(const T* /*A*/, const WeightType* /*B*/, const T* /*weight_scales*/, const T* /*biases*/, + T* /*C*/, int64_t* /*total_rows_before_expert*/, int64_t /*gemm_n*/, int64_t /*gemm_k*/, + int /*num_experts*/, CutlassGemmConfig /*gemm_config*/, int /*multi_processor_count*/, + cudaStream_t /*stream*/, [[maybe_unused]] int* occupancy = nullptr) { std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); ORT_THROW("[FT Error][dispatch_stages::dispatch] " + err_msg); @@ -221,9 +231,10 @@ template < typename T, typename WeightType, typename arch, typename EpilogueTag, typename std::enable_if::value && std::is_same::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, int sm_version, - int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + int64_t* total_rows_before_expert, int64_t /*total_rows*/, + int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, + int /*sm_version*/, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: dispatch_gemm_config, @@ -300,8 +311,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig template ::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index f4f2b49032d23..a5b47bcddefbc 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -370,7 +370,7 @@ struct TopkConstants { template void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row, - int num_rows, int num_experts, int k, cudaStream_t stream) { + int num_rows, int /*num_experts*/, int k, cudaStream_t stream) { static constexpr unsigned long MAX_BYTES_PER_LDG = 16; static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); @@ -599,7 +599,7 @@ void CutlassMoeFCRunner::run_moe_fc( static constexpr bool scales_required = std::is_same::value || std::is_same::value; - if (scales_required) { + if constexpr (scales_required) { if (fc1_scales == nullptr) { ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for first matmul is a null pointer"); } else if (fc2_scales == nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h index 00f977c615df6..1de8f6b69642c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h @@ -276,13 +276,13 @@ struct MoeProblemVisitor::ComputeInternal(OpKernelContext* context) const { CudaT dequant_scale; CudaT input_scale = *(reinterpret_cast(input_scale_tensor->Data())); CudaT weight_scale = *(reinterpret_cast(weight_scale_tensor->Data())); - if (sizeof(T) == 2) { + if constexpr (sizeof(T) == 2) { dequant_scale = __float2half(__half2float(input_scale) * __half2float(weight_scale)); } else { dequant_scale = input_scale * weight_scale; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 3cecebedae2f0..12835978536e1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -142,7 +142,7 @@ inline void debug_print([[maybe_unused]] const T* arr, std::cout << "========" << name << std::endl; for (size_t i = 0; i < sz; i++) { if (i % w == 0) std::cout << std::endl; - if (std::is_same().value) { + if constepxr (std::is_same::value) { std::cout << (int)buf[i] << ", "; } else { std::cout << buf[i] << ", "; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu index f4d5a7b404a62..fd4b51f40fb4f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu @@ -151,7 +151,7 @@ QOrderBatchInt8MatrixTransposeKernel(const int8_t* src, const int8_t* dst, const } } -Status QOrderBatchTransposeInt8Matrix(cudaStream_t stream, const cudaDeviceProp& device_prop, +Status QOrderBatchTransposeInt8Matrix(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, const int batch_size, const int rows, const int cols, const int8_t* input, int8_t* output) { ORT_ENFORCE(rows % 4 == 0 && cols % 4 == 0, "Matrix rows and cols must be divisible by 4!"); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu index baff8e76ec73b..e6ac0bc8a5171 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu @@ -389,7 +389,7 @@ QOrderDequantizeKernel_Strict(const int8_t* __restrict__ src, const __half* __re } } -Status QOrderDequantize_Strict(cudaStream_t stream, const cudaDeviceProp& device_prop, +Status QOrderDequantize_Strict(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, const int8_t* src, __half* dst, float scale, size_t N) { ORT_RETURN_IF(N & 0x3LL, "N can not divide by 4!"); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index a39abefed9cd0..eb1943b59d976 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -1,11 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + +// cub.cuh includes device/dispatch_radix_sort.cuh which has assignment in conditional expressions +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4706) +#endif +#include +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#include + #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" -#include "cub/util_type.cuh" -#include -#include + #include "contrib_ops/cuda/bert/utils.cuh" #include "contrib_ops/cuda/transformers/generation_cuda_impl.h" diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 5f62f313b86a2..75fe1dff7c4a4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -131,41 +131,33 @@ class CUDAExecutionProvider : public IExecutionProvider { template const T* GetConstOnes(size_t count, cudaStream_t stream) { - constexpr bool is_float = std::is_same::value; - constexpr bool is_double = std::is_same::value; - constexpr bool is_half = std::is_same::value; - constexpr bool is_BFloat16 = std::is_same::value; -#if !defined(DISABLE_FLOAT8_TYPES) - constexpr bool is_Float8E4M3FN = std::is_same::value; - constexpr bool is_Float8E5M2 = std::is_same::value; -#endif - if (is_float) { + if constexpr (std::is_same::value) { if (!constant_ones_float_) { constant_ones_float_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_float_->GetBuffer(stream, count)); - } else if (is_double) { + } else if constexpr (std::is_same::value) { if (!constant_ones_double_) { constant_ones_double_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_double_->GetBuffer(stream, count)); - } else if (is_half) { + } else if constexpr (std::is_same::value) { if (!constant_ones_half_) { constant_ones_half_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_half_->GetBuffer(stream, count)); - } else if (is_BFloat16) { + } else if constexpr (std::is_same::value) { if (!constant_ones_bfloat16_) { constant_ones_bfloat16_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_bfloat16_->GetBuffer(stream, count)); #if !defined(DISABLE_FLOAT8_TYPES) - } else if (is_Float8E4M3FN) { + } else if constexpr (std::is_same::value) { if (!constant_ones_float8e4m3fn_) { constant_ones_float8e4m3fn_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_float8e4m3fn_->GetBuffer(stream, count)); - } else if (is_Float8E5M2) { + } else if constexpr (std::is_same::value) { if (!constant_ones_float8e5m2_) { constant_ones_float8e5m2_ = cuda::CreateConstantOnes(); } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index c850f7b583bfc..39b73163794f0 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -160,7 +160,6 @@ cudnnDataType_t CudnnTensor::GetDataType() { template <> cudnnDataType_t CudnnTensor::GetDataType() { ORT_THROW("cuDNN doesn't support BFloat16."); - return CUDNN_DATA_FLOAT; } template <> diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index fd8f7929d4426..554d5908cf854 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -127,9 +127,10 @@ struct OP_Cast { UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ } -#define IMPL_CAST_IMPL_THROW(InT, OutT) \ - void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ - ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ +#define IMPL_CAST_IMPL_THROW(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t /*stream*/, const InT* /*input_data*/, OutT* /*output_data*/, \ + size_t /*count*/) { \ + ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ } #if !defined(DISABLE_FLOAT8_TYPES) diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index a417be5a86c32..e05786248cbcf 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -97,11 +97,11 @@ Status SliceOutUnwantedOutputSection(cudaStream_t stream, template Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) { + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack - if (NHWC && is_nhwc_domain_) { // InputTensors::IN_W - if (input_idx == 1) { + if constexpr (NHWC) { + if (is_nhwc_domain_ && input_idx == 1) { // InputTensors::IN_W // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); @@ -123,6 +123,10 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); is_packed = true; } + } else { + ORT_UNUSED_PARAMETER(tensor); + ORT_UNUSED_PARAMETER(input_idx); + ORT_UNUSED_PARAMETER(alloc); } return Status::OK(); @@ -149,8 +153,11 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. constexpr bool channels_last = NHWC; - if (channels_last && (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + if constexpr (channels_last) { + if (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + } } // set B @@ -403,7 +410,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) default: perf.algo = kDefaultConvAlgo; CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); - if (std::is_same::value) { + + if constexpr (std::is_same::value) { perf.mathType = CUDNN_TENSOR_OP_MATH; } else if (std::is_same::value && !UseTF32()) { perf.mathType = CUDNN_FMA_MATH; diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 181fbc99fd8e9..3aec654224e39 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -195,7 +195,7 @@ class Conv : public CudaKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; + bool& is_packed, PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm.h b/onnxruntime/core/providers/cuda/nn/layer_norm.h index ff231f4f1ad5c..c021d3ffe63a2 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm.h @@ -7,8 +7,6 @@ namespace onnxruntime { namespace cuda { -using namespace onnxruntime::cuda; - // NOTE: This was originally a contrib op with 3 type constraints. The ONNX spec merges 'T' and 'V'. // the kernel is templatized on all three for backwards compatibility, but in ONNX usage T == V. template diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index 679b8b6b78886..b9e8b45307079 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -29,8 +29,6 @@ namespace onnxruntime { namespace cuda { -using namespace onnxruntime::cuda; - template __device__ void cuWelfordOnlineSum( const U curr, diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index b61b104790fe5..6476364a211fd 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -305,7 +305,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { if (!weight_cached_) { const Tensor& W = *ctx->Input(RNN_Input_Index::W); const Tensor& R = *ctx->Input(RNN_Input_Index::R); - const Tensor* B = ctx->Input(RNN_Input_Index::B); ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc, rnn_desc, ctx->GetComputeStream())); } diff --git a/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu index 3292650584de8..7a27b7af33137 100644 --- a/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu @@ -62,7 +62,7 @@ __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int } template <> -Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length, const float* input, const float* bias, float* output, bool /*use_half2*/) { constexpr int blockSize = 256; const int gridSize = (input_length + blockSize - 1) / blockSize; @@ -73,7 +73,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int } template <> -Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length, const double* input, const double* bias, double* output, bool /*use_half2*/) { constexpr int blockSize = 256; const int gridSize = (input_length + blockSize - 1) / blockSize; @@ -108,7 +108,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int } template <> -Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length, const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { constexpr int blockSize = 256; diff --git a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu index 56b7c3f499303..d56e4bc53874d 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu @@ -680,10 +680,10 @@ template void ResizeTrilinearUpsample( cudaStream_t stream, int rank, - const UpsampleMode upsample_mode, + const UpsampleMode /*upsample_mode*/, ResizeCoordinateTransformationMode coordinate_transform_mode, - gsl::span input_shape, - gsl::span output_shape, + gsl::span /*input_shape*/, + gsl::span /*output_shape*/, int64_t batch_size, int64_t num_channels, std::tuple inferred_input_dims, std::tuple inferred_output_dims, @@ -832,11 +832,11 @@ void ResizeTrilinearUpsample( template void ResizeBiLinearUpsample(cudaStream_t stream, int rank, - const UpsampleMode upsample_mode, + const UpsampleMode /*upsample_mode*/, ResizeCoordinateTransformationMode coordinate_transform_mode, - gsl::span input_shape, - gsl::span output_shape, - int64_t batch_size, int64_t num_channels, + gsl::span /*input_shape*/, + gsl::span /*output_shape*/, + int64_t /*batch_size*/, int64_t num_channels, std::tuple inferred_input_dims, std::tuple inferred_output_dims, std::tuple inferred_dim_rscales, @@ -959,10 +959,10 @@ void ResizeBiLinearUpsample(cudaStream_t stream, template void ResizeBicubicUpsample(cudaStream_t stream, int rank, - const UpsampleMode upsample_mode, + const UpsampleMode /*upsample_mode*/, ResizeCoordinateTransformationMode coordinate_transform_mode, - gsl::span input_shape, - gsl::span output_shape, + gsl::span /*input_shape*/, + gsl::span /*output_shape*/, int64_t batch_size, int64_t num_channels, std::tuple inferred_input_dims, std::tuple inferred_output_dims, diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index 0cde0ed8e8681..e788f24052985 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -609,7 +609,7 @@ void ResizeNearestImpl( const size_t N, bool extrapolation_enabled, const T extrapolation_value, - float cubic_coeff_a, + float /*cubic_coeff_a*/, ResizeCoordinateTransformationMode transform_coordinate, ResizeNearestMode calc_nearest_pixel, int64_t* /* prefix_dim_sum */, diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu index 9f9c365d2a53d..6344845359b32 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu @@ -80,7 +80,7 @@ bool CanDoTranspose3D(const cudaDeviceProp& prop, size_t rank, const gsl::span& input_shape, - const TArray& input_strides, const void* input_data, void* output_data, int64_t N, + const TArray& input_strides, const void* input_data, void* output_data, int64_t /*N*/, const dim3& grid_size, const dim3& block_size) { switch (element_size) { HANDLE_TRANSPOSE_3D_TILE_DIM(int8_t); @@ -248,10 +248,10 @@ __global__ void Transpose4DKernelParallelizeOneElementPerThread( } bool CanDoTranspose4DParallelizeOneElementPerThread(const cudaDeviceProp& prop, - size_t element_size, + size_t /*element_size*/, int32_t rank, const gsl::span& input_dims, - const gsl::span& permutations, + const gsl::span& /*permutations*/, dim3& grid_size, dim3& block_size) { if (rank == 4) { // dims[3]: block.x diff --git a/onnxruntime/core/providers/cuda/triton_kernel.cu b/onnxruntime/core/providers/cuda/triton_kernel.cu index 6ffbf0420a15f..b42dbd0291b7a 100644 --- a/onnxruntime/core/providers/cuda/triton_kernel.cu +++ b/onnxruntime/core/providers/cuda/triton_kernel.cu @@ -130,27 +130,11 @@ void LoadOrtTritonKernel() { std::call_once(load_ort_triton_kernel_flag, TryToLoadKernel); } -Status LaunchTritonKernel(cudaStream_t stream, std::string fname, - int grid0, int grid1, int grid2, void* args, size_t args_size) { -#ifdef USE_TRITON_KERNEL - if (ort_triton_kernel_map.count(fname) == 0) { - // Return unsupported status if function name not found in registry. - // This error status will be used by TunableOp - std::ostringstream message_stream; - message_stream << "Can't find ort triton kernel name: " << fname; - std::string message = message_stream.str(); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(true, message); - } - auto idx = ort_triton_kernel_map[fname]; - return LaunchTritonKernel(stream, idx, grid0, grid1, grid2, args, args_size); -#else - return Status::OK(); -#endif -} -Status LaunchTritonKernel(cudaStream_t stream, size_t idx, - int grid0, int grid1, int grid2, void* args, size_t args_size) { + #ifdef USE_TRITON_KERNEL +Status LaunchTritonKernel(cudaStream_t stream, size_t idx, int grid0, int grid1, int grid2, + void* args, size_t args_size) { if (idx >= ort_triton_kernel_metadata.size()) { // Return unsupported status when idx exceeds the size of ort_triton_kernel_metadata. // This error status will be used by TunableOp @@ -181,11 +165,37 @@ Status LaunchTritonKernel(cudaStream_t stream, size_t idx, nullptr, (void**)&config), "Launching kernel failed."); -#endif return Status::OK(); } +Status LaunchTritonKernel(cudaStream_t stream, std::string fname, int grid0, int grid1, int grid2, + void* args, size_t args_size) { + if (ort_triton_kernel_map.count(fname) == 0) { + // Return unsupported status if function name not found in registry. + // This error status will be used by TunableOp + std::ostringstream message_stream; + message_stream << "Can't find ort triton kernel name: " << fname; + std::string message = message_stream.str(); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(true, message); + } + auto idx = ort_triton_kernel_map[fname]; + return LaunchTritonKernel(stream, idx, grid0, grid1, grid2, args, args_size); +} + +#else +Status LaunchTritonKernel(cudaStream_t /*stream*/, std::string /*fname*/, int /*grid0*/, int /*grid1*/, int /*grid2*/, + void* /*args*/, size_t /*args_size*/) { + return Status::OK(); +} + +Status LaunchTritonKernel(cudaStream_t /*stream*/, size_t /*idx*/, int /*grid0*/, int /*grid1*/, int /*grid2*/, + void* /*args*/, size_t /*args_size*/) { + return Status::OK(); +} +#endif + + const TritonKernelMetaData* GetOrtTritonKernelMetadata(size_t idx) { if (idx >= ort_triton_kernel_metadata.size()) { return nullptr; diff --git a/onnxruntime/core/providers/tensorrt/nv_includes.h b/onnxruntime/core/providers/tensorrt/nv_includes.h new file mode 100644 index 0000000000000..c3e9f7a3a2a77 --- /dev/null +++ b/onnxruntime/core/providers/tensorrt/nv_includes.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +// File to include the required TRT headers with workarounds for warnings we can't fix. + +// Ignore warning C4100: unreferenced formal parameter +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h index bf3bf9e3495d7..9f1e5178428e7 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h @@ -6,7 +6,7 @@ #include #include -#include "NvInfer.h" +#include "core/providers/tensorrt/nv_includes.h" #include "core/providers/shared_library/provider_api.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 157cd0a200b35..e521640681a77 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -7,6 +7,7 @@ #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/common/common.h" +#include "core/common/narrow.h" #include "core/common/safeint.h" #include "tensorrt_execution_provider.h" #include "tensorrt_execution_provider_utils.h" @@ -137,10 +138,10 @@ std::vector SplitToStringVec(std::string const& s, char separator) return splitted; } -nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_sting) { +nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { nvinfer1::TacticSources disabledTactics = 0; nvinfer1::TacticSources enabledTactics = 0; - std::vector tacticList = SplitToStringVec(tactic_sting, ','); + std::vector tacticList = SplitToStringVec(tactic_string, ','); for (auto& t : tacticList) { bool enable{false}; if (t.front() == '+') { @@ -151,8 +152,8 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_sting) { t.erase(0, 1); const auto toUpper = [](std::string& sourceName) { - std::transform( - sourceName.begin(), sourceName.end(), sourceName.begin(), [](char c) { return std::toupper(c); }); + std::transform(sourceName.begin(), sourceName.end(), sourceName.begin(), + [](char c) { return onnxruntime::narrow(std::toupper(c)); }); return sourceName; }; @@ -288,7 +289,8 @@ void CudaCall(cudnnStatus_t retCode, const char* exprString return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } -void* OutputAllocator::reallocateOutput(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept { +void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, + uint64_t /*alignment*/) noexcept { // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr // even for empty tensors, so allocate a dummy byte. size = std::max(size, static_cast(1)); @@ -304,7 +306,7 @@ void* OutputAllocator::reallocateOutput(char const* tensorName, void* currentMem return outputPtr; } -void OutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept { +void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims const& dims) noexcept { output_shapes.clear(); output_shapes.reserve(dims.nbDims); for (int i = 0; i < dims.nbDims; i++) { @@ -613,20 +615,22 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector(shape_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); + auto input_shape = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_shape.get(), input_tensor.GetTensorData(), + shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); for (int j = 0; j < shape_size; ++j) { - tensor_shape_values[input_name][j] = input[j]; + tensor_shape_values[input_name][j] = input_shape[j]; } break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - auto input = std::make_unique(shape_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); + auto input_shape = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_shape.get(), input_tensor.GetTensorData(), + shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); for (int j = 0; j < shape_size; ++j) { - tensor_shape_values[input_name][j] = static_cast(input[j]); + tensor_shape_values[input_name][j] = static_cast(input_shape[j]); } break; } @@ -974,7 +978,7 @@ Status BindContextOutput(Ort::KernelContext& ctx, * we are waiting for ORT core to support "assign" memory address to ORT context output. Some works need to be done in ORT memory planner to be aware of this memory support. */ Status BindKernelOutput(Ort::KernelContext& ctx, - OrtMemoryInfo* mem_info, + OrtMemoryInfo* /*mem_info*/, DDSOutputAllocatorMap& allocator_map, char const* output_name, size_t output_index, @@ -1143,7 +1147,8 @@ TensorrtExecutionProvider::PerThreadContext& TensorrtExecutionProvider::GetPerTh // get or create a context if (context_state_.retired_context_pool.empty()) { - context = std::make_shared(info_.device_id, info_.has_user_compute_stream, stream_); + context = std::make_shared(narrow(info_.device_id), + info_.has_user_compute_stream, stream_); } else { context = context_state_.retired_context_pool.back(); context_state_.retired_context_pool.pop_back(); @@ -1163,7 +1168,11 @@ TensorrtExecutionProvider::PerThreadContext& TensorrtExecutionProvider::GetPerTh } TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info), device_id_(info.device_id) { + : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, + narrow(info.device_id))}, + info_(info), + device_id_(info.device_id) { InitProviderOrtApi(); CUDA_CALL_THROW(cudaSetDevice(device_id_)); @@ -1655,7 +1664,8 @@ void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { std::vector TensorrtExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return CreateCUDAAllocator(device_id, onnxruntime::CUDA); }, device_id_); + [](OrtDevice::DeviceId device_id) { return CreateCUDAAllocator(device_id, onnxruntime::CUDA); }, + narrow(device_id_)); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { @@ -3036,7 +3046,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView std::unordered_set input_names; std::unordered_map> tensor_shape_values; - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_); + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); if (alloc_ == nullptr) { Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); } @@ -3603,7 +3614,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con // int num_inputs = static_cast(input_indexes.size()); int num_outputs = static_cast(output_indexes.size()); - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_); + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); if (alloc_ == nullptr) { Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 26f6b2dcc3020..339c45a8742d2 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -5,8 +5,9 @@ #include #include #include -#include "NvInfer.h" -#include "NvOnnxParser.h" + +#include "core/providers/tensorrt/nv_includes.h" + #include "core/platform/ort_mutex.h" #include "core/providers/cuda/cuda_graph.h" #include "tensorrt_execution_provider_info.h" diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index eb340ba1e64b6..b4f348159440f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -1,12 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/framework/provider_options.h" #include "tensorrt_execution_provider_custom_ops.h" #include "tensorrt_execution_provider.h" -#include -#include -#include namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h index b19d9ab0f66d0..54212d34aa2ce 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h @@ -13,7 +13,8 @@ using namespace onnxruntime; namespace onnxruntime { common::Status LoadDynamicLibrary(onnxruntime::PathString library_name); -common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths); +common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, + const std::string extra_plugin_lib_paths); common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info); void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain); void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list); @@ -23,16 +24,22 @@ struct TensorRTCustomKernel { : compute_stream_(compute_stream) { } - void Compute(OrtKernelContext* context){}; // The implementation is in TensorRT plugin. No need to implement it here. + void Compute(OrtKernelContext* /*context*/){ + // The implementation is in TensorRT plugin. No need to implement it here. + }; private: void* compute_stream_; }; struct TensorRTCustomOp : Ort::CustomOpBase { - explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {} + explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider), + compute_stream_(compute_stream) { + } - void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { return new TensorRTCustomKernel(info, compute_stream_); }; + void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { + return new TensorRTCustomKernel(info, compute_stream_); + }; const char* GetName() const { return name_; }; @@ -46,7 +53,9 @@ struct TensorRTCustomOp : Ort::CustomOpBase QK_Transpose(MLFloat16* q_matrix, MLFloat16* k_transpose_ // Softmax_QK_Transpose template -std::vector Softmax_QK_Transpose(T* qk_transpose_matrix, - int batch_size, int num_heads, int sequence_length, int total_sequence_length, int head_size); +std::vector Softmax_QK_Transpose(T* qk_transpose_matrix, int batch_size, int num_heads, + int sequence_length, int total_sequence_length, int head_size); template <> -std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, - int batch_size, int num_heads, int sequence_length, int total_sequence_length, int head_size) { +std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, int batch_size, int num_heads, + int sequence_length, int total_sequence_length, int /*head_size*/) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } @@ -506,8 +506,8 @@ std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, } template <> -std::vector Softmax_QK_Transpose(MLFloat16* qk_transpose_matrix, - int batch_size, int num_heads, int sequence_length, int total_sequence_length, int head_size) { +std::vector Softmax_QK_Transpose(MLFloat16* qk_transpose_matrix, int batch_size, int num_heads, + int sequence_length, int total_sequence_length, int /*head_size*/) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc index 16582696a81d4..532b98317405f 100644 --- a/onnxruntime/test/providers/cpu/generator/random_test.cc +++ b/onnxruntime/test/providers/cpu/generator/random_test.cc @@ -380,7 +380,7 @@ void RunRandomNormalGpuTest(const std::vector dims, const float mean, c test.AddOutput("Y", dims, fp16_data); } - auto output_verifier = [&](const std::vector& fetches, const std::string& provider_type) { + auto output_verifier = [&](const std::vector& fetches, const std::string& /*provider_type*/) { // Only one output, and mean of output values are near attribute mean. ASSERT_EQ(fetches.size(), 1u); const auto& output_tensor = fetches[0].Get(); @@ -472,7 +472,7 @@ void RunRandomUniformGpuTest(const std::vector dims, const float low, c test.AddOutput("Y", dims, fp16_data); } - auto output_verifier = [&](const std::vector& fetches, const std::string& provider_type) { + auto output_verifier = [&](const std::vector& fetches, const std::string& /*provider_type*/) { // Only one output. Each value in output tensoer is between low and high. // Mean of output values are near attribute mean of low and high. ASSERT_EQ(fetches.size(), 1u); diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc index 4c38c90c2b418..d7e8bf9063645 100644 --- a/onnxruntime/test/unittest_main/test_main.cc +++ b/onnxruntime/test/unittest_main/test_main.cc @@ -32,17 +32,30 @@ void ortenv_setup() { } #ifdef USE_TENSORRT + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) // Ignore warning C4100: unreferenced format parameter. +#endif + // TensorRT will load/unload libraries as builder objects are created and torn down. This will happen for // every single unit test, which leads to excessive test execution time due to that overhead. // Nvidia suggests to keep a placeholder builder object around to avoid this. #include "NvInfer.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + class DummyLogger : public nvinfer1::ILogger { public: - DummyLogger(Severity verbosity) {} - void log(Severity severity, const char* msg) noexcept override {} + DummyLogger(Severity /*verbosity*/) {} + void log(Severity /*severity*/, const char* /*msg*/) noexcept override {} }; DummyLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING); + auto const placeholder = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + #endif #define TEST_MAIN main diff --git a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc index d9800ce0e0d3e..d36f9b307ec70 100644 --- a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc @@ -311,11 +311,9 @@ template static std::vector RunSCELossWithEP(const char* op, int opset_version, const char* domain, - std::function()> - ep_creator, + std::function()> ep_creator, const std::string& reduction, const std::int64_t ignore_index, - const double error_tolerance, const std::vector* X_dims, const std::vector* index_dims, const std::vector* weight_dims, @@ -403,7 +401,7 @@ static void TestSCELoss(const char* op, int opset_version, cpu_fetches = RunSCELossWithEP( op, opset_version, domain, []() -> std::unique_ptr { return DefaultCpuExecutionProvider(); }, - reduction, ignore_index, error_tolerance, + reduction, ignore_index, X_dims, index_dims, weight_dims, Y_dims, log_prob_dims, X_data_temp, index_data, weight_data_temp); @@ -411,7 +409,7 @@ static void TestSCELoss(const char* op, int opset_version, cpu_fetches = RunSCELossWithEP( op, opset_version, domain, []() -> std::unique_ptr { return DefaultCpuExecutionProvider(); }, - reduction, ignore_index, error_tolerance, + reduction, ignore_index, X_dims, index_dims, weight_dims, Y_dims, log_prob_dims, X_data, index_data, weight_data); @@ -429,7 +427,7 @@ static void TestSCELoss(const char* op, int opset_version, return DefaultRocmExecutionProvider(); #endif }, - reduction, ignore_index, error_tolerance, + reduction, ignore_index, X_dims, index_dims, weight_dims, Y_dims, log_prob_dims, X_data, index_data, weight_data); diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc index d23905496c9bb..9b30bd128b161 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc @@ -105,7 +105,8 @@ struct AlgoSearch { CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms."); + static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos, + "Missing cuDNN convolution backward data algorithms."); int perf_count; std::unique_ptr candidates = std::make_unique(num_algos); if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { @@ -146,7 +147,9 @@ struct AlgoSearch { // NOTE: - 1 because ALGO_WINOGRAD is not implemented. static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos, + "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates = std::make_unique(num_algos); int perf_count; if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { @@ -188,7 +191,9 @@ struct AlgoSearch { }; static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos, + "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates = std::make_unique(num_algos); int perf_count; if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc index d3f5a89434a48..5d12e0ac312c0 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc @@ -53,7 +53,6 @@ Status ConvTransposeGrad::ComputeInputGradient(onnxruntime::Stream* stream, c algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.y_tensor, args.y_data)); return Status::OK(); }); - return Status::OK(); } template @@ -71,7 +70,6 @@ Status ConvTransposeGrad::ComputeWeightGradient(onnxruntime::Stream* stream, algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.w_desc, args.dw_data)); return Status::OK(); }); - return Status::OK(); } template diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu index 2d89ed05712e0..ad577afa06c18 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu @@ -30,8 +30,6 @@ namespace onnxruntime { namespace cuda { -using namespace onnxruntime::cuda; - namespace { // This is the un-specialized struct. Note that we prevent instantiation of this // struct by putting an undefined symbol in the function body so it won't compile. diff --git a/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu b/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu index c90809eb2fdcc..fd55f7c30ff75 100644 --- a/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu @@ -619,7 +619,7 @@ CudaKernel::CudaAsyncBuffer compute_tensor_rang template void LambMultiTensorReductionFunctor::operator()( - cudaStream_t stream, + cudaStream_t /*stream*/, ChunkGroup<4> chunk_group, const CudaKernel& kernel, void* reduction_buffer, diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml index 9516753d50113..864513bc4d671 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml @@ -93,8 +93,17 @@ steps: $ccache_parent_dir = (Split-Path -parent $ccache_path) Copy-Item "C:\ProgramData\chocolatey\lib\ccache\tools\ccache-4.7.4-windows-x86_64\ccache.exe" -Destination "C:\ProgramData\chocolatey\bin\cl.exe" Get-ChildItem $ccache_parent_dir - ccache --version } + + "ccache info:" + ccache --version + ccache --show-config + + "cl.exe from path: $((Get-Command cl).Path). Version:" + (cl.exe -?) -match 'Compiler Version' + "C:\ProgramData\chocolatey\bin\cl.exe version:" + (C:\ProgramData\chocolatey\bin\cl.exe -?) -match 'Compiler Version' + displayName: Install ccache and update PATH to use linked versions of gcc, cc, etc - ${{ if eq(parameters.WITHCACHE, true) }}: From e93a860819545ea64acfe36e19e2b954389d48bf Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Tue, 5 Mar 2024 21:54:48 -0800 Subject: [PATCH 202/207] Remove arm build for training (#19788) We no longer support Win arm 32 so removing the associated build and packaging job. --- .../ondevice-training-cpu-packaging-pipeline.yml | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml index cf39be23cbdaf..b3faaf2a7f1a6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml @@ -61,21 +61,6 @@ stages: buildJava: false buildNodejs: false -- template: win-ci.yml - parameters: - DoCompliance: ${{ parameters.DoCompliance }} - DoEsrp: ${{ parameters.DoEsrp }} - stage_name_suffix: Training_CPU_arm_${{ parameters.BuildVariant }} - artifact_name_suffix: -training - buildArch: x64 - msbuildPlatform: arm - packageName: arm - buildparameter: --arm ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe - runTests: false - buildJava: false - buildNodejs: false - ort_build_pool_name: onnxruntime-Win-CPU-2022 - - template: win-ci.yml parameters: DoCompliance: ${{ parameters.DoCompliance }} @@ -127,7 +112,6 @@ stages: - Linux_C_API_Packaging_Training_CPU - Windows_Packaging_Training_CPU_x86_${{ parameters.BuildVariant }} - Windows_Packaging_Training_CPU_x64_${{ parameters.BuildVariant }} - - Windows_Packaging_Training_CPU_arm_${{ parameters.BuildVariant }} - Windows_Packaging_Training_CPU_arm64_${{ parameters.BuildVariant }} - Android_Java_API_AAR_Packaging_Training_Full condition: succeeded() From d9bf85613d7171b54a6ece45fc0f241b008a1fd8 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 6 Mar 2024 21:54:16 +0800 Subject: [PATCH 203/207] Adapt memory optimizer to fit PHI2 (#19757) ### Adapt memory optimizer to fit PHI2 Few improvements and bug fixes: 1. Fix bug related to transformer layer detection. 2. Use default reversed typo order to create recompute node, to avoid the leaf nodes are handled too late, then having lowest priority for execution. 3. Add early stop when activation's element count is constant and total element count < 1M. This can avoid overhead to search subgraphs. Using export ORTMODULE_MEMORY_OPT_LEVEL=1 to enable layerwise recompute, on given recipe, memory consumption dropped from ~22GB to ~13GB . --- .../memory_optimizer/memory_insight.cc | 3 +- .../memory_optimizer/memory_optimizer.cc | 37 +++++++++++++++- .../memory_optimizer/recompute_analysis.cc | 18 +++++++- .../memory_optimizer/transformer_specific.cc | 42 +++++++++++++++++-- .../memory_optimizer/transformer_specific.h | 3 ++ 5 files changed, 95 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 08c402bf669c8..54c49db0597c7 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -258,7 +258,8 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, logger)); InlinedHashSet layer_boundary_ln_nodes; - FindLayerBoundaryLayerNormNodes(graph_viewer, logger, layer_boundary_ln_nodes); + FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map, + yield_op_order_in_topological_sort, layer_boundary_ln_nodes); // The first pass - find the candidate subgraphs. for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index 525e3b4b8de35..40fa2fc5cc737 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -190,11 +190,44 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve .IsOK()); // The second pass - apply the transformation. - // Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. + // Note 1: Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. // The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended // earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier // layers. - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + // + // Note 2: Here we use default typo order (which tries to BFS from the outputs, + // so the nearest node to graph output will be visited last). So in reversed default typo order, + // the neareast node to graph output will be visited first. + // Imagine there is a such subgraph + // input1 input2 input3 + // \ | / + // multiple layers + // | + // node M + // labels-------|----- + // \ | | + // node1 | | + // \ | | + // node2 / | + // \ / | + // node loss / + // | / + // YieldOp node1_recompute + // | / + // \ node2 recompute + // \ / + // node loss_grad + // | + // critical grad path + // + // In PriorityBased order, node1 will be visited first, so it's recompute node node1_recompute will be added + // at last because we do this following reversed topological order. Then node1_recompute node will have lowest + // priority to execute, as a result, if at this time, the queue to visit contains only recompute nodes, then + // node1_recompute will be run at last, affecting the backward critical path, which is not what we want. + // Current workaround is to use default order, which will execute node1_recompute earlier than other recompute nodes + // in this case. + + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { Node* p_node = graph.GetNode(node_ids[i]); if (p_node == nullptr) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 12c83591c0036..76b3325f36116 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -19,7 +19,7 @@ namespace onnxruntime::optimizer::memory_optimizer { namespace { -constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 15; +constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 50; static size_t GetElementSize(const ONNX_NAMESPACE::DataType& tensor_type) { const ONNX_NAMESPACE::TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); @@ -291,6 +291,22 @@ Status SelectRecomputeSubgraph(const Node& entry_node, const auto current_node_input_index = input_edge.GetDstArgIndex(); if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) != input_arg_indices.end()) { + // If the tensor size is constant and very small (Now < 1M), we stop adding the input edge into queue. + auto output_shape = parent_node.OutputDefs()[parent_node_output_index]->Shape(); + if (output_shape) { + bool all_constant_dim = true; + int64_t num_elem = 1; + for (int k = 0, dim_size = output_shape->dim_size(); k < dim_size; ++k) { + if (!output_shape->dim(k).has_dim_value()) { + all_constant_dim = false; + num_elem *= output_shape->dim(k).dim_value(); + } + } + if (all_constant_dim && num_elem < 1 * 1024 * 1024) { + // Skip this input index. + continue; + } + } NodeOutputPort next_p = std::make_pair(&parent_node, parent_node_output_index); MO_LOG_DEBUG_INFO(logger, "Node " + parent_node.Name() + "(" + parent_node.OpType() + ")'s " + diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index 04f2679ac774f..c88a0f05d36b8 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -19,6 +19,9 @@ namespace onnxruntime::optimizer::memory_optimizer { void FindLayerBoundaryLayerNormNodes( const GraphViewer& graph_viewer, const logging::Logger&, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const ptrdiff_t& yield_op_order_in_topological_sort, InlinedHashSet& layer_boundary_ln_nodes) { // Loop all nodes to find LayerNormalization nodes. // For each LayerNormalization node, keep checking its output nodes, @@ -40,9 +43,16 @@ void FindLayerBoundaryLayerNormNodes( std::deque nodes_to_check; std::set visited_nodes; for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) { - nodes_to_check.push_back(&(*node_it)); + // Ignore those nodes after YieldOp. + if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) < yield_op_order_in_topological_sort) { + nodes_to_check.push_back(&(*node_it)); + } } + bool unexpected_failure = false; + bool found_softmax = false; + bool found_layernorm = false; + ptrdiff_t next_layernorm_execution_oder = -1; while (!nodes_to_check.empty()) { const Node* next_node = nodes_to_check.front(); nodes_to_check.pop_front(); @@ -53,16 +63,40 @@ void FindLayerBoundaryLayerNormNodes( visited_nodes.insert(next_node); if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) { - layer_boundary_ln_nodes.insert(&node); - break; + found_softmax = true; } else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) { - break; + if (found_layernorm) { + // If we found another LayerNormalization node, we would report as warning, and do nothing for layer boundary detection. + unexpected_failure = true; + break; + } + found_layernorm = true; // don't trace further + next_layernorm_execution_oder = node_index_to_its_order_in_topological_sort_map.at(next_node->Index()); + continue; } else { for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { + // Stop if the node is after next Layernorm node in execution order. + if (found_layernorm && + node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= next_layernorm_execution_oder) { + continue; + } nodes_to_check.push_back(&(*node_it)); } } } + + if (unexpected_failure) { + layer_boundary_ln_nodes.clear(); + break; + } + + if (found_softmax) { + layer_boundary_ln_nodes.insert(&node); + } else if (!found_layernorm) { + // If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node, + // we also consider it as boundary node. + layer_boundary_ln_nodes.insert(&node); + } } } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h index f2cfd640b0840..b58d822124f43 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h @@ -20,6 +20,9 @@ namespace onnxruntime::optimizer::memory_optimizer { void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer, const logging::Logger& logger, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const ptrdiff_t& yield_op_order_in_topological_sort, InlinedHashSet& layer_boundary_ln_nodes); } // namespace onnxruntime::optimizer::memory_optimizer From f9a92e589ad8588424725a91bbd0683a63bda950 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Wed, 6 Mar 2024 09:10:35 -0800 Subject: [PATCH 204/207] Upgrade the Windows SDK version that is used in WindowsAI Nuget Packaging pipeline (#19786) ### Description 1. Upgrade the version from 10.0.19041.0 to 10.0.22621.0. The old one misses some macros that are needed by PyTorch's CPUINFO 2. Also update cmake. ### Motivation and Context In PR #19655 I added CPUINFO to all Windows builds, but forgot to test this pipeline. --- .pipelines/windowsai-steps.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.pipelines/windowsai-steps.yml b/.pipelines/windowsai-steps.yml index ff5179e6135c2..855573de753b0 100644 --- a/.pipelines/windowsai-steps.yml +++ b/.pipelines/windowsai-steps.yml @@ -80,11 +80,11 @@ jobs: # must call vsdevcmd first to add cmake to PATH - script: | - curl -O -L https://github.com/Kitware/CMake/releases/download/v3.26.3/cmake-3.26.3-windows-x86_64.zip - 7z x cmake-3.26.3-windows-x86_64.zip + curl -O -L https://github.com/Kitware/CMake/releases/download/v3.28.3/cmake-3.28.3-windows-x86_64.zip + 7z x cmake-3.28.3-windows-x86_64.zip set PYTHONHOME=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools set PYTHONPATH=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools - $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe + $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --parallel --use_binskim_compliant_compile_flags --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos --windows_sdk_version "10.0.22621.0" $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" --cmake_path $(Build.BinariesDirectory)\cmake-3.28.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.28.3-windows-x86_64\bin\ctest.exe workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Generate cmake config' From db8d0c8e06fd030da6b7bf00cf3fb20661dd13b8 Mon Sep 17 00:00:00 2001 From: Hector Li Date: Wed, 6 Mar 2024 11:21:19 -0800 Subject: [PATCH 205/207] reset dcvsEnable for different HTP performance mode (#19728) reset dcvsEnable for different HTP performance mode --- .../qnn/builder/qnn_backend_manager.cc | 80 ++++++++++--------- 1 file changed, 44 insertions(+), 36 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index e354bf6562722..6bb57b6a3e56c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -678,13 +678,13 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, dcvs_v3.setSleepDisable = 0; dcvs_v3.sleepDisable = 0; dcvs_v3.setDcvsEnable = 1; - dcvs_v3.dcvsEnable = kDcvsDisable; dcvs_v3.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_PERFORMANCE_MODE; // choose performance mode switch (htp_performance_mode) { case HtpPerformanceMode::kHtpBurst: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMinLatency; + dcvs_v3.dcvsEnable = kDcvsDisable; dcvs_v3.setBusParams = 1; dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_MAX_VOLTAGE_CORNER; dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_MAX_VOLTAGE_CORNER; @@ -698,6 +698,7 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, case HtpPerformanceMode::kHtpHighPerformance: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepLowLatency; + dcvs_v3.dcvsEnable = kDcvsDisable; dcvs_v3.setBusParams = 1; dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_TURBO; dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_TURBO; @@ -707,33 +708,36 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_TURBO; dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_TURBO; break; - case HtpPerformanceMode::kHtpPowerSaver: + case HtpPerformanceMode::kHtpBalanced: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMediumLatency; + dcvs_v3.dcvsEnable = kDcvsEnable; dcvs_v3.setBusParams = 1; - dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS; - dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS; - dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS; + dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM_PLUS; + dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM_PLUS; + dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS; dcvs_v3.setCoreParams = 1; - dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS; - dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS; - dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS; + dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM_PLUS; + dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM_PLUS; + dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS; break; - case HtpPerformanceMode::kHtpLowPowerSaver: + case HtpPerformanceMode::kHtpLowBalanced: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMediumLatency; + dcvs_v3.dcvsEnable = kDcvsEnable; dcvs_v3.setBusParams = 1; - dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS2; - dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS2; - dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS2; + dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM; + dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM; + dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM; dcvs_v3.setCoreParams = 1; - dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS2; - dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS2; - dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS2; + dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM; + dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM; + dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM; break; case HtpPerformanceMode::kHtpHighPowerSaver: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMediumLatency; + dcvs_v3.dcvsEnable = kDcvsEnable; dcvs_v3.setBusParams = 1; dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS_PLUS; dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS_PLUS; @@ -743,41 +747,45 @@ Status QnnBackendManager::SetHtpPowerConfig(uint32_t htp_power_config_client_id, dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS_PLUS; dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS_PLUS; break; - case HtpPerformanceMode::kHtpExtremePowerSaver: + case HtpPerformanceMode::kHtpPowerSaver: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMediumLatency; + dcvs_v3.dcvsEnable = kDcvsEnable; dcvs_v3.setBusParams = 1; - dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_CORNER_DISABLE; - dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_CORNER_DISABLE; - dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_CORNER_DISABLE; + dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS; + dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS; + dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS; dcvs_v3.setCoreParams = 1; - dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_CORNER_DISABLE; - dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_CORNER_DISABLE; - dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_CORNER_DISABLE; + dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS; + dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS; + dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS; break; - case HtpPerformanceMode::kHtpLowBalanced: + case HtpPerformanceMode::kHtpLowPowerSaver: dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMediumLatency; + dcvs_v3.dcvsEnable = kDcvsEnable; dcvs_v3.setBusParams = 1; - dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM; - dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM; - dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM; + dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS2; + dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS2; + dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS2; dcvs_v3.setCoreParams = 1; - dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM; - dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM; - dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM; + dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_SVS2; + dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_SVS2; + dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_SVS2; break; - case HtpPerformanceMode::kHtpBalanced: + case HtpPerformanceMode::kHtpExtremePowerSaver: + dcvs_v3.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_POWER_SAVER_MODE; dcvs_v3.setSleepLatency = 1; // true dcvs_v3.sleepLatency = kSleepMediumLatency; + dcvs_v3.dcvsEnable = kDcvsEnable; dcvs_v3.setBusParams = 1; - dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM_PLUS; - dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM_PLUS; - dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS; + dcvs_v3.busVoltageCornerMin = DCVS_VOLTAGE_CORNER_DISABLE; + dcvs_v3.busVoltageCornerTarget = DCVS_VOLTAGE_CORNER_DISABLE; + dcvs_v3.busVoltageCornerMax = DCVS_VOLTAGE_CORNER_DISABLE; dcvs_v3.setCoreParams = 1; - dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM_PLUS; - dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM_PLUS; - dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_NOM_PLUS; + dcvs_v3.coreVoltageCornerMin = DCVS_VOLTAGE_CORNER_DISABLE; + dcvs_v3.coreVoltageCornerTarget = DCVS_VOLTAGE_CORNER_DISABLE; + dcvs_v3.coreVoltageCornerMax = DCVS_VOLTAGE_CORNER_DISABLE; break; default: ORT_THROW("Invalid performance profile %d", static_cast(htp_performance_mode)); From 8bd1335d00375179fa9cdccf1c6fbda8c04304df Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Wed, 6 Mar 2024 12:34:33 -0800 Subject: [PATCH 206/207] Fix GQA Rotary Embedding sequence length (#19801) ### Description Previously, GQA incorrectly enforced rotary cos and sin cache to be of sequence length equal to present sequence length. Now it enforces that it be greater than or equal to present sequence length since to match Rotary Embedding Op it should be of max_sequence_length ### Motivation and Context Fixes issue with fusing Rotary Embedding and GQA for certain models which prefer this optimization. --- .../contrib_ops/cuda/bert/group_query_attention_helper.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 853e1a710cb24..6fa11200fd5be 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -214,13 +214,13 @@ Status CheckInputs(const Tensor* query, "head_size shall be a multiple of 16. Got head_size % 16 == ", head_size % 16); } - if (cos_dims[0] != present_sequence_length) { + if (cos_dims[0] < present_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 must be of present_sequence_length."); + "cos_cache dimension 0 should be of max_sequence_length."); } - if (sin_dims[0] != present_sequence_length) { + if (sin_dims[0] < present_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 must be of present_sequence_length."); + "sin_cache dimension 0 should be of max_sequence_length."); } if (cos_dims[1] != (head_size / 16) * 8) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, From f2dc725b3355ec25e61d6970b6c030c68f9d3ac4 Mon Sep 17 00:00:00 2001 From: Markus Tavenrath Date: Wed, 6 Mar 2024 21:35:55 +0100 Subject: [PATCH 207/207] Add SpaceToDepth and DepthToSpace CUDA NHWC Ops (#19646) ### Description - Adding CUDA NHWC support for SpaceToDepth and DepthToSpace - Add a new test which verifies that swizzling SpaceToDepth swizzling for the H axis is correct. - If CUDA NHWC is enabled, run all tests on the CUDA EP with NHWC as well. ### Motivation and Context Adding more NHWC operations to avoid layout transformations when using the CUDA EP for more efficiency. --- include/onnxruntime/core/graph/constants.h | 1 + .../contrib_ops/internal_nhwc_onnx_schemas.cc | 1 + .../layout_transformation.cc | 3 +- .../providers/cpu/tensor/space_depth_ops.h | 16 +- .../core/providers/cuda/cuda_nhwc_kernels.cc | 16 ++ .../providers/cuda/tensor/space_depth_ops.cc | 196 +++++++++++++----- .../providers/cuda/tensor/space_depth_ops.h | 2 + .../test/contrib_ops/gridsample_test.cc | 17 +- onnxruntime/test/providers/base_tester.cc | 7 + .../providers/cpu/generator/random_test.cc | 12 +- .../providers/cpu/nn/batch_norm_op_test.cc | 6 +- .../test/providers/cpu/nn/conv_op_test.cc | 2 + .../cpu/nn/conv_transpose_op_test.cc | 15 +- .../test/providers/cpu/nn/pool_op_test.cc | 86 ++++---- .../cpu/reduction/reduction_ops_test.cc | 3 + .../test/providers/cpu/rnn/rnn_op_test.cc | 7 +- .../cpu/tensor/gather_elements_op_test.cc | 2 +- .../providers/cpu/tensor/resize_op_test.cc | 22 +- .../providers/cpu/tensor/scatter_op_test.cc | 7 +- .../cpu/tensor/space_depth_ops_test.cc | 47 +++++ .../providers/cpu/tensor/upsample_op_test.cc | 6 +- 21 files changed, 345 insertions(+), 129 deletions(-) diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h index 9b26ba914c7dd..8e04050d089a0 100644 --- a/include/onnxruntime/core/graph/constants.h +++ b/include/onnxruntime/core/graph/constants.h @@ -31,6 +31,7 @@ constexpr size_t kMaxExecutionProviderNameLen = 30; constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider"; constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider"; +constexpr const char* kCudaNHWCExecutionProvider = "CUDANHWCExecutionProvider"; constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; constexpr const char* kVitisAIExecutionProvider = "VitisAIExecutionProvider"; diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index c8960578f9e3d..6bf19654a3ce9 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -106,6 +106,7 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function& GetCUDALayoutSensitiveOps() { "GlobalAveragePool", "AveragePool", "GridSample", - }; + "DepthToSpace", + "SpaceToDepth"}; }(); return cuda_nhwc_ops; } diff --git a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h index 7d117317ba172..3218c8952d6ec 100644 --- a/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h +++ b/onnxruntime/core/providers/cpu/tensor/space_depth_ops.h @@ -14,6 +14,7 @@ class SpaceDepthBase { "Attribute blocksize is not set."); } + template Status InputValidationsAndOutputDimsCalc(const Tensor& input, int64_t& batch, int64_t& input_depth, int64_t& input_height, int64_t& input_width, @@ -27,9 +28,15 @@ class SpaceDepthBase { } batch = input_shape[0]; - input_depth = input_shape[1]; - input_height = input_shape[2]; - input_width = input_shape[3]; + if constexpr (IsNHWC) { + input_depth = input_shape[3]; + input_height = input_shape[1]; + input_width = input_shape[2]; + } else { + input_depth = input_shape[1]; + input_height = input_shape[2]; + input_width = input_shape[3]; + } if (is_space_to_depth) { // SpaceToDepth op if ((input_height % this->blocksize_) != 0) { @@ -46,7 +53,8 @@ class SpaceDepthBase { } else { // DepthToSpace op if ((input_depth % (blocksize_ * blocksize_) != 0)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "DepthToSpace requires input depth to be a multiple of (block_size * blok_size)"); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "DepthToSpace requires input depth to be a multiple of (block_size * block_size)"); } output_depth = input_depth / blocksize_ / blocksize_; diff --git a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc index 64edc319e15ac..da7802fe8d5dc 100644 --- a/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc +++ b/onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc @@ -86,6 +86,11 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalN BatchNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 15, MLFloat16, BatchNormalization); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, DepthToSpace); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 11, 12, DepthToSpace); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, DepthToSpace); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 12, SpaceToDepth); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 13, SpaceToDepth); Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn nhwc_function_table[] = { @@ -171,6 +176,17 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) { kCudaExecutionProvider, kMSInternalNHWCDomain, 1, 10, float, ConvTranspose)>, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : nhwc_function_table) { diff --git a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc index 407a2ef3981f1..aaaf3600b676e 100644 --- a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc +++ b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.cc @@ -20,7 +20,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - SpaceToDepth); + SpaceToDepth); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + SpaceToDepth, + kMSInternalNHWCDomain, + 1, + 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + SpaceToDepth); +#endif ONNX_OPERATOR_KERNEL_EX( SpaceToDepth, @@ -32,7 +47,21 @@ ONNX_OPERATOR_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - SpaceToDepth); + SpaceToDepth); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_KERNEL_EX( + SpaceToDepth, + kMSInternalNHWCDomain, + 13, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + SpaceToDepth); +#endif ONNX_OPERATOR_VERSIONED_KERNEL_EX( DepthToSpace, @@ -45,7 +74,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - DepthToSpace); + DepthToSpace); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + DepthToSpace, + kMSInternalNHWCDomain, + 1, + 10, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + DepthToSpace); +#endif ONNX_OPERATOR_VERSIONED_KERNEL_EX( DepthToSpace, @@ -58,7 +102,22 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - DepthToSpace); + DepthToSpace); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + DepthToSpace, + kMSInternalNHWCDomain, + 11, + 12, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + DepthToSpace); +#endif ONNX_OPERATOR_KERNEL_EX( DepthToSpace, @@ -70,23 +129,35 @@ ONNX_OPERATOR_KERNEL_EX( {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), - DepthToSpace); + DepthToSpace); + +#ifdef ENABLE_CUDA_NHWC_OPS +ONNX_OPERATOR_KERNEL_EX( + DepthToSpace, + kMSInternalNHWCDomain, + 13, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", + {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + DepthToSpace); +#endif static Status SpaceDepthOpCudaImpl(const cudaDeviceProp& prop, cudaStream_t stream, const cublasHandle_t cublas_handle, const Tensor& input, Tensor& output, const std::vector& permutation, - const int64_t batch_size, - const int64_t in_dim1, const int64_t in_dim2, const int64_t in_dim3, - const int64_t in_dim4, const int64_t in_dim5, + const TensorShape& virtual_input_shape, const TensorShape& virtual_output_shape) { - TensorShape virtual_input_shape{batch_size, in_dim1, in_dim2, in_dim3, in_dim4, in_dim5}; return Transpose::DoTranspose(prop, stream, cublas_handle, permutation, input, output, &virtual_input_shape, &virtual_output_shape); } -Status SpaceToDepth::ComputeInternal(OpKernelContext* context) const { +template +Status SpaceToDepth::ComputeInternal(OpKernelContext* context) const { const auto* tensor_pointer = context->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& input = *tensor_pointer; @@ -101,29 +172,44 @@ Status SpaceToDepth::ComputeInternal(OpKernelContext* context) const { int64_t output_height = -1; int64_t output_width = -1; - ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc(input, - batch, - input_depth, input_height, input_width, - output_depth, output_height, output_width, - true)); + ORT_RETURN_IF_ERROR( + InputValidationsAndOutputDimsCalc(input, + batch, + input_depth, input_height, input_width, + output_depth, output_height, output_width, + true)); // We use the "actual" output shape to construct the output tensor - Tensor& output = *context->Output(0, {batch, output_depth, output_height, output_width}); + Tensor& output = (Layout == LAYOUT_NCHW) + ? *context->Output(0, {batch, output_depth, output_height, output_width}) + : *context->Output(0, {batch, output_height, output_width, output_depth}); + + TensorShape virtual_input_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, input_depth, input_height / blocksize_, + blocksize_, input_width / blocksize_, blocksize_} + : TensorShape{batch, input_height / blocksize_, blocksize_, + input_width / blocksize_, blocksize_, input_depth}; // We will pass in the "virtual" output shape to be used by DoTranspose() in SpaceDepthOpCudaImpl(...) - TensorShape virtual_output_shape{batch, blocksize_, blocksize_, input_depth, - input_height / blocksize_, input_width / blocksize_}; + TensorShape virtual_output_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, blocksize_, blocksize_, input_depth, + input_height / blocksize_, input_width / blocksize_} + : TensorShape{batch, input_height / blocksize_, input_width / blocksize_, + blocksize_, blocksize_, input_depth}; - std::vector permutation = {0, 3, 5, 1, 2, 4}; + std::vector permutation = (Layout == LAYOUT_NCHW) + ? std::vector{0, 3, 5, 1, 2, 4} + : std::vector{0, 1, 3, 2, 4, 5}; - ORT_RETURN_IF_ERROR(SpaceDepthOpCudaImpl(GetDeviceProp(), Stream(context), GetCublasHandle(context), input, output, permutation, batch, - input_depth, input_height / blocksize_, blocksize_, input_width / blocksize_, blocksize_, - virtual_output_shape)); + ORT_RETURN_IF_ERROR( + SpaceDepthOpCudaImpl(GetDeviceProp(), Stream(context), GetCublasHandle(context), input, output, permutation, + virtual_input_shape, virtual_output_shape)); return Status::OK(); } -Status DepthToSpace::ComputeInternal(OpKernelContext* context) const { +template +Status DepthToSpace::ComputeInternal(OpKernelContext* context) const { const auto* tensor_pointer = context->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); const Tensor& input = *tensor_pointer; @@ -138,46 +224,56 @@ Status DepthToSpace::ComputeInternal(OpKernelContext* context) const { int64_t output_height = -1; int64_t output_width = -1; - ORT_RETURN_IF_ERROR(InputValidationsAndOutputDimsCalc(input, - batch, - input_depth, input_height, input_width, - output_depth, output_height, output_width, - false)); + ORT_RETURN_IF_ERROR( + InputValidationsAndOutputDimsCalc(input, + batch, + input_depth, input_height, input_width, + output_depth, output_height, output_width, + false)); // We use the "actual" output shape to construct the output tensor - Tensor& output = *context->Output(0, {batch, output_depth, output_height, output_width}); + Tensor& output = (Layout == LAYOUT_NCHW) + ? *context->Output(0, {batch, output_depth, output_height, output_width}) + : *context->Output(0, {batch, output_height, output_width, output_depth}); + + int64_t virtual_input_depth = input_depth / blocksize_ / blocksize_; + TensorShape virtual_input_shape; + + // cdr only here! + if (is_dcr_) { + virtual_input_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, blocksize_, blocksize_, + virtual_input_depth, input_height, input_width} + : TensorShape{batch, input_height, input_width, + blocksize_, blocksize_, virtual_input_depth}; + } else { + virtual_input_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, virtual_input_depth, blocksize_, + blocksize_, input_height, input_width} + : TensorShape{batch, input_height, input_width, + virtual_input_depth, blocksize_, blocksize_}; + } // We will pass in the "virtual" output shape to be used by DoTranspose() in SpaceDepthOpCudaImpl(...) - TensorShape virtual_output_shape{batch, input_depth / blocksize_ / blocksize_, - input_height, blocksize_, input_width, blocksize_}; + TensorShape virtual_output_shape = (Layout == LAYOUT_NCHW) + ? TensorShape{batch, virtual_input_depth, input_height, + blocksize_, input_width, blocksize_} + : TensorShape{batch, input_height, blocksize_, + input_width, blocksize_, virtual_input_depth}; std::vector permutation; - permutation.reserve(6); - permutation.push_back(0); if (is_dcr_) { - permutation.push_back(3); - permutation.push_back(4); - permutation.push_back(1); - permutation.push_back(5); - permutation.push_back(2); + permutation = (Layout == LAYOUT_NCHW) + ? std::vector({0, 3, 4, 1, 5, 2}) + : std::vector({0, 1, 3, 2, 4, 5}); } else { - permutation.push_back(1); - permutation.push_back(4); - permutation.push_back(2); - permutation.push_back(5); - permutation.push_back(3); + permutation = std::vector({0, 1, 4, 2, 5, 3}); } - int64_t dim1 = is_dcr_ ? blocksize_ : input_depth / blocksize_ / blocksize_; - int64_t dim3 = is_dcr_ ? input_depth / blocksize_ / blocksize_ : blocksize_; - ORT_RETURN_IF_ERROR(SpaceDepthOpCudaImpl(GetDeviceProp(), Stream(context), GetCublasHandle(context), input, output, - permutation, - batch, - dim1, blocksize_, dim3, input_height, input_width, - virtual_output_shape)); + permutation, virtual_input_shape, virtual_output_shape)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h index 57b85556f1dbe..8780d9b365005 100644 --- a/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h +++ b/onnxruntime/core/providers/cuda/tensor/space_depth_ops.h @@ -9,6 +9,7 @@ namespace onnxruntime { namespace cuda { +template class SpaceToDepth final : public CudaKernel, SpaceDepthBase { public: explicit SpaceToDepth(const OpKernelInfo& info) : CudaKernel(info), SpaceDepthBase(info) { @@ -17,6 +18,7 @@ class SpaceToDepth final : public CudaKernel, SpaceDepthBase { Status ComputeInternal(OpKernelContext* context) const override; }; +template class DepthToSpace final : public CudaKernel, SpaceDepthBase { public: explicit DepthToSpace(const OpKernelInfo& info) : CudaKernel(info), SpaceDepthBase(info) { diff --git a/onnxruntime/test/contrib_ops/gridsample_test.cc b/onnxruntime/test/contrib_ops/gridsample_test.cc index 1f31c2bd21f14..46ed04301a9e8 100644 --- a/onnxruntime/test/contrib_ops/gridsample_test.cc +++ b/onnxruntime/test/contrib_ops/gridsample_test.cc @@ -32,7 +32,7 @@ TEST(GridsampleContribOpTest, gridsample_default) { 3.8000f, 7.9000f, 8.7000f, 9.5000f, 10.3000f, 5.3000f, 5.4000f, 11.1000f, 11.9000f, 12.7000f, 13.5000f, 6.9000f, 3.0000f, 6.1500f, 6.5500f, 6.9500f, 7.3500f, 3.7500f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(GridsampleContribOpTest, gridsample_paddingmode_zeros) { @@ -45,7 +45,7 @@ TEST(GridsampleContribOpTest, gridsample_paddingmode_zeros) { 5.0000f, 5.0000f, 10.0000f, 10.0000f}); test.AddAttribute("padding_mode", "zeros"); test.AddOutput("Y", {1, 1, 2, 4}, {0.0000f, 0.0000f, 1.7000f, 0.0000f, 0.0000f, 1.7000f, 0.0000f, 0.0000f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(GridsampleContribOpTest, gridsample_paddingmode_border) { @@ -58,7 +58,7 @@ TEST(GridsampleContribOpTest, gridsample_paddingmode_border) { 5.0000f, 5.0000f, 10.0000f, 10.0000f}); test.AddAttribute("padding_mode", "border"); test.AddOutput("Y", {1, 1, 2, 4}, {0.0000f, 0.0000f, 1.7000f, 5.0000f, 5.0000f, 1.7000f, 5.0000f, 5.0000f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(GridsampleContribOpTest, gridsample_paddingmode_reflection) { @@ -71,7 +71,8 @@ TEST(GridsampleContribOpTest, gridsample_paddingmode_reflection) { 5.0000f, 5.0000f, 10.0000f, 10.0000f}); test.AddAttribute("padding_mode", "reflection"); test.AddOutput("Y", {1, 1, 2, 4}, {2.5000f, 0.0000f, 1.7000f, 2.5000f, 2.5000f, 1.7000f, 5.0000f, 2.5000f}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); // Accuracy issue for QNN + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kQnnExecutionProvider}); // Accuracy issue for QNN } TEST(GridsampleContribOpTest, gridsample_aligncorners_true) { @@ -86,7 +87,7 @@ TEST(GridsampleContribOpTest, gridsample_aligncorners_true) { test.AddAttribute("mode", "bilinear"); test.AddAttribute("align_corners", align_corners); test.AddOutput("Y", {1, 1, 2, 4}, {0.0000f, 1.2500f, 2.0000f, 2.5000f, 2.5000f, 2.0000f, 3.7500f, 5.0000f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(GridsampleContribOpTest, gridsample_mode_bilinear) { @@ -99,7 +100,7 @@ TEST(GridsampleContribOpTest, gridsample_mode_bilinear) { 0.5000f, 0.5000f, 1.0000f, 1.0000f}); test.AddAttribute("mode", "bilinear"); test.AddOutput("Y", {1, 1, 2, 4}, {0.0000f, 0.5000f, 1.7000f, 2.5000f, 2.5000f, 1.7000f, 4.5000f, 1.2500f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(GridsampleContribOpTest, gridsample_mode_nearest) { @@ -112,7 +113,7 @@ TEST(GridsampleContribOpTest, gridsample_mode_nearest) { 0.5000f, 0.5000f, 1.0000f, 1.0000f}); test.AddAttribute("mode", "nearest"); test.AddOutput("Y", {1, 1, 2, 4}, {0.f, 0.f, 2.f, 2.f, 2.f, 2.f, 5.f, 0.f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(GridsampleContribOpTest, gridsample_mode_bicubic) { @@ -125,7 +126,7 @@ TEST(GridsampleContribOpTest, gridsample_mode_bicubic) { 0.5000f, 0.5000f, 1.0000f, 1.0000f}); test.AddAttribute("mode", "bicubic"); test.AddOutput("Y", {1, 1, 2, 4}, {-0.1406f, 0.3828f, 1.7556f, 2.9688f, 2.9688f, 1.7556f, 5.1445f, 1.3906f}); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } } // namespace test diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 16cce85f7cb0a..84cb663a2984a 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -622,6 +622,9 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, static const std::string all_provider_types[] = { kCpuExecutionProvider, kCudaExecutionProvider, +#ifdef ENABLE_CUDA_NHWC_OPS + kCudaNHWCExecutionProvider, +#endif kDnnlExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, @@ -650,6 +653,10 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, execution_provider = DefaultCpuExecutionProvider(); else if (provider_type == onnxruntime::kCudaExecutionProvider) execution_provider = DefaultCudaExecutionProvider(); +#ifdef ENABLE_CUDA_NHWC_OPS + else if (provider_type == onnxruntime::kCudaNHWCExecutionProvider) + execution_provider = DefaultCudaNHWCExecutionProvider(); +#endif else if (provider_type == onnxruntime::kDnnlExecutionProvider) execution_provider = DefaultDnnlExecutionProvider(); else if (provider_type == onnxruntime::kOpenVINOExecutionProvider) diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc index 532b98317405f..be049d1cf0ce3 100644 --- a/onnxruntime/test/providers/cpu/generator/random_test.cc +++ b/onnxruntime/test/providers/cpu/generator/random_test.cc @@ -36,7 +36,8 @@ TEST(Random, RandomNormal2DDouble) { // The expected_output is generated using std lib, which is used by CPU kernel only. // So we need to exclude other EPs here. Ditto for other places. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); } void RunRandomNormalLike3DFloat(bool infer_dtype = false) { @@ -72,7 +73,8 @@ void RunRandomNormalLike3DFloat(bool infer_dtype = false) { test.AddOutput("Y", dims, expected_output); // TensorRT does not support manual seed overrides and there will be result mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); } TEST(Random, RandomNormalLike3DDouble) { @@ -109,7 +111,8 @@ TEST(Random, RandomUniform1DFloat) { test.AddOutput("Y", dims, expected_output); // TensorRT does not support manual seed overrides and there will be result mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); } void RunRandomUniformLikeTest(bool infer_dtype = false) { @@ -142,7 +145,8 @@ void RunRandomUniformLikeTest(bool infer_dtype = false) { test.AddOutput("Y", dims, expected_output); // TensorRT does not support seed parameter and there will be result mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider}); } TEST(Random, RandomUniformLike2DDouble) { diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index 54e5c71bd753a..3d30fc62a945d 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -917,7 +917,7 @@ TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) { // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() test.Run(OpTester::ExpectResult::kExpectSuccess, "", // TODO(mtavenrath) flakiness of running_mean for CUDA has been fixed, the delta of running_var is still ~0.1 - {kCudaExecutionProvider, kRocmExecutionProvider, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } @@ -945,7 +945,7 @@ TEST(BatchNormTest, ForwardTrainingTestOpset14) { // exclude CUDA Execution Provider due to flakiness // exclude TRT and OpenVINO for same reasons as seen in TestBatchNorm() test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } @@ -972,7 +972,7 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) { // Same exclusions as the opset 14 test test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } #endif // BATCHNORM_INCLUDE_TRAINING_SUPPORT diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index dede278b7274f..0efa78af2795c 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -59,6 +59,8 @@ void TestConvOp(const ConvOpAndTestAttributes& attributes, std::unordered_set excluded_providers(attributes.excluded_providers); // Disable TensorRT because weight as input is not supported excluded_providers.insert(kTensorrtExecutionProvider); + // Disable CUDA NHWC execution provider as it is currently flaky + excluded_providers.insert(kCudaNHWCExecutionProvider); // QNN SDK 2.10.0 has a bug that breaks support for dynamic bias inputs. excluded_providers.insert(kQnnExecutionProvider); diff --git a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc index 472f841aa8565..ec93dc249eeb2 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_transpose_op_test.cc @@ -75,7 +75,8 @@ void TestConvTransposeOp(const ConvTransposeOpAttributes& attributes, const vector& expected_output_shape, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "", - const std::unordered_set& excluded_provider_types = {kTensorrtExecutionProvider, kQnnExecutionProvider}) { + const std::unordered_set& excluded_provider_types = + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}) { std::unordered_set extra_exclude_openvino_for_initializer_filter = excluded_provider_types; extra_exclude_openvino_for_initializer_filter.insert(kOpenVINOExecutionProvider); TestConvTransposeOpInitializer(attributes, inputs, input_shapes, expected_output, expected_output_shape, @@ -409,7 +410,8 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShape_2) { vector Y_shape = {1, 1, 1, 14}; auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, - OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kQnnExecutionProvider}); + OpTester::ExpectResult::kExpectSuccess, "", + {kOpenVINOExecutionProvider, kCudaNHWCExecutionProvider, kQnnExecutionProvider}); } TEST(ConvTransposeTest, ConvTranspose_2D_OutputShapeWithBatchSize) { @@ -434,7 +436,8 @@ TEST(ConvTransposeTest, ConvTranspose_2D_OutputShapeWithBatchSize) { auto expected_vals = {1.0f, 2.0f, 5.0f, 11.0f, 19.0f, 28.0f, 37.0f, 46.0f, 55.0f, 64.0f, 63.0f, 51.0f, 27.0f, 10.0f, 11.0f, 32.0f, 65.0f, 91.0f, 109.0f, 118.0f, 127.0f, 136.0f, 145.0f, 154.0f, 143.0f, 111.0f, 57.0f, 20.0f}; TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, - OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider, kQnnExecutionProvider}); + OpTester::ExpectResult::kExpectSuccess, "", + {kOpenVINOExecutionProvider, kCudaNHWCExecutionProvider, kQnnExecutionProvider}); } TEST(ConvTransposeTest, ConvTranspose_InvalidKernelShape) { @@ -871,7 +874,8 @@ TEST(ConvTransposeTest, DimWithZero) { TestConvTransposeOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kAclExecutionProvider, kQnnExecutionProvider}); + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, + kAclExecutionProvider, kQnnExecutionProvider}); } TEST(ConvTransposeTest, ConvTranspose_3D) { @@ -1005,7 +1009,8 @@ TEST(ConvTransposeTest, ConvTranspose_3D) { TestConvTransposeOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape, OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider, kCudaExecutionProvider, kQnnExecutionProvider}); + {kTensorrtExecutionProvider, kCudaExecutionProvider, + kCudaNHWCExecutionProvider, kQnnExecutionProvider}); } TEST(ConvTransposeTest, ConvTranspose_1D_AsymmetricPads) { diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc index 4b194ec18b31b..e24cda17166ed 100644 --- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc @@ -57,7 +57,8 @@ TEST(PoolTest, MaxPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: result differs + // TensorRT: result differs + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } // Only CUDA kernel has float 16 support @@ -115,7 +116,8 @@ TEST(PoolTest, MaxPool_F16) { test.AddInput("X", x_dims, f_X); test.AddOutput("Y", expected_dims, f_Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Assertion `!attrs.count("pads")' failed + // TensorRT: Assertion `!attrs.count("pads")' failed + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } #endif @@ -167,7 +169,9 @@ static void MaxPool_8_WithIndexTest(bool has_index, int64_t storage_order = 0) { storage_order == 0 ? test.AddOutput("Indices", expected_dims, expected_indices_row) : test.AddOutput("Indices", expected_dims, expected_indices_col); } - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kDnnlExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider, kArmNNExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kDnnlExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, + kAclExecutionProvider, kArmNNExecutionProvider, kOpenVINOExecutionProvider}); } TEST(PoolTest, MaxPool_8_With_Index) { @@ -196,7 +200,7 @@ TEST(PoolTest, MaxPool1D) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } static void MaxPool1D_8_WithIndexTest(int64_t storage_order) { @@ -217,7 +221,8 @@ static void MaxPool1D_8_WithIndexTest(int64_t storage_order) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.AddOutput("Indices", expected_dims, expected_indices); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool1D_8_With_Index) { @@ -243,7 +248,8 @@ static void MaxPool1D_12_WithIndexTest_int8(int64_t storage_order) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.AddOutput("Indices", expected_dims, expected_indices); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } static void MaxPool1D_12_WithIndexTest_uint8(int64_t storage_order) { @@ -264,7 +270,8 @@ static void MaxPool1D_12_WithIndexTest_uint8(int64_t storage_order) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.AddOutput("Indices", expected_dims, expected_indices); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool1D_12_With_Index_8bits) { @@ -302,9 +309,9 @@ TEST(PoolTest, MaxPool2D_uint8) { test.AddOutput("Output", output_shape, output); #if defined(OPENVINO_CONFIG_GPU_FP32) || defined(OPENVINO_CONFIG_GPU_FP16) - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kOpenVINOExecutionProvider}); #else - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); #endif } @@ -330,7 +337,7 @@ TEST(PoolTest, MaxPool_10_Dilation_1d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_DefaultDilations) { @@ -350,7 +357,7 @@ TEST(PoolTest, MaxPool_DefaultDilations) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_DefaultDilations_int8) { @@ -370,7 +377,7 @@ TEST(PoolTest, MaxPool_DefaultDilations_int8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_DefaultDilations_uint8) { @@ -390,7 +397,7 @@ TEST(PoolTest, MaxPool_DefaultDilations_uint8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_10_DilationPadding_1d) { @@ -416,7 +423,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_1d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_2d) { @@ -444,7 +451,7 @@ TEST(PoolTest, MaxPool_10_Dilation_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_2d_int8) { @@ -472,7 +479,7 @@ TEST(PoolTest, MaxPool_10_Dilation_2d_int8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, MaxPool_10_DilationPadding_2d) { @@ -500,7 +507,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) { @@ -528,7 +535,8 @@ TEST(PoolTest, MaxPool_10_Dilation_Ceil0_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_12_Dilation_Ceil0_2d_int8) { @@ -556,7 +564,8 @@ TEST(PoolTest, MaxPool_12_Dilation_Ceil0_2d_int8) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_10_Dilation_Ceil1_2d) { @@ -585,7 +594,8 @@ TEST(PoolTest, MaxPool_10_Dilation_Ceil1_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, MaxPool_10_DilationPadding_3d) { @@ -621,7 +631,7 @@ TEST(PoolTest, MaxPool_10_DilationPadding_3d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(PoolTest, GlobalMaxPool) { @@ -697,7 +707,7 @@ TEST(PoolTest, GlobalMaxPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(PoolTest, GlobalMaxPool3D) { @@ -773,7 +783,7 @@ TEST(PoolTest, GlobalMaxPool3D) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, AveragePool) { @@ -854,7 +864,7 @@ TEST(PoolTest, AveragePool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, AveragePool_IncludePadPixel) { @@ -878,7 +888,7 @@ TEST(PoolTest, AveragePool_IncludePadPixel) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } // test 'strides' attribute not specified @@ -897,7 +907,7 @@ TEST(PoolTest, AveragePool_DefaultStrides) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, AveragePool_10_ceil1_2d) { @@ -920,7 +930,8 @@ TEST(PoolTest, AveragePool_10_ceil1_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider}); } TEST(PoolTest, AveragePool_19_dilation_2d) { @@ -944,7 +955,7 @@ TEST(PoolTest, AveragePool_19_dilation_2d) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kAclExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kAclExecutionProvider, kOpenVINOExecutionProvider}); } TEST(PoolTest, GlobalAveragePool) { @@ -1020,7 +1031,7 @@ TEST(PoolTest, GlobalAveragePool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(PoolTest, GlobalAveragePool_Large_128) { @@ -1033,7 +1044,7 @@ TEST(PoolTest, GlobalAveragePool_Large_128) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals, /*sort_output=*/false, /*rel_error=*/1e-3f, /*abs_error=*/1e-2f); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(PoolTest, GlobalAveragePool_Large_256) { @@ -1046,7 +1057,7 @@ TEST(PoolTest, GlobalAveragePool_Large_256) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals, /*sort_output=*/false, /*rel_error=*/1e-3f, /*abs_error=*/1e-2f); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(PoolTest, LpPool) { @@ -1353,7 +1364,7 @@ TEST(PoolTest, LpPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } // test data generated with lp_pool_test_generator.py @@ -1385,7 +1396,7 @@ TEST(PoolTest, LpPool1d) { // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a94f434942252e6d98ac17705c06ce060 // TensorRT does not support 1d pooling - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); y_count++; } } @@ -1417,7 +1428,7 @@ TEST(PoolTest, LpPool2d) { test.AddAttribute("kernel_shape", kernel_sizes[kernel_size_count]); test.AddOutput("Y", y_sizes[y_count], ys[y_count]); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); y_count++; } } @@ -1435,7 +1446,7 @@ TEST(PoolTest, LpPoolCeilMode) { // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_network_definition.html#a94f434942252e6d98ac17705c06ce060 // TensorRT does not support 1d pooling - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(PoolTest, GlobalLpPool) { @@ -1690,7 +1701,7 @@ TEST(PoolTest, GlobalLpPool) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaNHWCExecutionProvider}); } TEST(PoolTest, MaxPoolDimWithZeroForN) { @@ -1707,7 +1718,8 @@ TEST(PoolTest, MaxPoolDimWithZeroForN) { test.AddInput("X", x_dims, x_vals); test.AddOutput("Y", expected_dims, expected_vals); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kQnnExecutionProvider}); } } // namespace test diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc index b0e0a0dd0d564..2902995df1e71 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc +++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc @@ -3541,6 +3541,7 @@ TEST(ReductionOpTest, ReduceDimWithZero1) { { kCoreMLExecutionProvider, kCudaExecutionProvider, + kCudaNHWCExecutionProvider, kDnnlExecutionProvider, kMIGraphXExecutionProvider, kOpenVINOExecutionProvider, @@ -3591,6 +3592,7 @@ TEST(ReductionOpTest, ReduceDimWithZero2) { { kCoreMLExecutionProvider, kCudaExecutionProvider, + kCudaNHWCExecutionProvider, kDnnlExecutionProvider, kMIGraphXExecutionProvider, kOpenVINOExecutionProvider, @@ -5779,6 +5781,7 @@ void test_empty_set(const std::string& op, int opset, bool axes_as_input, float { kCoreMLExecutionProvider, kCudaExecutionProvider, + kCudaNHWCExecutionProvider, kDmlExecutionProvider, kDnnlExecutionProvider, kMIGraphXExecutionProvider, diff --git a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc index 1a31743e2f7e7..38734ab9f668f 100644 --- a/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc +++ b/onnxruntime/test/providers/cpu/rnn/rnn_op_test.cc @@ -744,7 +744,9 @@ TEST(RNNTest, RNN_invalid_sequence_lens) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); // the CUDA RNN version allows the invalid sequence lengths, so disable testing on CUDA and TensorRT - test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectFailure, error_msg, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, + kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); }; // should batch batch_size to be valid @@ -842,7 +844,8 @@ TEST(RNNTest, RNN_bidirectional_with_sequence_lens) { test.AddOutput("Y_h", Y_h_dims, Y_h_data); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } TEST(RNNTest, RNN_with_invalid_activation_load_failure) { diff --git a/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc index 8a8bc5560c084..b4bd3fca7b712 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc @@ -383,7 +383,7 @@ TEST(GatherElementsOpTest, IndicesOutOfBounds) { // skip openvino which will not throw error message but will ensure no out-of-bound access // skip TensorRT because it doesn't support out of bounds indices test.Run(OpTester::ExpectResult::kExpectFailure, "", - {kCudaExecutionProvider, kRocmExecutionProvider, kOpenVINOExecutionProvider, + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kOpenVINOExecutionProvider, kTensorrtExecutionProvider, kDmlExecutionProvider}); } diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 5addb5dd9ce46..062f25b989a70 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -102,7 +102,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr // TensorRT: results mismatch // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_uint8) { @@ -132,7 +132,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_int8) { @@ -192,7 +193,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e // DML: results mismatch test.Run( OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8) { @@ -267,7 +268,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear) { // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { @@ -291,7 +292,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { @@ -439,7 +441,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_uin test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); }; run_test(false); @@ -539,7 +542,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe // ROCm: results mismatch // DML: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kDmlExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) { @@ -650,7 +653,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) { Y, false, .0f, 1.0f); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider}); }; run_test(false); @@ -1913,6 +1917,8 @@ void TestAntialiasing(std::map attributes, }); // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accuracy issue. excluded_eps.insert(kTensorrtExecutionProvider); + // Test is flaky on kCudaNHWCExecutionProvider + excluded_eps.insert(kCudaNHWCExecutionProvider); test.Run(OpTester::ExpectResult::kExpectSuccess, "", excluded_eps); } diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc index 30e27bb15fa57..b1dfec7951338 100644 --- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc @@ -268,7 +268,7 @@ static void scatter_invalid_index(const char* op_name, int op_version) { test.AddOutput("y", {4, 2, 1}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 5.0f, 0.0f}); test.Run(OpTester::ExpectResult::kExpectFailure, "indices element out of data bounds, idx=4 must be within the inclusive range [-4,3]", - {kCudaExecutionProvider, kTensorrtExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider}); } TEST(Scatter, InvalidIndex) { @@ -291,9 +291,10 @@ static void scatter_bool_with_axis_tests(const char* op_name, int op_version) { test.AddOutput("y", {1, 5}, {false, true, false, false, false}); #if defined(OPENVINO_CONFIG_GPU_FP32) || defined(OPENVINO_CONFIG_GPU_FP16) test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kOpenVINOExecutionProvider}); // OpenVINO: Disabled due to failure for GPU + {kCudaNHWCExecutionProvider, kOpenVINOExecutionProvider}); // OpenVINO: Disabled due to failure for GPU #else - test.Run(); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kCudaNHWCExecutionProvider}); // OpenVINO: Disabled due to failure for GPU #endif } diff --git a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc index 63b92cfc187bd..5222380d9ca56 100644 --- a/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/space_depth_ops_test.cc @@ -108,6 +108,53 @@ TEST(TensorOpTest, SpaceToDepthTest_2) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); } +TEST(TensorOpTest, SpaceToDepthTest_3) { + // Test swizzling with H_output > 1 + OpTester test("SpaceToDepth"); + constexpr int64_t blocksize = 2; + test.AddAttribute("blocksize", blocksize); + constexpr int64_t N = 1, C = 2, H = 4, W = 8; + + const std::vector X = { + 0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, + 1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, + + 2.0f, 2.1f, 2.2f, 2.3f, 2.4f, 2.5f, 2.6f, 2.7f, + 3.0f, 3.1f, 3.2f, 3.3f, 3.4f, 3.5f, 3.6f, 3.7f, + + 4.0f, 4.1f, 4.2f, 4.3f, 4.4f, 4.5f, 4.6f, 4.7f, + 5.0f, 5.1f, 5.2f, 5.3f, 5.4f, 5.5f, 5.6f, 5.7f, + 6.0f, 6.1f, 6.2f, 6.3f, 6.4f, 6.5f, 6.6f, 6.7f, + 7.0f, 7.1f, 7.2f, 7.3f, 7.4f, 7.5f, 7.6f, 7.7f}; + + test.AddInput("input", {N, C, H, W}, X); + + const std::vector result = { + 0.0f, 0.2f, 0.4f, 0.6f, + 2.0f, 2.2f, 2.4f, 2.6f, + 4.0f, 4.2f, 4.4f, 4.6f, + 6.0f, 6.2f, 6.4f, 6.6f, + + 0.1f, 0.3f, 0.5f, 0.7f, + 2.1f, 2.3f, 2.5f, 2.7f, + 4.1f, 4.3f, 4.5f, 4.7f, + 6.1f, 6.3f, 6.5f, 6.7f, + + 1.0f, 1.2f, 1.4f, 1.6f, + 3.0f, 3.2f, 3.4f, 3.6f, + 5.0f, 5.2f, 5.4f, 5.6f, + 7.0f, 7.2f, 7.4f, 7.6f, + + 1.1f, 1.3f, 1.5f, 1.7f, + 3.1f, 3.3f, 3.5f, 3.7f, + 5.1f, 5.3f, 5.5f, 5.7f, + 7.1f, 7.3f, 7.5f, 7.7f}; + + test.AddOutput("output", {N, C * blocksize * blocksize, H / blocksize, W / blocksize}, result); + + test.Run(); +} + TEST(TensorOpTest, DepthToSpaceTest_1) { OpTester test("DepthToSpace", 7); // create an opset 7 model constexpr int64_t blocksize = 2; diff --git a/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc b/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc index 72cb84d50f078..188532cfa350a 100644 --- a/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/upsample_op_test.cc @@ -692,7 +692,7 @@ TEST(UpsampleOpTest, NhwcUpsampleOp4D1CBilinearTest) { // TensorRT: results mismatch // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(UpsampleOpTest, NhwcUpsampleOp4DBilinearTest) { @@ -766,7 +766,7 @@ TEST(UpsampleOpTest, NhwcUpsampleOp4DBilinearTest) { // TensorRT: results mismatch // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOp2DBilinearTest) { @@ -886,7 +886,7 @@ TEST(UpsampleOpTest, NhwcUpsampleOp4DBilinearTest_int32) { // TensorRT: results mismatch // ROCm: results mismatch test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(UpsampleOpTest, UpsampleOpNearestTest_1D) {