From e525ea2295c455cbb14840b075a7935a089ecb94 Mon Sep 17 00:00:00 2001 From: derdeljan-msft Date: Thu, 28 Aug 2025 09:05:04 +0200 Subject: [PATCH 01/23] [CPU] Optimize GQA attention bias application for FP16 (#25871) ### Description When using attention bias input for GQA op with FP16, on the platforms that don't natively support FP16 math a cast to fp32 needs to be performed, and thus a temporary buffer needs to be created to store the fp32 values. The issue is that this temporary buffer was being allocated / deallocated inside of a loop for every token being processed. Refactored the implementation so that the allocation takes place only once. Phi model throughput increased by 15%. --- .../contrib_ops/cpu/bert/gqa_attention_base.h | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 0d5117709c18a..bfa450f4287f8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -280,6 +280,18 @@ class GQAAttentionBase { output, static_cast(present_buffer_sequence_length), nullptr); } + // Pre-allocate buffer for attention mask to avoid allocating it for every processed token + float* attention_bias_thread_fp32 = nullptr; + if (attention_bias_thread != nullptr) { + if constexpr (!std::is_same_v) { + static_assert(std::is_same_v && std::is_same_v); + + size_t bytes = attention_total_seqlen * sizeof(float); + attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes)); + } + } + BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator)); + // compute Softmax U* output_softmax = output; for (size_t seq = 0; seq < sequence_length; seq++) { @@ -316,9 +328,6 @@ class GQAAttentionBase { static_cast(window_size)); } else { static_assert(std::is_same_v && std::is_same_v); - size_t bytes = window_size * sizeof(float); - auto attention_bias_thread_fp32 = static_cast(allocator->Alloc(bytes)); - BufferUniquePtr scratch_buffer(attention_bias_thread_fp32, BufferDeleter(allocator)); MlasConvertHalfToFloatBuffer(attention_bias_thread + start_offset, attention_bias_thread_fp32, window_size); ApplyAttentionBias(output_softmax + start_offset, attention_bias_thread_fp32, static_cast(window_size)); From 574806bf9105d64f5f3d61cfb1e1793be38d05cf Mon Sep 17 00:00:00 2001 From: Jonathan Clohessy Date: Thu, 28 Aug 2025 17:14:31 +0100 Subject: [PATCH 02/23] Fixes for DynamicQuantizeMatMul and Attention3D tests (#25814) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This change fixes correctness issues in two areas that were causing failures in onnxruntime_test_all: - DynamicQuantizeMatMul.WithConstantBInputs - AttentionTest.Attention3DDefault - AttentionTest.Attention3DWithPastAndPresentQkMatmul What was wrong and how it’s fixed 1) DynamicQuantizeMatMul.WithConstantBInputs - Root cause: The Kleidi dynamic quantization GEMM path could be selected even when the B scales contained values such as (zero, negative, or non-finite). That violates kernel assumptions and can lead to incorrect results. - Fix: In `onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc`, we now explicitly validate that all B scales are finite and strictly positive before enabling the Kleidi/MLAS dynamic path. If any scale is invalid, we disable that path. 2) Attention tests (Attention3DDefault, Attention3DWithPastAndPresentQkMatmul) - Root causes in `onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp`: - Incorrect handling of GEMM corner cases for alpha/beta and K==0 (e.g., not respecting C = beta*C when alpha==0 or K==0). - Unnecessary or premature fallbacks for small shapes. - Fixes: - Add early-outs for degenerate sizes: if M==0 or N==0, return handled. - Correctly implement alpha/beta semantics: --------- Signed-off-by: Jonathan Clohessy --- .../quantization/dynamic_quantize_matmul.cc | 17 ++++- .../core/mlas/lib/kleidiai/sgemm_kleidiai.cpp | 73 +++++++++++-------- 2 files changed, 56 insertions(+), 34 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 85a2cbaea0e44..36a6f70cc69d9 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -200,6 +200,19 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available); + // Kleidi dynamic path requires strictly positive, finite scales. + // Disable if any invalid scale is detected. + if (can_use_dynamic_quant_mlas_) { + const auto bs = b_scale_tensor->DataAsSpan(); + const bool has_invalid = + std::any_of(bs.begin(), bs.end(), + [](float s) { return !std::isfinite(s) || s <= 0.0f; }); + + if (has_invalid) { + can_use_dynamic_quant_mlas_ = false; + } + } + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. // We check that here too before attempting to use them. if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) { @@ -379,7 +392,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { if (y->Shape().Size() == 0) return Status::OK(); - auto a_data = static_cast(ctx->Input(IN_A)->DataRaw()); + const float* a_data = ctx->Input(IN_A)->Data(); auto* y_data = y->MutableData(); // batch gemm @@ -393,7 +406,7 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { for (size_t gemm_idx = 0; gemm_idx < num_gemms; gemm_idx++) { auto& params = gemm_data_vec[gemm_idx]; - params.A = reinterpret_cast(a_data + helper.LeftOffsets()[gemm_idx]); + params.A = a_data + helper.LeftOffsets()[gemm_idx]; params.lda = gemm_shape.K; params.PackedB = packed_b_.get(); params.C = y_data + helper.OutputOffsets()[gemm_idx]; diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp index caa445b71e2a5..c579ff1542eb9 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -153,28 +153,23 @@ ArmKleidiAI::MlasGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - if(TransA == CblasTrans) - { - return false; + if (M == 0 || N == 0) { + return true; } - if (TransA == CblasNoTrans && K == 0) { - if (Data->beta != 1.0f) { + + if (Data->alpha == 0.0f || K == 0) { + if (Data->beta == 0.0f) { + for (size_t i = 0; i < M; ++i) { + std::fill_n(Data->C + i * Data->ldc, N, 0.0f); + } + } else if (Data->beta != 1.0f) { for (size_t i = 0; i < M; ++i) { for (size_t j = 0; j < N; ++j) { Data->C[i * Data->ldc + j] *= Data->beta; } } } - } - if (Data->beta == 0.0f){ - std::fill_n(Data->C, M * Data->ldc, 0.0f); - } - //Fallback in the case of unsupported cases - if (M == 0 || N == 0 || K == 0 || - TransA != CblasNoTrans || - (TransB != CblasNoTrans && !Data[0].BIsPacked)) - { - return false; + return true; } if (TransA == CblasNoTrans) { @@ -185,11 +180,9 @@ ArmKleidiAI::MlasGemmBatch( auto m_step = kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); auto n_step = kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(); - if (M < m_step || N < n_step) { - if (GetMlasPlatform().MlasGemmBatchOverride != ArmKleidiAI::MlasGemmBatch){ - //Fallback to MLAS - return false; - } + if (M < m_step && N < n_step && !Data->BIsPacked) { + // Fallback to MLAS + return false; } std::vector KaiPackedData; @@ -316,7 +309,7 @@ ArmKleidiAI::MlasGemmBatch( float* dst_tile = reinterpret_cast(CTile); // quick copy of data in cases where we are not scaling or accumulating anything - // with bounds checking on tile sizing to ensure the data fits in the memory block + // with bounds checking on tile sizing to ensure the data fits in the memory block bool can_memcpy = ( Data[BIdx].alpha == 1.0f && Data[BIdx].beta == 0.0f && @@ -328,21 +321,37 @@ ArmKleidiAI::MlasGemmBatch( if (can_memcpy) { std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float)); - }else { - // apply alpha scaling and beta to output files - for (size_t i = 0; i < TileSizeM; ++i) { - for (size_t j = 0; j < TileSizeN; ++j) { - const size_t idx = i * TileSizeN + j; - const size_t dst_idx = i * Data[BIdx].ldc + j; - - float ab = temp_tile[idx]; - float c_orig = dst_tile[dst_idx]; + return; + } - dst_tile[dst_idx] = Data[BIdx].alpha * ab + Data[BIdx].beta * c_orig; + float alpha = Data[BIdx].alpha; + float beta = Data[BIdx].beta; + size_t ldc = Data[BIdx].ldc; + + for (size_t i = 0; i < TileSizeM; ++i) { + for (size_t j = 0; j < TileSizeN; ++j) { + const size_t temp_idx = i * TileSizeN + j; + const size_t dst_idx = i * ldc + j; + + float ab = temp_tile[temp_idx]; + float c_orig = dst_tile[dst_idx]; + + if (alpha == 1.0f && beta == 0.0f) { + dst_tile[dst_idx] = ab; + } else if (alpha == 1.0f) { + dst_tile[dst_idx] = ab + beta * c_orig; + } else if (beta == 0.0f) { + dst_tile[dst_idx] = alpha * ab; + } else { + dst_tile[dst_idx] = alpha * ab + beta * c_orig; } } } + return; }); + return true; + } + else { + return false; } - return true; } From abe485ee02f6432cbc608c3b0f765e86df36e467 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane <111780983+apsonawane@users.noreply.github.com> Date: Thu, 28 Aug 2025 10:01:56 -0700 Subject: [PATCH 03/23] Fix MoE CPP tests (#25877) This change adds skip test for QMoE CPU tests when running on TensorRT or CUDA EP. In the QMoE kernel there was a memory overwrite bug in the accumulate part, updated that and this fixed the python tests back --- .../cpu/moe/moe_quantization_cpu.cc | 11 +++- onnxruntime/test/contrib_ops/moe_test.cc | 55 +++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 9b35a40f64f2a..5c6c3b919b572 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -331,7 +331,13 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const int64_t token_idx = route_idx / k_; const float weight = route_scale[route_idx]; - float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + token_idx * hidden_size; + const size_t buffer_offset = static_cast(token_idx) * static_cast(hidden_size); + if (buffer_offset + static_cast(hidden_size) > output_buffer_size) { + // Skip this token to prevent buffer overflow + continue; + } + + float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; for (int64_t j = 0; j < hidden_size; ++j) { dest[j] += weight * (src[j] + (B2_bias ? bias2_float[j] : 0.0f)); @@ -344,8 +350,9 @@ Status QMoECPU::Compute(OpKernelContext* context) const { auto accumulate = [&](float* buffer) { memset(buffer, 0, output_buffer_size * sizeof(float)); for (int i = 0; i < num_expert_threads; ++i) { + const size_t thread_offset = static_cast(i) * output_buffer_size; for (size_t j = 0; j < output_buffer_size; ++j) { - buffer[j] += thread_local_outputs[static_cast(i) * output_buffer_size + j]; + buffer[j] += thread_local_outputs[thread_offset + j]; } } }; diff --git a/onnxruntime/test/contrib_ops/moe_test.cc b/onnxruntime/test/contrib_ops/moe_test.cc index ed7ca998e0b86..0690b8894eb7a 100644 --- a/onnxruntime/test/contrib_ops/moe_test.cc +++ b/onnxruntime/test/contrib_ops/moe_test.cc @@ -144,6 +144,12 @@ static void RunQMoETest(const std::vector& input, const std::vector("k", static_cast(top_k)); cpu_tester.AddAttribute("activation_type", activation_type); @@ -1323,6 +1329,13 @@ TEST(MoETest, QMoETest_Mixtral_Int4) { // CPU-specific QMoE tests TEST(MoETest, QMoETest_CPU_Int4_MLAS) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + int num_rows = 2; int num_experts = 2; int hidden_size = 32; @@ -1387,9 +1400,19 @@ TEST(MoETest, QMoETest_CPU_Int4_MLAS) { std::vector> cpu_execution_providers; cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif } TEST(MoETest, QMoETest_CPU_Int8_MLAS) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + // Test CPU implementation with 8-bit quantization - CPU ONLY int num_rows = 1; int num_experts = 2; @@ -1446,9 +1469,19 @@ TEST(MoETest, QMoETest_CPU_Int8_MLAS) { std::vector> cpu_execution_providers; cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif } TEST(MoETest, QMoETest_CPU_FC3_Error) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + // Test that CPU throws error when FC3 gating is provided - CPU ONLY int num_rows = 1; int num_experts = 2; @@ -1506,9 +1539,19 @@ TEST(MoETest, QMoETest_CPU_FC3_Error) { // Expect this to fail with FC3 not implemented error cpu_tester.Run(OpTester::ExpectResult::kExpectFailure, "FC3 gating is not yet implemented", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif } TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } + // Test CPU implementation with 4-bit quantization and SwiGLU activation int num_rows = 2; int num_experts = 2; @@ -1573,9 +1616,18 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int4) { std::vector> cpu_execution_providers; cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif } TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { +#ifdef USE_MLAS + // Skip this test if we're not testing CPU execution provider + auto cpu_ep = DefaultCpuExecutionProvider(); + if (!cpu_ep) { + GTEST_SKIP() << "CPU execution provider not available"; + } // Test CPU implementation with 8-bit quantization and SwiGLU activation int num_rows = 1; int num_experts = 2; @@ -1633,6 +1685,9 @@ TEST(MoETest, QMoETest_CPU_SwiGLU_Int8) { std::vector> cpu_execution_providers; cpu_execution_providers.push_back(DefaultCpuExecutionProvider()); cpu_tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &cpu_execution_providers); +#else + GTEST_SKIP() << "Skipping CPU QMoE test"; +#endif } #endif From 179f371f8ffaafa64aade11e04a8471653859dd2 Mon Sep 17 00:00:00 2001 From: Christopher Warrington Date: Thu, 28 Aug 2025 13:35:48 -0400 Subject: [PATCH 04/23] [c++] Eliminate dynamic initialization of static Ort::Global::api_ (#25741) ### Description Delay the call to `OrtGetApiBase()` until the first call to `Ort::GetApi()` so that `OrtGetApiBase()` is typically called after dynamic library loading. ### Motivation and Context When ORT_API_MANUAL_INIT is not defined (which is the default), the static `Ort::Global::api_` has a dynamic initializer that calls `OrtGetApiBase()->GetApi(ORT_API_VERSION)` This dynamic initialization can cause problems when it interacts with other global/static initialization. On Windows in particular, it can also cause deadlocks when used in a dynamic library if OrtGetApiBase()->GetApi() attempts to load any other libraries. * Replace the templated `Global::api_` with an inline static initialized to nullptr. * `Ort::GetApi()` now calls `detail::Global::GetApi()` which calls `detail::Global::DefaultInit()` if initialization is needed. * When `ORT_API_MANUAL_INIT` is defined, `DefaultInit()` returns nullptr, which will eventually cause the program to crash. The callers have violated the initialization contract by not calling one of the `Ort::InitApi` overloads. * When `ORT_API_MANUAL_INIT` is not defined, `DefaultInit()` uses a function-level static to compute the result of `OrtGetApiBase()->GetApi(ORT_API_VERSION)` once and return it. * `Ort::Global` has been replaced with a non-templated type and moved inside a `detail` namespace. Since the `Global` object was documented as being used internally, it is believed that these changes here are non-breaking, as they do not impact a public API. The public APIs, `Ort::InitApi()` and `Ort::InitApi(const OrtApi*)` remain unchanged. * Add `#pragma detect_mismatch` to surface issues with compilation units that disagree on how ORT_API_MANUAL_INIT is defined. (MSVC only.) --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../core/session/onnxruntime_cxx_api.h | 115 ++++++++++++++---- js/node/src/inference_session_wrap.cc | 2 +- .../shared_library/provider_ort_api_init.cc | 4 +- .../core/providers/vitisai/imp/global_api.cc | 6 +- onnxruntime/test/autoep/library/ep_arena.h | 3 + .../custom_op_library/custom_op_library.cc | 2 +- 6 files changed, 99 insertions(+), 33 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index c39e27088e8bc..b460d753dd542 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -79,22 +79,19 @@ struct Exception : std::exception { throw Ort::Exception(string, code) #endif -// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, -// it's in a template so that we can define a global variable in a header and make -// it transparent to the users of the API. -template -struct Global { - static const OrtApi* api_; -}; - -// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. -template #ifdef ORT_API_MANUAL_INIT -const OrtApi* Global::api_{}; -inline void InitApi() noexcept { Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } - -// Used by custom operator libraries that are not linked to onnxruntime. Sets the global API object, which is -// required by C++ APIs. +// If the macro ORT_API_MANUAL_INIT is defined, no static initialization +// will be performed. Instead, users must call InitApi() before using the +// ORT C++ APIs.. +// +// InitApi() sets the global API object using the default initialization +// logic. Users call this to initialize the ORT C++ APIs at a time that +// makes sense in their program. +inline void InitApi() noexcept; + +// InitApi(const OrtApi*) is used by custom operator libraries that are not +// linked to onnxruntime. It sets the global API object, which is required +// by the ORT C++ APIs. // // Example mycustomop.cc: // @@ -107,22 +104,88 @@ inline void InitApi() noexcept { Global::api_ = OrtGetApiBase()->GetApi(OR // // ... // } // -inline void InitApi(const OrtApi* api) noexcept { Global::api_ = api; } -#else -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) -// "Global initializer calls a non-constexpr function." Therefore you can't use ORT APIs in the other global initializers. -// Please define ORT_API_MANUAL_INIT if it conerns you. -#pragma warning(disable : 26426) +inline void InitApi(const OrtApi* api) noexcept; #endif -const OrtApi* Global::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) + +namespace detail { +// This is used internally by the C++ API. This class holds the global +// variable that points to the OrtApi. +struct Global { + static const OrtApi* Api(const OrtApi* newValue = nullptr) noexcept { + // This block-level static will be initialized once when this function is + // first executed, delaying the call to DefaultInit() until it is first needed. + // + // When ORT_API_MANUAL_INIT is not defined, DefaultInit() calls + // OrtGetApiBase()->GetApi(), which may result in a shared library being + // loaded. + // + // Using a block-level static instead of a class-level static helps + // avoid issues with static initialization order and dynamic libraries + // loading other dynamic libraries. + // + // This makes it safe to include the C++ API headers in a shared library + // that is delay loaded or delay loads its dependencies. + // + // This DOES NOT make it safe to _use_ arbitrary ORT C++ APIs when + // initializing static members, however. + static const OrtApi* api = DefaultInit(); + + if (newValue) { + api = newValue; + } + + return api; + } + + private: + // Has different definitions based on ORT_API_MANUAL_INIT + static const OrtApi* DefaultInit() noexcept; + +#ifdef ORT_API_MANUAL_INIT + // Public APIs to set the OrtApi* to use. + friend void ::Ort::InitApi() noexcept; + friend void ::Ort::InitApi(const OrtApi*) noexcept; #endif +}; +} // namespace detail + +#ifdef ORT_API_MANUAL_INIT + +// See comments on declaration above for usage. +inline void InitApi(const OrtApi* api) noexcept { detail::Global::Api(api); } +inline void InitApi() noexcept { InitApi(OrtGetApiBase()->GetApi(ORT_API_VERSION)); } + +#ifdef _MSC_VER +// If you get a linker error about a mismatch here, you are trying to +// link two compilation units that have different definitions for +// ORT_API_MANUAL_INIT together. All compilation units must agree on the +// definition of ORT_API_MANUAL_INIT. +#pragma detect_mismatch("ORT_API_MANUAL_INIT", "enabled") +#endif + +inline const OrtApi* detail::Global::DefaultInit() noexcept { + // When ORT_API_MANUAL_INIT is defined, there's no default init that can + // be done. + return nullptr; +} + +#else // ORT_API_MANUAL_INIT + +#ifdef _MSC_VER +// If you get a linker error about a mismatch here, you are trying to link +// two compilation units that have different definitions for +// ORT_API_MANUAL_INIT together. All compilation units must agree on the +// definition of ORT_API_MANUAL_INIT. +#pragma detect_mismatch("ORT_API_MANUAL_INIT", "disabled") #endif +inline const OrtApi* detail::Global::DefaultInit() noexcept { + return OrtGetApiBase()->GetApi(ORT_API_VERSION); +} +#endif // ORT_API_MANUAL_INIT + /// This returns a reference to the ORT C API. -inline const OrtApi& GetApi() noexcept { return *Global::api_; } +inline const OrtApi& GetApi() noexcept { return *detail::Global::Api(); } /// /// This function returns the onnxruntime version string diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index 84ed3457a488b..8db91f792cb06 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -15,7 +15,7 @@ Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { // create ONNX runtime env Ort::InitApi(); ORT_NAPI_THROW_ERROR_IF( - Ort::Global::api_ == nullptr, env, + &Ort::GetApi() == nullptr, env, "Failed to initialize ONNX Runtime API. It could happen when this nodejs binding was built with a higher version " "ONNX Runtime but now runs with a lower version ONNX Runtime DLL(or shared library)."); diff --git a/onnxruntime/core/providers/shared_library/provider_ort_api_init.cc b/onnxruntime/core/providers/shared_library/provider_ort_api_init.cc index 9fa2551e53c23..f8d88b07f6dd5 100644 --- a/onnxruntime/core/providers/shared_library/provider_ort_api_init.cc +++ b/onnxruntime/core/providers/shared_library/provider_ort_api_init.cc @@ -24,7 +24,7 @@ std::once_flag init; } // namespace void InitProviderOrtApi() { - std::call_once(init, []() { Ort::Global::api_ = Provider_GetHost()->OrtGetApiBase()->GetApi(ORT_API_VERSION); }); + std::call_once(init, []() { Ort::InitApi(Provider_GetHost()->OrtGetApiBase()->GetApi(ORT_API_VERSION)); }); } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 5fc0b8900730b..580fbfbdba0b0 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -229,7 +229,7 @@ int vitisai_ep_set_ep_dynamic_options( struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { op_kernel_ = - op_.CreateKernel(&op_, Ort::Global::api_, reinterpret_cast(&info)); + op_.CreateKernel(&op_, &Ort::GetApi(), reinterpret_cast(&info)); } ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } @@ -332,8 +332,8 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { InitProviderOrtApi(); set_version_info(the_global_api); the_global_api.host_ = Provider_GetHost(); - assert(Ort::Global::api_ != nullptr); - the_global_api.ort_api_ = Ort::Global::api_; + assert(&Ort::GetApi() != nullptr); + the_global_api.ort_api_ = &Ort::GetApi(); the_global_api.model_load = [](const std::string& filename) -> Model* { auto model_proto = ONNX_NAMESPACE::ModelProto::Create(); auto& logger = logging::LoggingManager::DefaultLogger(); diff --git a/onnxruntime/test/autoep/library/ep_arena.h b/onnxruntime/test/autoep/library/ep_arena.h index 641f3ce3f7b17..caa2c61db835f 100644 --- a/onnxruntime/test/autoep/library/ep_arena.h +++ b/onnxruntime/test/autoep/library/ep_arena.h @@ -21,7 +21,10 @@ limitations under the License. #include #include +#define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + #include "ep_allocator.h" #include "example_plugin_ep_utils.h" diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc index 8ab58adbeeb74..bc22864304567 100644 --- a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +++ b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc @@ -26,7 +26,7 @@ static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) { } OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { - Ort::Global::api_ = api->GetApi(ORT_API_VERSION); + Ort::InitApi(api->GetApi(ORT_API_VERSION)); OrtStatus* result = nullptr; ORT_TRY { From 3563f2e52775c13ee94abb09c69a4c5084d33a3c Mon Sep 17 00:00:00 2001 From: Ishwar Raut Date: Thu, 28 Aug 2025 23:06:41 +0530 Subject: [PATCH 05/23] python GPU IO Bindings for NVIDIA (#25776) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description 1. A Small change to use the shared allocator in Python binding. 2. Remove the FP64 support from the EP. ### Motivation and Context The Python GPU IO binding is necessary for performance. The change will enable the shared allocator for GPU allocation. The FP64 was using the FP32 inference—aligned WRT TRT RTX support. --------- Co-authored-by: Gaurav Garg --- .../nv_tensorrt_rtx/nv_execution_provider.cc | 151 +----- .../nv_tensorrt_rtx/nv_execution_provider.h | 2 - .../python/onnxruntime_pybind_state.cc | 2 +- ...me_test_python_nv_tensorrt_rtx_ep_tests.py | 468 ++++++++++++++++++ 4 files changed, 482 insertions(+), 141 deletions(-) create mode 100644 onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index b7997ce86737a..93b673f2df5bd 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -20,7 +20,6 @@ #include "onnx_ctx_model_helper.h" #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/cuda_graph.h" -#include "core/providers/cuda/math/unary_elementwise_ops_impl.h" #include "core/session/allocator_adapters.h" #include "cuda_runtime_api.h" #include "core/common/parse_string.h" @@ -85,40 +84,6 @@ struct ShutdownProtobuf { namespace onnxruntime { -namespace cuda { -template <> -void Impl_Cast( - cudaStream_t stream, - const int64_t* input_data, int32_t* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} - -template <> -void Impl_Cast( - cudaStream_t stream, - const int32_t* input_data, int64_t* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} - -template <> -void Impl_Cast( - cudaStream_t stream, - const double* input_data, float* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} - -template <> -void Impl_Cast( - cudaStream_t stream, - const float* input_data, double* output_data, - size_t count) { - return g_host->cuda__Impl_Cast(static_cast(stream), input_data, output_data, count); -} -} // namespace cuda - void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept { // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr @@ -372,51 +337,19 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(); \ - skip_input_binding_allowed = false; \ - if (input_tensor_ptr != nullptr && elem_cnt > 0) { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ - data = scratch_buffers.back().get(); \ - cuda::Impl_Cast(stream, input_tensor_ptr, reinterpret_cast(data), elem_cnt); \ - } else { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - data = scratch_buffers.back().get(); \ - } \ - break; \ - } - #define CASE_GET_OUTPUT_TENSOR(DATA_TYPE, SrcT) \ case DATA_TYPE: { \ auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ data_ptr = output_tensor_ptr; \ if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - buffers[output_name] = output_tensor_ptr; \ + buffer = output_tensor_ptr; \ } else { \ scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - buffers[output_name] = scratch_buffers.back().get(); \ + buffer = scratch_buffers.back().get(); \ } \ break; \ } -#define CASE_GET_CAST_OUTPUT_TENSOR(DATA_TYPE, SrcT, DstT) \ - case DATA_TYPE: { \ - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ - data_ptr = output_tensor_ptr; \ - skip_output_binding_allowed = false; \ - if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, elem_cnt * sizeof(DstT))); \ - buffers[output_name] = scratch_buffers.back().get(); \ - output_dim_sizes[i] = static_cast(elem_cnt); \ - } else { \ - scratch_buffers.push_back(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, 1)); \ - buffers[output_name] = scratch_buffers.back().get(); \ - output_dim_sizes[i] = 1; \ - } \ - break; \ - } - #define CASE_COPY_TENSOR(DATA_TYPE, DstT) \ case DATA_TYPE: { \ auto output_tensor_ptr = output_tensor.GetTensorMutableData(); \ @@ -426,15 +359,6 @@ bool ApplyProfileShapesFromProviderOptions(std::vector(); \ - if (output_tensor_ptr != nullptr && elem_cnt > 0) { \ - cuda::Impl_Cast(stream, reinterpret_cast(allocator->getBuffer()), reinterpret_cast(output_tensor_ptr), elem_cnt); \ - } \ - break; \ - } - /* * Set Nv executio context input. * @@ -557,7 +481,6 @@ Status BindContextInput(Ort::KernelContext& ctx, CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) - CASE_GET_CAST_INPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP input onnx tensor data type: " + std::to_string(tensor_type) + " not supported."); @@ -582,8 +505,6 @@ Status BindContextInput(Ort::KernelContext& ctx, * param output_type - Data type of the output * param i - Output iteration index * param output_tensors - Output iteration index to output's ORT value - * param output_dim_sizes - Output iteration index to the multiplocation of its shape's dimensions - * param dds_output_set - DDS output set * param dds_output_allocator_map - DDS output to its allocator * param scratch_buffer - The allocation buffer created by TRT EP * param allocator - ORT allocator @@ -595,16 +516,11 @@ Status BindContextOutput(Ort::KernelContext& ctx, const char* output_name, size_t output_index, size_t output_type, - size_t i, - std::unordered_map& output_tensors, - std::unordered_map& output_dim_sizes, DDSOutputAllocatorMap& dds_output_allocator_map, std::vector>& scratch_buffers, OrtAllocator* alloc, - std::unordered_map& buffers, nvinfer1::Dims& dims, - void*& data_ptr, - bool& skip_output_binding_allowed) { + void*& data_ptr) { // Get output shape dims = trt_context->getTensorShape(output_name); int nb_dims = dims.nbDims; @@ -634,10 +550,11 @@ Status BindContextOutput(Ort::KernelContext& ctx, data_ptr = nullptr; // Set data_ptr to nullptr for DDS output binding. } } else { - output_tensors[i] = ctx.GetOutput(output_index, dims.d, nb_dims); - auto& output_tensor = output_tensors[i]; + auto output_tensor = ctx.GetOutput(output_index, dims.d, nb_dims); const auto elem_cnt = output_tensor.GetTensorTypeAndShapeInfo().GetElementCount(); + void* buffer = nullptr; + switch (output_type) { // below macros set data_ptr and skip_output_binding_allowed variables CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, float) @@ -648,13 +565,12 @@ Status BindContextOutput(Ort::KernelContext& ctx, CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_GET_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) - CASE_GET_CAST_OUTPUT_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, double, float) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); } } - trt_context->setTensorAddress(output_name, buffers[output_name]); + trt_context->setTensorAddress(output_name, buffer); } return Status::OK(); @@ -711,7 +627,6 @@ Status BindKernelOutput(Ort::KernelContext& ctx, CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, uint8_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, int32_t) CASE_COPY_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, int64_t) - CASE_CAST_TENSOR(ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, float, double) default: { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP output tensor data type: " + std::to_string(output_type) + " not supported."); @@ -2837,7 +2752,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr } // Save TRT engine, other TRT objects and input/output info to map - parsers_.emplace(fused_node.Name(), std::move(trt_parser)); engines_.emplace(fused_node.Name(), std::move(trt_engine)); contexts_.emplace(fused_node.Name(), std::move(trt_context)); networks_.emplace(fused_node.Name(), std::move(trt_network)); @@ -2853,7 +2767,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) { std::unique_ptr p = std::make_unique(); *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), - &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], + &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], &tensorrt_mu_, engine_cache_enable_, cache_path_, @@ -2891,7 +2805,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; - int num_outputs = static_cast(output_indexes.size()); std::unordered_set input_names; if (alloc_ == nullptr) { @@ -2966,16 +2879,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr /* * Set output shapes and bind output buffers */ - std::unordered_map buffers; - buffers.reserve(num_outputs); - using OutputOrtValue = Ort::UnownedValue; - std::unordered_map output_tensors; - output_tensors.reserve(num_outputs); - std::unordered_map output_dim_sizes; - output_dim_sizes.reserve(num_outputs); - if (require_io_binding) { - bool skip_output_binding_allowed = true; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -2993,16 +2897,15 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr nvinfer1::Dims dims; void* data_ptr = nullptr; - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_allocator_map, scratch_buffers, alloc, buffers, dims, data_ptr, skip_output_binding_allowed); + + Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, + dds_output_allocator_map, scratch_buffers, alloc, dims, data_ptr); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } trt_state->output_tensors[output_index] = TensorParams{data_ptr, dims}; } - - trt_state->skip_io_binding_allowed = trt_state->skip_io_binding_allowed | skip_output_binding_allowed; } // Set execution context memory @@ -3082,14 +2985,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } - } else { - auto& output_tensor = output_tensors[i]; - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); - } - } } } @@ -3213,7 +3108,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name]; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); - int num_outputs = static_cast(output_indexes.size()); std::unordered_map> shape_tensor_values; // This map holds "shape tensor -> shape values" for the shape tensor input across this inference run std::unordered_map> shape_tensor_values_int64; // same as above but for int64 shape tensor input @@ -3283,16 +3177,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra /* * Set output shapes and bind output buffers */ - std::unordered_map buffers; - buffers.reserve(num_outputs); - using OutputOrtValue = Ort::UnownedValue; - std::unordered_map output_tensors; - output_tensors.reserve(num_outputs); - std::unordered_map output_dim_sizes; - output_dim_sizes.reserve(num_outputs); - if (require_io_binding) { - bool skip_output_binding_allowed = true; for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) { char const* output_name = output_binding_names[i]; @@ -3311,16 +3196,14 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra nvinfer1::Dims dims; void* data_ptr = nullptr; - Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes, - dds_output_allocator_map, scratch_buffers, alloc, buffers, dims, data_ptr, skip_output_binding_allowed); + Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, + dds_output_allocator_map, scratch_buffers, alloc, dims, data_ptr); if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } trt_state->output_tensors[output_index] = TensorParams{data_ptr, dims}; } - - trt_state->skip_io_binding_allowed = trt_state->skip_io_binding_allowed | skip_output_binding_allowed; } // Set execution context memory @@ -3401,14 +3284,6 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra if (status != Status::OK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage()); } - } else { - auto& output_tensor = output_tensors[i]; - if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) { - auto output_tensor_ptr = output_tensor.GetTensorMutableData(); - if (output_tensor_ptr != nullptr) { - cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]); - } - } } } diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 22b8314649757..9e5fd03756f02 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -195,7 +195,6 @@ struct TensorrtFuncState { AllocatorHandle allocator = nullptr; std::string fused_node_name; nvinfer1::IBuilder* builder; - tensorrt_ptr::unique_pointer* parser = nullptr; std::unique_ptr* engine = nullptr; std::unique_ptr* context = nullptr; std::unique_ptr* network = nullptr; @@ -386,7 +385,6 @@ class NvExecutionProvider : public IExecutionProvider { // In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client. // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. - std::unordered_map> parsers_; std::unordered_map> engines_; std::unordered_map> contexts_; std::unordered_map> builders_; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 24554560b4dde..9679da7cea2ff 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1782,7 +1782,7 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra type = OrtDevice::GPU; vendor = OrtDevice::VendorIds::MICROSOFT; } else if (type == OrtDevice::GPU) { -#if USE_CUDA +#if USE_CUDA || USE_NV || USE_NV_PROVIDER_INTERFACE || USE_CUDA_PROVIDER_INTERFACE vendor = OrtDevice::VendorIds::NVIDIA; #elif USE_ROCM || USE_MIGRAPHX vendor = OrtDevice::VendorIds::AMD; diff --git a/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py b/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py new file mode 100644 index 0000000000000..d5c80a4a1f4ba --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_nv_tensorrt_rtx_ep_tests.py @@ -0,0 +1,468 @@ +# Copyright (c) NVIDIA Corporation. All rights reserved. +# Licensed under the MIT License. +from __future__ import annotations + +import sys +import unittest +from collections.abc import Sequence + +import numpy as np +import torch +from autoep_helper import AutoEpTestCase +from helper import get_name +from numpy.testing import assert_almost_equal +from onnx import TensorProto, helper +from onnx.defs import onnx_opset_version + +import onnxruntime as onnxrt +from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice +from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue +from onnxruntime.capi._pybind_state import OrtValueVector, SessionIOBinding + + +class TestNvTensorRTRTXAutoEP(AutoEpTestCase): + """ + Test suite for the NvTensorRTRTX Execution Provider. + + This class contains tests for registering the NvTensorRTRTX EP, + selecting it using different policies, and running inference with various + I/O binding configurations. + """ + + ep_lib_path = "onnxruntime_providers_nv_tensorrt_rtx.dll" + ep_name = "NvTensorRTRTXExecutionProvider" + + def setUp(self): + if sys.platform != "win32": + self.skipTest("Skipping test because device discovery is only supported on Windows") + self.register_execution_provider_library(self.ep_name, self.ep_lib_path) + + def tearDown(self): + self.unregister_execution_provider_library(self.ep_name) + + def _create_ortvalue_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_numpy( + np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), device, 0 + ) + + def _create_ortvalue_alternate_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_numpy( + np.array([[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]], dtype=np.float32), + device, + 0, + ) + + def _create_uninitialized_ortvalue_input_on_gpu(self, device): + return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, device, 0) + + def _create_numpy_input(self): + return np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + + def _create_expected_output(self): + return np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + + def _create_expected_output_alternate(self): + return np.array([[2.0, 8.0], [18.0, 32.0], [50.0, 72.0]], dtype=np.float32) + + def torch_to_onnx_type(self, torch_dtype): + if torch_dtype == torch.float32: + return TensorProto.FLOAT + elif torch_dtype == torch.float16: + return TensorProto.FLOAT16 + elif torch_dtype == torch.bfloat16: + return TensorProto.BFLOAT16 + elif torch_dtype == torch.int8: + return TensorProto.int8 + elif torch_dtype == torch.int32: + return TensorProto.INT32 + elif torch_dtype == torch.int64: + return TensorProto.INT64 + else: + raise TypeError(f"Unsupported dtype: {torch_dtype}") + + def test_nv_tensorrt_rtx_ep_register_and_inference(self): + """ + Test registration of NvTensorRTRTX EP, adding its OrtDevice to the SessionOptions, and running inference. + """ + ep_devices = onnxrt.get_ep_devices() + nv_tensorrt_rtx_ep_device = next((d for d in ep_devices if d.ep_name == self.ep_name), None) + self.assertIsNotNone(nv_tensorrt_rtx_ep_device) + self.assertEqual(nv_tensorrt_rtx_ep_device.ep_vendor, "NVIDIA") + + hw_device = nv_tensorrt_rtx_ep_device.device + self.assertEqual(hw_device.type, onnxrt.OrtHardwareDeviceType.GPU) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx")) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + def test_nv_tensorrt_rtx_ep_prefer_gpu_and_inference(self): + """ + Test selecting NvTensorRTRTX EP via the PREFER_GPU policy and running inference. + """ + # Set a policy to prefer GPU. NvTensorRTRTX should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + def test_nv_tensorrt_rtx_ep_selection_delegate_and_inference(self): + """ + Test selecting NvTensorRTRTX EP via the custom EP selection delegate function and then run inference. + """ + + # User's custom EP selection function. + def my_delegate( + ep_devices: Sequence[onnxrt.OrtEpDevice], + model_metadata: dict[str, str], + runtime_metadata: dict[str, str], + max_selections: int, + ) -> Sequence[onnxrt.OrtEpDevice]: + self.assertGreater(len(model_metadata), 0) + self.assertGreaterEqual(len(ep_devices), 1) + self.assertGreaterEqual(max_selections, 2) + + nv_tensorrt_rtx_ep_device = next((d for d in ep_devices if d.ep_name == self.ep_name), None) + self.assertIsNotNone(nv_tensorrt_rtx_ep_device) + + # Select the NvTensorRTRTX device + return [nv_tensorrt_rtx_ep_device] + + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy_delegate(my_delegate) + self.assertTrue(sess_options.has_providers()) + + # Run sample model and check output + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + input_name = sess.get_inputs()[0].name + res = sess.run([], {input_name: x}) + output_expected = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) + np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) + + def test_bind_input_only(self): + """ + Test I/O binding with input data only. + """ + # Set a policy to prefer GPU. NvTensorRTRTX should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + input = self._create_ortvalue_input_on_gpu("cuda") + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + # Bind input to the GPU + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Bind output to CPU + io_binding.bind_output("Y") + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host + # here) + ort_output = io_binding.copy_outputs_to_cpu()[0] + + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output)) + + def test_bind_input_and_bind_output_with_ortvalues(self): + """ + Test I/O binding with OrtValues for both input and output. + """ + # Set a policy to prefer GPU. NvTensorRTRTX EP should be selected. + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + # Bind ortvalue as input + input_ortvalue = self._create_ortvalue_input_on_gpu("cuda") + io_binding.bind_ortvalue_input("X", input_ortvalue) + + # Bind ortvalue as output + output_ortvalue = self._create_uninitialized_ortvalue_input_on_gpu("cuda") + io_binding.bind_ortvalue_output("Y", output_ortvalue) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Inspect contents of output_ortvalue and make sure that it has the right contents + self.assertTrue(np.array_equal(self._create_expected_output(), output_ortvalue.numpy())) + + # Bind another ortvalue as input + input_ortvalue_2 = self._create_ortvalue_alternate_input_on_gpu("cuda") + io_binding.bind_ortvalue_input("X", input_ortvalue_2) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Inspect contents of output_ortvalue and make sure that it has the right contents + self.assertTrue(np.array_equal(self._create_expected_output_alternate(), output_ortvalue.numpy())) + + def test_bind_input_and_non_preallocated_output(self): + """ + Test I/O binding with non-preallocated output. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + input = self._create_ortvalue_input_on_gpu("cuda") + + # Bind input to the GPU + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + + # Bind output to the GPU + io_binding.bind_output("Y", "cuda") + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # This call returns an OrtValue which has data allocated by ORT on the GPU + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), "cuda") + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) + + # We should be able to repeat the above process as many times as we want - try once more + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), "cuda") + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output(), ort_outputs[0].numpy())) + + input = self._create_ortvalue_alternate_input_on_gpu("cuda") + + # Change the bound input and validate the results in the same bound OrtValue + # Bind alternate input to the GPU + io_binding.bind_input( + "X", + "cuda", + 0, + np.float32, + [3, 2], + input.data_ptr(), + ) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # This call returns an OrtValue which has data allocated by ORT on the GPU + ort_outputs = io_binding.get_outputs() + self.assertEqual(len(ort_outputs), 1) + self.assertEqual(ort_outputs[0].device_name(), "cuda") + # Validate results (by copying results to CPU by creating a Numpy object) + self.assertTrue(np.array_equal(self._create_expected_output_alternate(), ort_outputs[0].numpy())) + + def test_bind_input_and_preallocated_output(self): + """ + Test I/O binding with preallocated output. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + + input = self._create_ortvalue_input_on_gpu("cuda") + + session = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=sess_options) + io_binding = session.io_binding() + + # Bind input to the GPU + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) + + # Bind output to the GPU + output = self._create_uninitialized_ortvalue_input_on_gpu("cuda") + io_binding.bind_output("Y", "cuda", 0, np.float32, [3, 2], output.data_ptr()) + + # Sync if different streams + io_binding.synchronize_inputs() + + # Invoke Run + session.run_with_iobinding(io_binding) + + # Sync if different streams + io_binding.synchronize_outputs() + + # Get outputs over to CPU (the outputs which were bound to the GPU will get copied over to the host + # here) + ort_output_vals = io_binding.copy_outputs_to_cpu()[0] + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals)) + + # Validate if ORT actually wrote to pre-allocated buffer by copying the allocated buffer + # to the host and validating its contents + ort_output_vals_in_cpu = output.numpy() + # Validate results + self.assertTrue(np.array_equal(self._create_expected_output(), ort_output_vals_in_cpu)) + + def test_bind_input_types(self): + """ + Test I/O binding with various input data types. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + opset = onnx_opset_version() + device = C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) + + for dtype in [ + np.float32, + # np.float64, + np.int32, + # np.uint32, + np.int64, + # np.uint64, + # np.int16, + # np.uint16, + # np.int8, + np.uint8, + np.float16, + np.bool_, + ]: + with self.subTest(dtype=dtype, inner_device=str(device)): + x = np.arange(8).reshape((-1, 2)).astype(dtype) + proto_dtype = helper.np_dtype_to_tensor_dtype(x.dtype) + + X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) # noqa: N806 + Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # noqa: N806 + + # inference + node_add = helper.make_node("Identity", ["X"], ["Y"]) + + # graph + graph_def = helper.make_graph([node_add], "lr", [X], [Y], []) + model_def = helper.make_model( + graph_def, + producer_name="dummy", + ir_version=7, + producer_version="0", + opset_imports=[helper.make_operatorsetid("", opset)], + ) + + sess = onnxrt.InferenceSession(model_def.SerializeToString(), sess_options=sess_options) + + bind = SessionIOBinding(sess._sess) + ort_value = C_OrtValue.ortvalue_from_numpy(x, device) + bind.bind_ortvalue_input("X", ort_value) + bind.bind_output("Y", device) + sess._sess.run_with_iobinding(bind, None) + ortvaluevector = bind.get_outputs() + self.assertIsInstance(ortvaluevector, OrtValueVector) + ortvalue = bind.get_outputs()[0] + y = ortvalue.numpy() + assert_almost_equal(x, y) + + bind = SessionIOBinding(sess._sess) + bind.bind_input("X", device, dtype, x.shape, ort_value.data_ptr()) + bind.bind_output("Y", device) + sess._sess.run_with_iobinding(bind, None) + ortvalue = bind.get_outputs()[0] + y = ortvalue.numpy() + assert_almost_equal(x, y) + + def test_bind_onnx_types_from_torch(self): + """ + Test I/O binding with various input data types. + """ + sess_options = onnxrt.SessionOptions() + sess_options.set_provider_selection_policy(onnxrt.OrtExecutionProviderDevicePolicy.PREFER_GPU) + self.assertTrue(sess_options.has_providers()) + opset = onnx_opset_version() + + for dtype in [ + torch.float32, + torch.float16, + torch.bfloat16, + torch.int32, + torch.int64, + ]: + with self.subTest(dtype=dtype): + proto_dtype = self.torch_to_onnx_type(dtype) + + x_ = helper.make_tensor_value_info("X", proto_dtype, [None]) + y_ = helper.make_tensor_value_info("Y", proto_dtype, [None]) + node_add = helper.make_node("Identity", ["X"], ["Y"]) + graph_def = helper.make_graph([node_add], "lr", [x_], [y_], []) + model_def = helper.make_model( + graph_def, + producer_name="dummy", + ir_version=10, + producer_version="0", + opset_imports=[helper.make_operatorsetid("", opset)], + ) + sess = onnxrt.InferenceSession(model_def.SerializeToString(), sess_options=sess_options) + + dev = "cuda" if torch.cuda.is_available() else "cpu" + device = ( + C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0) + if dev == "cuda" + else C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0) + ) + + x = torch.arange(8, dtype=dtype, device=dev) + y = torch.empty(8, dtype=dtype, device=dev) + + bind = SessionIOBinding(sess._sess) + bind.bind_input("X", device, proto_dtype, x.shape, x.data_ptr()) + bind.bind_output("Y", device, proto_dtype, y.shape, y.data_ptr()) + sess._sess.run_with_iobinding(bind, None) + self.assertTrue(torch.equal(x, y)) + + +if __name__ == "__main__": + unittest.main(verbosity=1) From 47f355aa5c67ff78bc3c99b950b362028f8bc3ed Mon Sep 17 00:00:00 2001 From: Xinpeng Dou <15529241576@163.com> Date: Fri, 29 Aug 2025 03:07:15 +0800 Subject: [PATCH 06/23] [CANN] Add a `enable_cann_subgraph` feature parameter (#25867) ### Description Add a `enable_cann_subgraph` feature parameter. this parameter controls whether graph splitting is performed and can help quickly identify issues in certain scenarios. --- .../core/providers/cann/cann_provider_options.h | 2 ++ .../core/providers/cann/cann_execution_provider.cc | 9 ++++----- .../core/providers/cann/cann_execution_provider_info.cc | 4 ++++ .../core/providers/cann/cann_execution_provider_info.h | 1 + onnxruntime/core/providers/cann/cann_provider_factory.cc | 2 ++ onnxruntime/core/session/provider_bridge_ort.cc | 1 + 6 files changed, 14 insertions(+), 5 deletions(-) diff --git a/include/onnxruntime/core/providers/cann/cann_provider_options.h b/include/onnxruntime/core/providers/cann/cann_provider_options.h index 51b423e68110a..4b33ee77a892e 100644 --- a/include/onnxruntime/core/providers/cann/cann_provider_options.h +++ b/include/onnxruntime/core/providers/cann/cann_provider_options.h @@ -15,6 +15,8 @@ struct OrtCANNProviderOptions { onnxruntime::ArenaExtendStrategy arena_extend_strategy; // Strategy used to grow the memory arena int enable_cann_graph; // Flag indicating if prioritizing the use of // CANN's graph-running capabilities + int enable_cann_subgraph; // Flag indicating whether to generate subgraph + // automaticly int dump_graphs; // Flag indicating if dumping graphs int dump_om_model; // Flag indicating if dumping om model std::string precision_mode; // Operator Precision Mode diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 4bcf71335d15e..06c3628eb301d 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1266,17 +1266,16 @@ CANNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewe // the single operator operation mode of CANN if (info_.enable_cann_graph) { std::vector&& unsupported_nodes = SupportONNXModel(graph_viewer); - - if (unsupported_nodes.empty()) { - auto sub_graph = GetSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer); - result.push_back(ComputeCapability::Create(std::move(sub_graph))); - } else { + if (info_.enable_cann_subgraph && !unsupported_nodes.empty()) { auto partitions = GetSubGraphPartition(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes); for (const auto& partition : partitions) { auto sub_graph = GetSubGraph(partition, graph_viewer); result.push_back(ComputeCapability::Create(std::move(sub_graph))); } + } else { + auto sub_graph = GetSubGraph(graph_viewer.GetNodesInTopologicalOrder(), graph_viewer); + result.push_back(ComputeCapability::Create(std::move(sub_graph))); } } else { InlinedVector candidates; diff --git a/onnxruntime/core/providers/cann/cann_execution_provider_info.cc b/onnxruntime/core/providers/cann/cann_execution_provider_info.cc index d1ba7544bc09e..d6cf9fad70ae5 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider_info.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider_info.cc @@ -20,6 +20,7 @@ constexpr const char* kDeviceId = "device_id"; constexpr const char* kMemLimit = "npu_mem_limit"; constexpr const char* kArenaExtendStrategy = "arena_extend_strategy"; constexpr const char* kEnableCannGraph = "enable_cann_graph"; +constexpr const char* kEnableCannSubGraph = "enable_cann_subgraph"; constexpr const char* kDumpGraphs = "dump_graphs"; constexpr const char* kDumpOmModel = "dump_om_model"; constexpr const char* kPrecisionMode = "precision_mode"; @@ -58,6 +59,7 @@ CANNExecutionProviderInfo CANNExecutionProviderInfo::FromProviderOptions(const P cann::provider_option_names::kArenaExtendStrategy, arena_extend_strategy_mapping, info.arena_extend_strategy) .AddAssignmentToReference(cann::provider_option_names::kEnableCannGraph, info.enable_cann_graph) + .AddAssignmentToReference(cann::provider_option_names::kEnableCannSubGraph, info.enable_cann_subgraph) .AddAssignmentToReference(cann::provider_option_names::kDumpGraphs, info.dump_graphs) .AddAssignmentToReference(cann::provider_option_names::kDumpOmModel, info.dump_om_model) .AddAssignmentToReference(cann::provider_option_names::kPrecisionMode, info.precision_mode) @@ -74,6 +76,7 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const CANNExecution {cann::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, info.arena_extend_strategy)}, {cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}, + {cann::provider_option_names::kEnableCannSubGraph, MakeStringWithClassicLocale(info.enable_cann_subgraph)}, {cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)}, {cann::provider_option_names::kDumpOmModel, MakeStringWithClassicLocale(info.dump_om_model)}, {cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)}, @@ -89,6 +92,7 @@ ProviderOptions CANNExecutionProviderInfo::ToProviderOptions(const OrtCANNProvid {cann::provider_option_names::kArenaExtendStrategy, EnumToName(arena_extend_strategy_mapping, ArenaExtendStrategy(info.arena_extend_strategy))}, {cann::provider_option_names::kEnableCannGraph, MakeStringWithClassicLocale(info.enable_cann_graph)}, + {cann::provider_option_names::kEnableCannSubGraph, MakeStringWithClassicLocale(info.enable_cann_subgraph)}, {cann::provider_option_names::kDumpGraphs, MakeStringWithClassicLocale(info.dump_graphs)}, {cann::provider_option_names::kDumpOmModel, MakeStringWithClassicLocale(info.dump_om_model)}, {cann::provider_option_names::kPrecisionMode, MakeStringWithClassicLocale(info.precision_mode)}, diff --git a/onnxruntime/core/providers/cann/cann_execution_provider_info.h b/onnxruntime/core/providers/cann/cann_execution_provider_info.h index 7ac43e9a8ed6f..9c1f9eb03b67e 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider_info.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider_info.h @@ -18,6 +18,7 @@ struct CANNExecutionProviderInfo { size_t npu_mem_limit{std::numeric_limits::max()}; ArenaExtendStrategy arena_extend_strategy{ArenaExtendStrategy::kNextPowerOfTwo}; bool enable_cann_graph{true}; + bool enable_cann_subgraph{false}; bool dump_graphs{false}; bool dump_om_model{true}; std::string precision_mode; diff --git a/onnxruntime/core/providers/cann/cann_provider_factory.cc b/onnxruntime/core/providers/cann/cann_provider_factory.cc index 4a130b9b0ca20..d3dc86f588f1d 100644 --- a/onnxruntime/core/providers/cann/cann_provider_factory.cc +++ b/onnxruntime/core/providers/cann/cann_provider_factory.cc @@ -76,6 +76,7 @@ struct CANN_Provider : Provider { info.npu_mem_limit = params->npu_mem_limit; info.arena_extend_strategy = params->arena_extend_strategy; info.enable_cann_graph = params->enable_cann_graph != 0; + info.enable_cann_subgraph = params->enable_cann_subgraph != 0; info.dump_graphs = params->dump_graphs != 0; info.dump_om_model = params->dump_om_model != 0; info.precision_mode = params->precision_mode; @@ -94,6 +95,7 @@ struct CANN_Provider : Provider { cann_options.npu_mem_limit = internal_options.npu_mem_limit; cann_options.arena_extend_strategy = internal_options.arena_extend_strategy; cann_options.enable_cann_graph = internal_options.enable_cann_graph; + cann_options.enable_cann_subgraph = internal_options.enable_cann_subgraph; cann_options.dump_graphs = internal_options.dump_graphs; cann_options.dump_om_model = internal_options.dump_om_model; cann_options.precision_mode = internal_options.precision_mode; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 41cf8be1d1412..f82cbcf63ca62 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2902,6 +2902,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateCANNProviderOptions, _Outptr_ OrtCANNProvider options->npu_mem_limit = SIZE_MAX; options->arena_extend_strategy = static_cast(0); options->enable_cann_graph = 1; + options->enable_cann_subgraph = 0; options->dump_graphs = 0; options->dump_om_model = 1; options->default_memory_arena_cfg = nullptr; From 1eb18f189a7e1d7776a2d7e59e8e083dd79a040d Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Thu, 28 Aug 2025 17:21:20 -0700 Subject: [PATCH 07/23] [EP ABI] Add OpAttr_GetTensorAttributeAsOrtValue and replace the existing Node_GetTensorAttributeAsOrtValue (#25886) ### Description Replace `Node_GetTensorAttributeAsOrtValue` with `OpAttr_GetTensorAttributeAsOrtValue`. Change the API signature to make it one of the `OpAttr` interfaces instead of the `OrtNode` interface. The original API was added [here](https://github.com/microsoft/onnxruntime/pull/25566). --- .../core/providers/utils/ort_graph_to_proto.h | 8 ++-- .../core/session/onnxruntime_c_api.h | 3 +- onnxruntime/core/graph/abi_graph_types.h | 10 ----- onnxruntime/core/graph/ep_api_types.cc | 26 ------------- onnxruntime/core/graph/ep_api_types.h | 3 -- .../core/graph/model_editor_api_types.h | 5 --- onnxruntime/core/session/onnxruntime_c_api.cc | 38 +++++++++++++++++-- onnxruntime/core/session/ort_apis.h | 2 +- 8 files changed, 41 insertions(+), 54 deletions(-) diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 21aa797ce16eb..28ce4439fdc7e 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -232,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_ /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, onnx::GraphProto& graph_proto, @@ -379,7 +379,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, } onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); } } @@ -652,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { +static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { const OrtApi& ort_api = Ort::GetApi(); const char* attr_name = nullptr; @@ -766,7 +766,7 @@ static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& or // TensorProto as an attribute value doesn't require a name. OrtValue* ort_value = nullptr; - ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.OpAttr_GetTensorAttributeAsOrtValue(&ort_attr, &ort_value)); Ort::Value tensor(ort_value); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9ae6174817b7c..f137d88e5fb8a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6079,7 +6079,6 @@ struct OrtApi { /** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue. * - * \param[in] node The OrtNode instance. * \param[in] attribute The OrtOpAttr instance. * \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue. Must be freed with OrtApi::ReleaseValue. @@ -6088,7 +6087,7 @@ struct OrtApi { * * \since Version 1.23. */ - ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, + ORT_API2_STATUS(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor); /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index b99c22edb36c8..2ef7c4a9091f3 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -252,16 +252,6 @@ struct OrtNode { /// A status indicating success or an error. virtual onnxruntime::Status GetAttributes(gsl::span attrs) const = 0; - /// - /// Gets the node's 'TENSOR' attribute as an OrtValue. - /// - /// Node's 'TENSOR' attribute. - /// Output parameter is set to a newly created OrtValue containing the 'TENSOR' attribute value, - /// only if the attribute is of type 'TENSOR' - /// A status indicating success or an error. - virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr, - OrtValue*& value) const = 0; - /// /// Gets the number of node subgraphs. /// diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 759a2998ace3a..0d9b93631ee8a 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -249,32 +249,6 @@ Status EpNode::GetAttributes(gsl::span dst) const { return Status::OK(); } -Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue*& result) const { - const auto* attr_proto = reinterpret_cast(attribute); - - if (attr_proto->type() != onnx::AttributeProto::TENSOR) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute"); - } - - const auto& graph_viewer = ep_graph_->GetGraphViewer(); - const auto& tensor_proto = attr_proto->t(); - - // Check that TensorProto is valid. - ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type."); - ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type."); - ORT_ENFORCE(!utils::HasExternalData(tensor_proto), - "Tensor proto with external data for value attribute is not supported."); - - // Initialize OrtValue for tensor attribute. - auto tensor_attribute_value = std::make_unique(); - AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); - ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, - tensor_attribute_allocator, *tensor_attribute_value)); - - result = tensor_attribute_value.release(); - return Status::OK(); -} - Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { num_subgraphs = subgraphs_.size(); return Status::OK(); diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 7f22e265129f7..e003f02a79a2d 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -183,9 +183,6 @@ struct EpNode : public OrtNode { // Gets the node's attributes. Status GetAttributes(gsl::span attrs) const override; - Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, - OrtValue*& attr_tensor) const override; - // Gets the number of subgraphs contained by this node. Status GetNumSubgraphs(size_t& num_subgraphs) const override; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index e7ffcbc7e4c90..2c0f6d6174303 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -138,11 +138,6 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } - Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue*& /*attr_tensor*/) const override { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, - "OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode"); - } - Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index ad0a1ad137f06..f3e2a8ce7ba7b 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3036,7 +3036,7 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { +ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { API_IMPL_BEGIN if (attr_tensor == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null"); @@ -3045,7 +3045,39 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNo return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); } - ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); + const auto* attr_proto = reinterpret_cast(attribute); + + if (attr_proto->type() != onnx::AttributeProto::TENSOR) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "This OrtOpAttr instance is not a 'TENSOR' attribute"); + } + + const auto& tensor_proto = attr_proto->t(); + + // Check that TensorProto is valid. + if (!utils::HasDataType(tensor_proto)) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Tensor proto doesn't have data type."); + } + + if (!ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type())) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Tensor proto has invalid data type."); + } + + if (utils::HasExternalData(tensor_proto)) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "Tensor proto with external data for value attribute is not supported."); + } + + // Initialize OrtValue for tensor attribute. + auto tensor_attribute_value = std::make_unique(); + AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); + // The tensor in the 'Tensor' attribute's TensorProto is stored inline, not in an external file. + // Therefore, the 'model_path' passed to TensorProtoToOrtValue() may be an empty path. + std::filesystem::path model_path; + ORT_API_RETURN_IF_STATUS_NOT_OK(utils::TensorProtoToOrtValue(Env::Default(), model_path, tensor_proto, + tensor_attribute_allocator, *tensor_attribute_value)); + + *attr_tensor = tensor_attribute_value.release(); + return nullptr; API_IMPL_END } @@ -4134,7 +4166,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumAttributes, &OrtApis::Node_GetAttributes, &OrtApis::Node_GetAttributeByName, - &OrtApis::Node_GetTensorAttributeAsOrtValue, + &OrtApis::OpAttr_GetTensorAttributeAsOrtValue, &OrtApis::OpAttr_GetType, &OrtApis::OpAttr_GetName, &OrtApis::Node_GetNumSubgraphs, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index e62149d04a16c..6dc4cf9d195cc 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -687,7 +687,7 @@ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_result_maybenull_ const OrtOpAttr** attribute); -ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, +ORT_API_STATUS_IMPL(OpAttr_GetTensorAttributeAsOrtValue, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); From 820554e29a4479f259464a5878ce4c9d225b2358 Mon Sep 17 00:00:00 2001 From: adrastogi Date: Thu, 28 Aug 2025 17:42:06 -0700 Subject: [PATCH 08/23] Language bindings for model compatibility API (#25878) ### Description This change builds on top of #25841 , and adds the scaffolding necessary to call into this API from C++ / C# / Python. ### Motivation and Context #25454 talks more about the broader notion of precompiled model compatibility. This change is directed at app developers whose apps may want to determine if a particular precompiled model (e.g. on a server somewhere) is compatible with the device where the application is running. There is functionality in `OrtEpFactory` for making this determination, which was exposed as a C API in #25841, and this change makes the API more broadly available in other languages. ### Testing and Validation Introduced new unit test cases across each language, and verified that the API was being called and returned the correct result for the default CPU EP. --------- Co-authored-by: Aditya Rastogi --- .../NativeMethods.shared.cs | 98 +++++++++++++++++++ .../Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs | 40 ++++++++ .../EpCompatibilityTests.cs | 49 ++++++++++ .../core/session/onnxruntime_cxx_api.h | 10 ++ .../core/session/onnxruntime_cxx_inline.h | 20 ++++ .../python/onnxruntime_pybind_state.cc | 17 ++++ .../test/framework/ep_compatibility_test.cc | 29 ++++++ ...nnxruntime_test_python_ep_compatibility.py | 46 +++++++++ 8 files changed, 309 insertions(+) create mode 100644 csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs create mode 100644 onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 8cca2b42e987a..3c92400715740 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -368,6 +368,88 @@ public struct OrtApi public IntPtr EpDevice_Device; public IntPtr GetEpApi; public IntPtr GetTensorSizeInBytes; + + public IntPtr AllocatorGetStats; + + public IntPtr CreateMemoryInfo_V2; + public IntPtr MemoryInfoGetDeviceMemType; + public IntPtr MemoryInfoGetVendorId; + + public IntPtr ValueInfo_GetValueProducer; + public IntPtr ValueInfo_GetValueNumConsumers; + public IntPtr ValueInfo_GetValueConsumers; + public IntPtr ValueInfo_GetInitializerValue; + public IntPtr ValueInfo_GetExternalInitializerInfo; + public IntPtr ValueInfo_IsRequiredGraphInput; + public IntPtr ValueInfo_IsOptionalGraphInput; + public IntPtr ValueInfo_IsGraphOutput; + public IntPtr ValueInfo_IsConstantInitializer; + public IntPtr ValueInfo_IsFromOuterScope; + public IntPtr Graph_GetName; + public IntPtr Graph_GetModelPath; + public IntPtr Graph_GetOnnxIRVersion; + public IntPtr Graph_GetNumOperatorSets; + public IntPtr Graph_GetOperatorSets; + public IntPtr Graph_GetNumInputs; + public IntPtr Graph_GetInputs; + public IntPtr Graph_GetNumOutputs; + public IntPtr Graph_GetOutputs; + public IntPtr Graph_GetNumInitializers; + public IntPtr Graph_GetInitializers; + public IntPtr Graph_GetNumNodes; + public IntPtr Graph_GetNodes; + public IntPtr Graph_GetParentNode; + public IntPtr Graph_GetGraphView; + public IntPtr Node_GetId; + public IntPtr Node_GetName; + public IntPtr Node_GetOperatorType; + public IntPtr Node_GetDomain; + public IntPtr Node_GetSinceVersion; + public IntPtr Node_GetNumInputs; + public IntPtr Node_GetInputs; + public IntPtr Node_GetNumOutputs; + public IntPtr Node_GetOutputs; + public IntPtr Node_GetNumImplicitInputs; + public IntPtr Node_GetImplicitInputs; + public IntPtr Node_GetNumAttributes; + public IntPtr Node_GetAttributes; + public IntPtr Node_GetAttributeByName; + public IntPtr Node_GetTensorAttributeAsOrtValue; + public IntPtr OpAttr_GetType; + public IntPtr OpAttr_GetName; + public IntPtr Node_GetNumSubgraphs; + public IntPtr Node_GetSubgraphs; + public IntPtr Node_GetGraph; + public IntPtr Node_GetEpName; + public IntPtr ReleaseExternalInitializerInfo; + public IntPtr ExternalInitializerInfo_GetFilePath; + public IntPtr ExternalInitializerInfo_GetFileOffset; + public IntPtr ExternalInitializerInfo_GetByteSize; + + public IntPtr GetRunConfigEntry; + + public IntPtr EpDevice_MemoryInfo; + + public IntPtr CreateSharedAllocator; + public IntPtr GetSharedAllocator; + public IntPtr ReleaseSharedAllocator; + + public IntPtr GetTensorData; + + public IntPtr GetSessionOptionsConfigEntries; + + public IntPtr SessionGetMemoryInfoForInputs; + public IntPtr SessionGetMemoryInfoForOutputs; + public IntPtr SessionGetEpDeviceForInputs; + + public IntPtr CreateSyncStreamForEpDevice; + public IntPtr SyncStream_GetHandle; + public IntPtr ReleaseSyncStream; + + public IntPtr CopyTensors; + + public IntPtr Graph_GetModelMetadata; + public IntPtr GetModelCompatibilityForEpDevices; } internal static class NativeMethods @@ -704,6 +786,10 @@ static NativeMethods() (DSessionOptionsSetEpSelectionPolicyDelegate)Marshal.GetDelegateForFunctionPointer( api_.SessionOptionsSetEpSelectionPolicyDelegate, typeof(DSessionOptionsSetEpSelectionPolicyDelegate)); + + OrtGetModelCompatibilityForEpDevices = (DOrtGetModelCompatibilityForEpDevices)Marshal.GetDelegateForFunctionPointer( + api_.GetModelCompatibilityForEpDevices, + typeof(DOrtGetModelCompatibilityForEpDevices)); } internal class NativeLib @@ -2456,6 +2542,18 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, public static DOrtGetEpDevices OrtGetEpDevices; + /// + /// Validate compiled model compatibility for the provided EP devices. + /// + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtGetModelCompatibilityForEpDevices( + IntPtr[] /* const OrtEpDevice* const* */ ep_devices, + UIntPtr /* size_t */ num_ep_devices, + byte[] /* const char* */ compatibility_info, + out int /* OrtCompiledModelCompatibility */ out_status); + + public static DOrtGetModelCompatibilityForEpDevices OrtGetModelCompatibilityForEpDevices; + /// /// Add execution provider devices to the session options. /// Priority is based on the order of the OrtEpDevice instances. Highest priority first. diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs index 5c70808b82be1..052d5899b52c0 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs @@ -7,6 +7,21 @@ namespace Microsoft.ML.OnnxRuntime { + /// + /// Represents the compatibility status of a pre-compiled model with one or more execution provider devices. + /// + /// + /// This enum is used to determine whether a pre-compiled model can be used with specific execution providers + /// and devices, or if recompilation is needed. + /// + public enum OrtCompiledModelCompatibility + { + EP_NOT_APPLICABLE = 0, + EP_SUPPORTED_OPTIMAL = 1, + EP_SUPPORTED_PREFER_RECOMPILATION = 2, + EP_UNSUPPORTED = 3, + } + /// /// Delegate for logging function callback. /// Supply your function and register it with the environment to receive logging callbacks via @@ -361,6 +376,31 @@ public string[] GetAvailableProviders() } } + /// + /// Validate a compiled model's compatibility information for one or more EP devices. + /// + /// The list of EP devices to validate against. + /// The compatibility string from the precompiled model to validate. + /// OrtCompiledModelCompatibility enum value denoting the compatibility status + public OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( + IReadOnlyList epDevices, string compatibilityInfo) + { + if (epDevices == null || epDevices.Count == 0) + throw new ArgumentException("epDevices must be non-empty", nameof(epDevices)); + + var devicePtrs = new IntPtr[epDevices.Count]; + for (int i = 0; i < epDevices.Count; ++i) + { + devicePtrs[i] = epDevices[i].Handle; + } + + var infoUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(compatibilityInfo); + NativeApiStatus.VerifySuccess( + NativeMethods.OrtGetModelCompatibilityForEpDevices( + devicePtrs, (UIntPtr)devicePtrs.Length, infoUtf8, out int status)); + return (OrtCompiledModelCompatibility)status; + } + /// /// Get/Set log level property of OrtEnv instance diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs new file mode 100644 index 0000000000000..103fe5bc10106 --- /dev/null +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/EpCompatibilityTests.cs @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// not supported on mobile platforms +#if !(ANDROID || IOS) + +namespace Microsoft.ML.OnnxRuntime.Tests; + +using System; +using System.Linq; +using Xunit; +using System.Collections.Generic; + +public class EpCompatibilityTests +{ + private readonly OrtEnv ortEnvInstance = OrtEnv.Instance(); + + private IReadOnlyList GetDevices() + { + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotNull(epDevices); + Assert.NotEmpty(epDevices); + return epDevices; + } + + [Fact] + public void GetEpCompatibility_InvalidArgs() + { + Assert.Throws(() => ortEnvInstance.GetModelCompatibilityForEpDevices(null, "info")); + Assert.Throws(() => ortEnvInstance.GetModelCompatibilityForEpDevices(new List(), "info")); + } + + [Fact] + public void GetEpCompatibility_SingleDeviceCpuProvider() + { + var devices = GetDevices(); + var someInfo = "arbitrary-compat-string"; + + // Use CPU device + var cpu = devices.First(d => d.EpName == "CPUExecutionProvider"); + Assert.NotNull(cpu); + var selected = new List { cpu }; + var status = ortEnvInstance.GetModelCompatibilityForEpDevices(selected, someInfo); + + // CPU defaults to not applicable in this scenario + Assert.Equal(OrtCompiledModelCompatibility.EP_NOT_APPLICABLE, status); + } +} +#endif diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index b460d753dd542..13675ab447ab1 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1076,6 +1076,16 @@ struct EpDevice : detail::EpDeviceImpl { ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {}); }; +/** \brief Validate a compiled model's compatibility for one or more EP devices. + * + * Throws on error. Returns the resulting compatibility status. + * /// \param ep_devices The EP devices to check compatibility against. + * /// \param compatibility_info The compatibility string from the precompiled model to validate. + */ +OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( + const std::vector& ep_devices, + const char* compatibility_info); + /** \brief The Env (Environment) * * The Env holds the logging state used by all other objects. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index d0089726812a3..05c86ae4e0c58 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -859,6 +859,26 @@ inline void CustomOpDomain::Add(const OrtCustomOp* op) { ThrowOnError(GetApi().CustomOpDomain_Add(p_, op)); } +inline OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( + const std::vector& ep_devices, + const char* compatibility_info) { + if (ep_devices.empty()) { + ORT_CXX_API_THROW("ep_devices is empty", ORT_INVALID_ARGUMENT); + } + + std::vector ptrs; + ptrs.reserve(ep_devices.size()); + for (const auto& d : ep_devices) ptrs.push_back(d); + + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + ThrowOnError(GetApi().GetModelCompatibilityForEpDevices( + reinterpret_cast(ptrs.data()), + ptrs.size(), + compatibility_info, + &status)); + return status; +} + inline LoraAdapter LoraAdapter::CreateLoraAdapter(const std::basic_string& adapter_path, OrtAllocator* allocator) { OrtLoraAdapter* p; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 9679da7cea2ff..eb06a65ad5330 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1575,6 +1575,17 @@ void addGlobalMethods(py::module& m) { R"pbdoc(Get the list of available OrtEpDevice instances.)pbdoc", py::return_value_policy::reference); + m.def( + "get_model_compatibility_for_ep_devices", + [](const std::vector& ep_devices, + const std::string& compatibility_info) -> OrtCompiledModelCompatibility { + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + Ort::ThrowOnError(Ort::GetApi().GetModelCompatibilityForEpDevices( + ep_devices.data(), ep_devices.size(), compatibility_info.c_str(), &status)); + return status; + }, + R"pbdoc("Validate a compiled model's compatibility information for one or more EP devices.)pbdoc"); + #if defined(USE_OPENVINO) || defined(USE_OPENVINO_PROVIDER_INTERFACE) m.def( "get_available_openvino_device_ids", []() -> std::vector { @@ -1759,6 +1770,12 @@ void addObjectMethods(py::module& m, ExecutionProviderRegistrationFn ep_registra .value("PRIORITY_BASED", ExecutionOrder::PRIORITY_BASED) .value("MEMORY_EFFICIENT", ExecutionOrder::MEMORY_EFFICIENT); + py::enum_(m, "OrtCompiledModelCompatibility") + .value("EP_NOT_APPLICABLE", OrtCompiledModelCompatibility_EP_NOT_APPLICABLE) + .value("EP_SUPPORTED_OPTIMAL", OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL) + .value("EP_SUPPORTED_PREFER_RECOMPILATION", OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION) + .value("EP_UNSUPPORTED", OrtCompiledModelCompatibility_EP_UNSUPPORTED); + py::enum_(m, "OrtAllocatorType") .value("INVALID", OrtInvalidAllocator) .value("ORT_DEVICE_ALLOCATOR", OrtDeviceAllocator) diff --git a/onnxruntime/test/framework/ep_compatibility_test.cc b/onnxruntime/test/framework/ep_compatibility_test.cc index ee82d4683ab73..a8a83fbe5ceb6 100644 --- a/onnxruntime/test/framework/ep_compatibility_test.cc +++ b/onnxruntime/test/framework/ep_compatibility_test.cc @@ -15,6 +15,7 @@ #include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/utils.h" #include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" #include "core/session/abi_session_options_impl.h" #include "core/framework/error_code_helper.h" #include "dummy_provider.h" @@ -499,3 +500,31 @@ TEST(EpCompatibilityCapiTest, CpuEpReturnsNotApplicableIfNoValidation) { api->ReleaseEnv(env); } + +// ----------------------------- +// C++ API unit tests +// ----------------------------- + +TEST(EpCompatibilityCxxApiTest, SingleDeviceCpuProvider) { + Ort::Env env{ORT_LOGGING_LEVEL_WARNING, "EpCompatCxx"}; + auto devices = env.GetEpDevices(); + ASSERT_FALSE(devices.empty()); + + std::vector selected; + for (const auto& d : devices) { + if (std::string{d.EpName()} == "CPUExecutionProvider") { + selected.push_back(d); + break; + } + } + + ASSERT_FALSE(selected.empty()); + + // Pick a status that the CPU EP would never return to ensure the value is set correctly. + OrtCompiledModelCompatibility status = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; + ASSERT_NO_FATAL_FAILURE({ + status = Ort::GetModelCompatibilityForEpDevices(selected, "arbitrary-compat-string"); + }); + + ASSERT_TRUE(status == OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); +} \ No newline at end of file diff --git a/onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py b/onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py new file mode 100644 index 0000000000000..8e69fdf088103 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_ep_compatibility.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import platform +import sys +import unittest + +from onnxruntime.capi.onnxruntime_pybind11_state import ( + OrtCompiledModelCompatibility, + get_ep_devices, + get_model_compatibility_for_ep_devices, +) + +# handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. +if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: # noqa: YTT204 + os.add_dll_directory(os.getcwd()) + + +class TestEpCompatibility(unittest.TestCase): + def test_invalid_args(self): + # empty devices + with self.assertRaises(RuntimeError): + get_model_compatibility_for_ep_devices([], "info") + # None compatibility info should raise TypeError before native call + with self.assertRaises(TypeError): + get_model_compatibility_for_ep_devices(get_ep_devices(), None) # type: ignore[arg-type] + + def test_basic_smoke(self): + devices = list(get_ep_devices()) + if not devices: + self.skipTest("No EP devices available in this build") + + # Always select CPUExecutionProvider; skip if not present. + cpu_devices = [d for d in devices if getattr(d, "ep_name", None) == "CPUExecutionProvider"] + if not cpu_devices: + self.skipTest("CPUExecutionProvider not available in this build") + selected = [cpu_devices[0]] + + # API requires all devices belong to the same EP; we pass only one. + status = get_model_compatibility_for_ep_devices(selected, "arbitrary-compat-string") + self.assertEqual(status, OrtCompiledModelCompatibility.EP_NOT_APPLICABLE) + + +if __name__ == "__main__": + unittest.main() From 4754a1d64e5920a715b0396906f339e6c15742a0 Mon Sep 17 00:00:00 2001 From: qti-hungjuiw Date: Fri, 29 Aug 2025 12:02:47 +0800 Subject: [PATCH 09/23] [QNN-EP] Introduce Level1 Transformer into qnn.preprocess (#25883) ### Description - Introduce Level1 Transformer into qnn.preprocess to support various optimizations. ### Motivation and Context - This change brings in several useful optimizations such as `ConvBnFusion` and `ConstantFolding`, which are part of `TransformerLevel::Level1` and can benefit QNNEP. - The goal is to optimize the ONNX model before quantization by integrating these passes into the Python tooling workflow. --- .../execution_providers/qnn/preprocess.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py index 191edc4c6390d..a12aca47f5b65 100644 --- a/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py +++ b/onnxruntime/python/tools/quantization/execution_providers/qnn/preprocess.py @@ -6,15 +6,15 @@ from __future__ import annotations import logging +import tempfile from pathlib import Path import onnx -from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed +from ....tools.onnx_model_utils import fix_output_shapes, make_input_shape_fixed, optimize_model from ....tools.remove_initializer_from_input import remove_initializer_from_input from ...fusions import FusionGelu, FusionLayerNormalization from ...onnx_model import ONNXModel -from ...quant_utils import save_and_reload_model_with_shape_infer from .fusion_lpnorm import FusionLpNormalization from .fusion_spacetodepth import FusionSpaceToDepth @@ -93,7 +93,7 @@ def qnn_preprocess_model( """ modified = False model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input) - model = save_and_reload_model_with_shape_infer(model) + model = save_and_reload_optimize_model(model, shape_infer=True) onnx_model = ONNXModel(model) # Optionally, fix the dynamic input shapes. @@ -178,6 +178,24 @@ def qnn_preprocess_model( return modified +def save_and_reload_optimize_model(model: onnx.ModelProto, shape_infer: bool) -> onnx.ModelProto: + with tempfile.TemporaryDirectory(prefix="ort.qnn_preproc.") as qnn_preproc_tmp_dir: + model_in_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_input.onnx") + onnx.save_model(model, model_in_path, save_as_external_data=True) + if shape_infer: + model_infer_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_infer.onnx") + onnx.shape_inference.infer_shapes_path(str(model_in_path), str(model_infer_path)) + model_in_path = model_infer_path + model_out_path = Path(qnn_preproc_tmp_dir).joinpath("qnn_proc_output.onnx") + optimize_model(model_in_path, model_out_path) + ret_model = onnx.load_model(model_out_path) + ret_metaprops = {"onnx.infer": "onnxruntime.tools.qnn.preprocess"} + if ret_model.metadata_props: + ret_metaprops.update(ret_model.metadata_props) + onnx.helper.set_model_props(ret_model, ret_metaprops) + return ret_model + + class InputOutputNameMap: def __init__( self, From 3fc9779df89f3e35b5f93454e1b588f64660a649 Mon Sep 17 00:00:00 2001 From: qti-yuduo Date: Fri, 29 Aug 2025 08:35:29 -0700 Subject: [PATCH 10/23] [QNN EP] Minor fix weight name missing when not valid QDQ node group (#25887) ### Description Minor fix weight name missing when not valid QDQ node group ### Motivation and Context Some quantized model failed QDQ node group validation, the weights then won't be folded as initializer. QNN EP failed to handle the dynamic weights here due to the transpose op input name look up. This change make sure we process the weights tensor before adding transposes. --- .../qnn/builder/opbuilder/conv_op_builder.cc | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc index 541ca5ca7ab14..a994c936970f6 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/conv_op_builder.cc @@ -245,6 +245,12 @@ Status ConvOpBuilder::ProcessConv2D3DInputs(QnnModelWrapper& qnn_model_wrapper, bool is_graph_input = qnn_model_wrapper.IsGraphInput(input1_name); LOGS(logger, VERBOSE) << "Add HWCN Transpose node after input: " << input1_name; + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(input1_name)) { + QnnTensorWrapper weight_tensor_wrapper; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[1], weight_tensor_wrapper)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(weight_tensor_wrapper)), "Failed to add weight tensor."); + } + if (conv_type == OnnxConvType::kConv) { ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddNchwToHwcnTranspose(node_unit.Index(), input1_name, @@ -425,7 +431,7 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, // // Input 1: weight - // We need to first reshape the weight inorder to handle 1D convolutions with the Conv2d operator. + // We need to first reshape the weight in order to handle 1D convolutions with the Conv2d operator. // Next, we have to transpose the weight because ORT layout transformations do not change the weight layout. // { @@ -511,6 +517,12 @@ Status ConvOpBuilder::ProcessConv1DInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF(input_info.quant_param.IsPerChannel(), "Non-constant Conv inputs only support per-tensor quantization"); + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(input1_name)) { + QnnTensorWrapper weight_tensor_wrapper; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[1], weight_tensor_wrapper)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(weight_tensor_wrapper)), "Failed to add weight tensor."); + } + bool is_graph_input = qnn_model_wrapper.IsGraphInput(input1_name); LOGS(logger, VERBOSE) << "Adding Reshape (to 2D) and HWCN Transpose node after input: " << input1_name; ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(input1_name, From ca77b7ed063091b71e54ffa066942e480c26e041 Mon Sep 17 00:00:00 2001 From: Pradeep Sakhamoori Date: Fri, 29 Aug 2025 11:21:48 -0500 Subject: [PATCH 11/23] Add custom ops library_path to EP metadata (#25830) ## Summary Adds EP metadata library path support to enable custom ops DLL registration with proper path resolution. ## Changes - Added `library_path` metadata key to EP metadata infrastructure - Pass resolved library path directly to `EpLibraryProviderBridge` constructor - Simplified implementation per reviewer feedback (removed virtual method complexity) - Added `#include ` for std::move compliance ## Purpose Enables downstream applications (like onnxruntime-genai) to resolve relative custom ops library paths using EP metadata, improving DLL registration reliability. ## Files Modified - `plugin_ep/ep_factory_provider_bridge.h` - `plugin_ep/ep_library.h` - `plugin_ep/ep_library_plugin.h` - `plugin_ep/ep_library_provider_bridge.cc` - `plugin_ep/ep_library_provider_bridge.h` - `utils.cc` --- .../onnxruntime_ep_device_ep_metadata_keys.h | 5 ++++- .../plugin_ep/ep_factory_provider_bridge.cc | 7 +++++++ .../plugin_ep/ep_factory_provider_bridge.h | 15 +++++++++++---- onnxruntime/core/session/plugin_ep/ep_library.h | 1 + .../plugin_ep/ep_library_provider_bridge.cc | 4 +++- .../plugin_ep/ep_library_provider_bridge.h | 9 +++++++-- onnxruntime/core/session/utils.cc | 5 +++-- 7 files changed, 36 insertions(+), 10 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h index 672103bedc437..bbd6a43bb7a41 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h @@ -12,4 +12,7 @@ static const char* const kOrtEpDevice_EpMetadataKey_Version = "version"; // Prefix for execution provider compatibility information stored in model metadata. // Used when generating EP context models to store compatibility strings for each EP. // Full key format: "ep_compatibility_info." -static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info."; \ No newline at end of file +static const char* const kOrtModelMetadata_EpCompatibilityInfoPrefix = "ep_compatibility_info."; + +// Key for the execution provider library path (for dynamically loaded EPs) +static const char* const kOrtEpDevice_EpMetadataKey_LibraryPath = "library_path"; diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc index d6e51a44c1c69..42b65239de92c 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc @@ -4,6 +4,8 @@ #include "core/session/plugin_ep/ep_factory_provider_bridge.h" #include "core/providers/shared_library/provider_host_api.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" namespace onnxruntime { OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, @@ -20,6 +22,11 @@ OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_fa auto* ep_device = ep_devices[i]; if (ep_device) { ep_device->ep_factory = &ep_factory; + + // Add library path to EP metadata if available + if (library_path_.has_value()) { + ep_device->ep_metadata.Add(kOrtEpDevice_EpMetadataKey_LibraryPath, library_path_->string()); + } } } diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index 437af62dc2c0c..8c5ef526baba1 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -3,6 +3,10 @@ #pragma once +#include +#include +#include + #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" @@ -12,12 +16,14 @@ namespace onnxruntime { class ProviderBridgeEpFactory : public EpFactoryInternalImpl { public: - ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) + ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library, + std::optional library_path = std::nullopt) : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), ep_factory.GetVendor(&ep_factory), ep_factory.GetVendorId(&ep_factory)), ep_factory_{ep_factory}, - provider_library_{provider_library} { + provider_library_{provider_library}, + library_path_{std::move(library_path)} { } private: @@ -59,8 +65,9 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); } - OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP - ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP + OrtEpFactory& ep_factory_; + ProviderLibrary& provider_library_; + std::optional library_path_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_library.h b/onnxruntime/core/session/plugin_ep/ep_library.h index 24ab74e1c77fc..af5bc23143e33 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library.h +++ b/onnxruntime/core/session/plugin_ep/ep_library.h @@ -23,6 +23,7 @@ class EpLibrary { virtual Status Load() { return Status::OK(); } virtual const std::vector& GetFactories() = 0; // valid after Load() virtual Status Unload() { return Status::OK(); } + virtual ~EpLibrary() = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(EpLibrary); diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc index 06cf54aea4071..da94a9f12ba9d 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc @@ -4,6 +4,7 @@ #include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/plugin_ep/ep_factory_provider_bridge.h" +#include "core/session/plugin_ep/ep_library_plugin.h" namespace onnxruntime { Status EpLibraryProviderBridge::Load() { @@ -26,8 +27,9 @@ Status EpLibraryProviderBridge::Load() { // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. + for (const auto& factory : ep_library_plugin_->GetFactories()) { - auto factory_impl = std::make_unique(*factory, *provider_library_); + auto factory_impl = std::make_unique(*factory, *provider_library_, library_path_); auto internal_factory = std::make_unique(std::move(factory_impl)); factory_ptrs_.push_back(internal_factory.get()); diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h index c7e8ebefc3785..45277b2828f56 100644 --- a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h @@ -21,9 +21,11 @@ namespace onnxruntime { class EpLibraryProviderBridge : public EpLibrary { public: EpLibraryProviderBridge(std::unique_ptr provider_library, - std::unique_ptr ep_library_plugin) + std::unique_ptr ep_library_plugin, + std::optional library_path = std::nullopt) : provider_library_{std::move(provider_library)}, - ep_library_plugin_{std::move(ep_library_plugin)} { + ep_library_plugin_{std::move(ep_library_plugin)}, + library_path_{std::move(library_path)} { } const char* RegistrationName() const override { @@ -53,6 +55,9 @@ class EpLibraryProviderBridge : public EpLibrary { // implement EpFactoryInternal::CreateIExecutionProvider by calling Provider::CreateIExecutionProvider. std::unique_ptr ep_library_plugin_; + // Library path for EP metadata + std::optional library_path_; + std::vector> factories_; std::vector factory_ptrs_; // for convenience std::vector internal_factory_ptrs_; // for convenience diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index d4041dfce5a7a..7da7fabb15b15 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -421,13 +421,14 @@ Status LoadPluginOrProviderBridge(const std::string& registration_name, << (is_provider_bridge ? " as a provider bridge" : " as a plugin"); // create EpLibraryPlugin to ensure CreateEpFactories and ReleaseEpFactory are available - auto ep_library_plugin = std::make_unique(registration_name, std::move(resolved_library_path)); + auto ep_library_plugin = std::make_unique(registration_name, resolved_library_path); ORT_RETURN_IF_ERROR(ep_library_plugin->Load()); if (is_provider_bridge) { // wrap the EpLibraryPlugin with EpLibraryProviderBridge to add to directly create an IExecutionProvider auto ep_library_provider_bridge = std::make_unique(std::move(provider_library), - std::move(ep_library_plugin)); + std::move(ep_library_plugin), + resolved_library_path); ORT_RETURN_IF_ERROR(ep_library_provider_bridge->Load()); internal_factories = ep_library_provider_bridge->GetInternalFactories(); ep_library = std::move(ep_library_provider_bridge); From 7a919c693692d50f7c222660b76fb5b0c9926738 Mon Sep 17 00:00:00 2001 From: Preetha Veeramalai Date: Fri, 29 Aug 2025 23:04:32 +0530 Subject: [PATCH 12/23] [OVEP] OpenVINO EP Features and bug-fixes for ORT-1.23 (#25884) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description This update introduces multiple improvements, fixes, and feature enhancements to the OpenVINO Execution Provider (OVEP) and related components in ONNX Runtime: #### Configuration & Properties - Updated load_config mapping to act as a passthrough to OpenVINO properties. - Added support for providing layout information to inputs/outputs in OpenVINO. #### Inference & Tensor Handling - Improved OVInferRequest::SetTensor to correctly handle cached binding shape mismatches. - Added support for self-detecting on-the-fly bfloat16 → float16 conversion. - Fixed issues with input ONNX models when used with shared execution contexts. #### Model Handling & Operator Support - Fixed model copying behavior for QDQ stripping. - Updated operator support status for OpenVINO 2025.2. #### Platform & Integration Fixes - Applied multiple PSU Lora fixes and related updates. - Resolved filename confusion issues with wrapped OVIRs in EPCtx. - Enabled memory-mapped native binaries for OpenVINO 2025.3. #### Quality & Maintenance - Addressed linting issues. - Fixed coverage gaps in OVEP. - Added a new test script for OpenVINO with ORT ABI integration. --------- Co-authored-by: Ankit Maheshkar Co-authored-by: Ryan Metcalfe <107415876+RyanMetcalfeInt8@users.noreply.github.com> Co-authored-by: Klimenko, Mikhail Co-authored-by: sfatimar Co-authored-by: Garth Long Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: MayureshV1 <47039074+MayureshV1@users.noreply.github.com> Co-authored-by: Eric Crawford Co-authored-by: jatinwadhwa921 <110383850+jatinwadhwa921@users.noreply.github.com> Co-authored-by: Vishnudas Thaniel S Co-authored-by: Javier Martinez --- .../providers/openvino/backend_manager.cc | 77 +++++++++++- .../core/providers/openvino/backend_utils.cc | 40 ++++++ .../core/providers/openvino/backend_utils.h | 2 + .../openvino/backends/basic_backend.cc | 107 +++------------- .../core/providers/openvino/contexts.h | 4 +- .../core/providers/openvino/ibackend.h | 2 +- .../openvino/onnx_ctx_model_helper.cc | 10 +- .../openvino/onnx_ctx_model_helper.h | 8 +- .../openvino/openvino_execution_provider.cc | 7 +- .../openvino/openvino_parser_utils.cc | 74 +++++++++++ .../openvino/openvino_parser_utils.h | 2 + .../openvino/openvino_provider_factory.cc | 10 +- .../core/providers/openvino/ov_factory.cc | 2 +- .../core/providers/openvino/ov_interface.cc | 14 ++- .../core/providers/openvino/ov_interface.h | 32 ++--- .../openvino/ov_versions/capability.cc | 42 ++++--- .../openvino/ov_versions/data_ops.cc | 18 +-- .../providers/openvino/ov_versions/utils.cc | 18 +++ .../providers/openvino/ov_versions/utils.h | 4 + .../qdq_transformations/qdq_scales_fix.cpp | 72 ++++++++--- .../qdq_transformations/qdq_scales_fix.h | 5 + .../qdq_transformations/qdq_stripping.cc | 37 +++++- .../test/perftest/command_args_parser.cc | 4 +- onnxruntime/test/perftest/ort_test_session.cc | 4 +- .../providers/cpu/controlflow/loop_test.cc | 4 +- .../tensor/dynamic_quantize_linear_test.cc | 3 +- .../openvino_ep_bfloat16_pass_test.cc | 116 ++++++++++++++++++ 27 files changed, 545 insertions(+), 173 deletions(-) create mode 100644 onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index be59b1ae07020..68d15bdfdcee0 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -90,7 +90,12 @@ BackendManager::BackendManager(SessionContext& session_context, "[OpenVINO-EP] Bounded dynamic model execution using provider option reshape_input is not supported for OVEP EPContext model"; ORT_THROW(exception_str); } - model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph); + if (subgraph_context_.is_ep_ctx_ovir_encapsulated) { + model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.onnx_model_path_name.replace_extension("xml").string(), subgraph); + } else { + model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph); + } + } else { model_proto = GetModelProtoFromFusedNode(fused_node, subgraph, logger); } @@ -236,7 +241,9 @@ Status BackendManager::ExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphVie std::ofstream blob_file(blob_filename, std::ios::out | std::ios::trunc | std::ios::binary); if (!blob_file) { - ORT_THROW("Unable to open file for epctx model dump."); + std::ostringstream err_msg; + err_msg << "Unable to open file for epctx model dump: " << blob_filename; + ORT_THROW(err_msg.str()); } compiled_model.export_model(blob_file); model_blob_str = blob_filename.filename().string(); @@ -375,6 +382,56 @@ static bool IsQDQGraph(const onnxruntime::GraphViewer& graph_viewer) { return false; } +static bool IsModelBF16(const onnxruntime::GraphViewer& graph_viewer) { + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (std::size_t i = 0; i < node_indices.size(); i++) { + gsl::not_null node(graph_viewer.GetNode(node_indices[i])); + for (auto& output : node->OutputDefs()) { + if (output->ToProto().type().tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) + return true; + } + } + return false; +} + +static bool Is16BitTensor(const onnxruntime::NodeArg* node_arg) { + const auto* type_proto = node_arg ? node_arg->TypeAsProto() : nullptr; + return type_proto && type_proto->has_tensor_type() && + (type_proto->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_UINT16 || + type_proto->tensor_type().elem_type() == ONNX_NAMESPACE::TensorProto_DataType_INT16); +} + +// Check to see if the graph has Q/DQ nodes with int16 or uint16 quantization +static bool IsQDQGraphWithUint16OrInt16(const onnxruntime::GraphViewer& graph_viewer) { + std::unordered_set qdq_ops = {"QuantizeLinear", "DequantizeLinear"}; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + + for (size_t i = 0; i < node_indices.size(); i++) { + gsl::not_null node(graph_viewer.GetNode(node_indices[i])); + + if (qdq_ops.find(node->OpType()) != qdq_ops.end()) { + const auto& input_defs = node->InputDefs(); + + if (node->OpType() == "DequantizeLinear") { + // DequantizeLinear: [quantized_input, scale, zero_point] -> [float_output] + // Check quantized input tensor and optional zero point + if (Is16BitTensor(input_defs.empty() ? nullptr : input_defs[0]) || + (input_defs.size() >= 3 && Is16BitTensor(input_defs[2]))) { + return true; + } + } else if (node->OpType() == "QuantizeLinear") { + // QuantizeLinear: [float_input, scale, zero_point] -> [quantized_output] + const auto& output_defs = node->OutputDefs(); + if (Is16BitTensor(output_defs.empty() ? nullptr : output_defs[0]) || + (input_defs.size() >= 3 && Is16BitTensor(input_defs[2]))) { + return true; + } + } + } + } + return false; +} + static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& onnx_model_path_name, [[maybe_unused]] ONNX_NAMESPACE::ModelProto* model_proto, [[maybe_unused]] const onnxruntime::Node& fused_node) { @@ -433,6 +490,10 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, } #endif + // Check if the graph is QDQ and has int16 or uint16 quantization + // If so, we will apply the QDQ scales fix transformation (for GPU device only) + bool is_qdq_graph_uint16_or_int16 = IsQDQGraphWithUint16OrInt16(subgraph); + const auto& onnx_model_path_name = subgraph.ModelPath(); // QDQ stripping enabled only for the NPU and experimentally on the GPU if ((session_context_.device_type.find("NPU") != std::string::npos) && @@ -446,7 +507,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; } else if ((session_context_.device_type.find("GPU") != std::string::npos) && - enable_ovep_qdq_optimizer) { + is_qdq_graph_uint16_or_int16) { // Create a copy of the model std::unique_ptr model; Status status = qdq_scales_fix::Transform(subgraph, logger, model); @@ -456,6 +517,16 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; + } else if (IsModelBF16(subgraph)) { + LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP bfloat16->float16 optimization pass is enabled"; + std::unique_ptr model; + Status status = bfloat16_fix::Transform(subgraph, logger, model); + auto model_proto = model->ToProto(); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + print_model_proto_duration(); + DumpOpenVINOEPModel(onnx_model_path_name, model_proto.get(), fused_node); + ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); + return model_proto; } else { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] OVEP QDQ optimization pass is disabled"; auto model = subgraph.CreateModel(logger); diff --git a/onnxruntime/core/providers/openvino/backend_utils.cc b/onnxruntime/core/providers/openvino/backend_utils.cc index 73fbe9a0fa76f..7027861f0c4dc 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.cc +++ b/onnxruntime/core/providers/openvino/backend_utils.cc @@ -150,6 +150,11 @@ CreateOVModel(std::string&& model, LOGS_DEFAULT(INFO) << log_tag << "Reshaping the ov tensor to specified shape"; ov_model->reshape(session_context.reshape); } + + if (!session_context.layout.empty()) { + LOGS_DEFAULT(INFO) << log_tag << "Setting the ov tensor layout to specified layout"; + ov_model = Set_Layout(ov_model, session_context.layout); + } // Check for Constant Folding if ((session_context.device_type != "NPU") && !session_context.is_wholly_supported_graph) { ov::pass::ConstantFolding pass_const_obj; @@ -199,6 +204,41 @@ GetOutputTensor(Ort::KernelContext& context, return context.GetOutput(index, output_shape); } +std::shared_ptr Set_Layout(std::shared_ptr ov_model, const layout_t& layout) { + ov::preprocess::PrePostProcessor preproc(ov_model); + + const auto& inputs = ov_model->inputs(); + const auto& outputs = ov_model->outputs(); + + auto find_tensor_index = [](const std::vector>& tensors, const std::string& name) -> std::optional { + for (size_t i = 0; i < tensors.size(); ++i) { + const auto& tensor = tensors[i]; + if (tensor.get_any_name() == name || tensor.get_tensor().get_names().count(name) > 0) { + return i; + } + } + return std::nullopt; + }; + + for (const auto& [tensor_name, layout_value] : layout) { + bool tensor_found = false; + + if (auto input_idx = find_tensor_index(inputs, tensor_name)) { + preproc.input(*input_idx).tensor().set_layout(layout_value); + tensor_found = true; + } else if (auto output_idx = find_tensor_index(outputs, tensor_name)) { + preproc.output(*output_idx).tensor().set_layout(layout_value); + tensor_found = true; + } + + if (!tensor_found) { + LOGS_DEFAULT(WARNING) << "Tensor '" << tensor_name << "' not found in model inputs or outputs"; + } + } + + return preproc.build(); +} + int GetFirstAvailableDevice(SessionContext& session_context) { int i = 0; // Get the first available VAD-M device and set the device to busy diff --git a/onnxruntime/core/providers/openvino/backend_utils.h b/onnxruntime/core/providers/openvino/backend_utils.h index 15145df651fa2..27f791c7a5bd1 100644 --- a/onnxruntime/core/providers/openvino/backend_utils.h +++ b/onnxruntime/core/providers/openvino/backend_utils.h @@ -79,6 +79,8 @@ int GetFirstAvailableDevice(SessionContext& session_context); void FillOutputsWithConstantData(std::shared_ptr node, Ort::UnownedValue& out_tensor); +std::shared_ptr Set_Layout(std::shared_ptr ov_model, const layout_t& layout); + template void FillOutputHelper(Ort::UnownedValue& out_tensor, std::shared_ptr node); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index 6efd866d47c3c..2f174110dd31b 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -59,7 +59,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr }; // If the EPContext node with OVIR Encapsulation, then create // an executable network from EP_CACHE_CONTEXT using read_model() & compile_model() - exe_network_ = OVCore::Get()->ImportEPCtxOVIREncapsulation(*model_stream, + exe_network_ = OVCore::Get()->ImportEPCtxOVIREncapsulation(*model_stream->stream_, hw_target, device_config, enable_causallm, @@ -98,6 +98,7 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr !subgraph_context_.has_dynamic_input_shape && !session_context_.so_context_enable && session_context_.reshape.empty() && + session_context_.layout.empty() && !enable_causallm && !eligible_for_cpu_fallback && auto_unified_compile); @@ -213,101 +214,29 @@ void BasicBackend::PopulateConfigValue(ov::AnyMap& device_config) { if (!session_context_.load_config.empty()) { const std::map& target_config = session_context_.load_config; - if ((session_context_.device_type.find("NPU") != std::string::npos) && session_context_.enable_causallm) { - if (target_config.find("NPU") != target_config.end()) { - auto npu_genai_config = target_config.at("NPU"); - CausalLMConfig().ApplyConfig(npu_genai_config, device_config); - } else { - LOGS_DEFAULT(WARNING) << "ORT GenAI CausalLMConfig Configuration not found."; - } - } + // Extract device names from device string and apply their configs + // Examples: "GPU" -> ["GPU"], "AUTO:GPU.0,CPU" -> ["AUTO", "GPU", "CPU"] + auto apply_device_config = [&](std::string_view device) { + if (device.empty()) return; - if (session_context_.device_type.find("NPU") != std::string::npos) { - auto npuw_config = target_config.at("NPU"); - - // Check if "NPU_USE_NPUW" exists and is set to "YES" - auto npu_use_npuw_it = npuw_config.find("NPU_USE_NPUW"); - if (npu_use_npuw_it != npuw_config.end() && - npu_use_npuw_it->second.is() && - npu_use_npuw_it->second.as() == "YES") { - // Only add NPUW-related keys if NPU_USE_NPUW is "YES" - for (const auto& [key, value] : npuw_config) { - if (key.find("NPUW") != std::string::npos) { - if (!value.is()) { - LOGS_DEFAULT(ERROR) << "Invalid value type for key: " << key; - continue; - } - device_config[key] = value; - } - } - } else { - // Check if there are any "NPUW" keys and log a warning - if (std::any_of(npuw_config.begin(), npuw_config.end(), - [&](const auto& pair) { return pair.first.find("NPUW") != std::string::npos; })) { - LOGS_DEFAULT(WARNING) << "Skipping NPUW-related configurations as NPU_USE_NPUW is not set to 'YES'."; - } - } - } - auto find_device_type_mode = [&](const std::string& device_type) -> std::string { - std::string device_mode = ""; - auto delimiter_pos = device_type.find(':'); - if (delimiter_pos != std::string::npos) { - std::stringstream str_stream(device_type.substr(0, delimiter_pos)); - std::getline(str_stream, device_mode, ','); - } - return device_mode; - }; - - // Parse device types like "AUTO:CPU,GPU" and extract individual devices - auto parse_individual_devices = [&](const std::string& device_type) -> std::vector { - std::vector devices; - auto delimiter_pos = device_type.find(':'); - if (delimiter_pos != std::string::npos) { - std::stringstream str_stream(device_type.substr(delimiter_pos + 1)); - std::string device; - while (std::getline(str_stream, device, ',')) { - devices.emplace_back(device); - } - } else { - devices.emplace_back(device_type); - } - return devices; - }; + // Remove device index: "GPU.0" -> "GPU" + auto base_device = device.substr(0, device.find('.')); - // Set properties, Validation will be handled by OpenVINO Core - auto set_target_properties = [&](const std::string& device, const ov::AnyMap& config_options) { - for (const auto& [key, value] : config_options) { - if ((key.find("NPUW") != std::string::npos) || - ((device_config.find(key) != device_config.end()) && session_context_.enable_causallm)) { - continue; + if (auto config_it = target_config.find(std::string(base_device)); config_it != target_config.end()) { + for (const auto& [key, value] : config_it->second) { + device_config[key] = value; } - OVCore::Get()->core.set_property(device, ov::AnyMap{{key, value}}); } }; - // Check if the device type is AUTO, HETERO, or MULTI - if (session_context_.device_type.find("AUTO") == 0 || - session_context_.device_type.find("HETERO") == 0 || - session_context_.device_type.find("MULTI") == 0) { - //// Parse to get the device mode (e.g., "AUTO:CPU,GPU" -> "AUTO") - std::unordered_set supported_mode = {"AUTO", "HETERO", "MULTI"}; - auto device_mode = find_device_type_mode(session_context_.device_type); - ORT_ENFORCE(supported_mode.find(device_mode) != supported_mode.end(), " Invalid device mode is passed : ", session_context_.device_type); - // Parse individual devices (e.g., "AUTO:CPU,GPU" -> ["CPU", "GPU"]) - auto individual_devices = parse_individual_devices(session_context_.device_type); - if (!device_mode.empty()) individual_devices.emplace_back(device_mode); - - // Set properties only for individual devices (e.g., "CPU", "GPU") - for (const std::string& device : individual_devices) { - if (target_config.count(device)) { - // Set properties for the device - set_target_properties(device, target_config.at(device)); + // Parse device string by splitting on ':' and ',' delimiters + const auto& device_str = session_context_.device_type; + for (size_t start = 0, pos = 0; pos <= device_str.size(); ++pos) { + if (pos == device_str.size() || device_str[pos] == ':' || device_str[pos] == ',') { + if (pos > start) { + apply_device_config(std::string_view(device_str).substr(start, pos - start)); } - } - } else { - if (target_config.count(session_context_.device_type)) { - set_target_properties(session_context_.device_type, - target_config.at(session_context_.device_type)); + start = pos + 1; } } } diff --git a/onnxruntime/core/providers/openvino/contexts.h b/onnxruntime/core/providers/openvino/contexts.h index 6a2b375d733f9..07b09899ac214 100644 --- a/onnxruntime/core/providers/openvino/contexts.h +++ b/onnxruntime/core/providers/openvino/contexts.h @@ -70,6 +70,7 @@ class SharedContext : public WeakSingleton { using config_t = std::map; using reshape_t = std::map; +using layout_t = std::map; struct ProviderInfo { std::string device_type{""}; // [device_type]: Overrides the accelerator hardware type and @@ -88,6 +89,7 @@ struct ProviderInfo { // (GPU) feature. If blob files are already present, // it will be directly loaded. reshape_t reshape{}; // Used for reshaping the ov input tensor shape at runtime. + layout_t layout{}; // Used for specifying the ov input/output tensor layout at runtime. std::string model_priority{"DEFAULT"}; // High-level OpenVINO model priority hint // Defines what model should be provided with more performant // bounded resource first @@ -110,7 +112,7 @@ struct ProviderInfo { const ConfigOptions* config_options{NULL}; const std::unordered_set valid_provider_keys = {"device_type", "device_id", "device_luid", "cache_dir", "precision", "load_config", "context", "num_of_threads", "model_priority", "num_streams", "enable_opencl_throttling", "enable_qdq_optimizer", - "enable_causallm", "disable_dynamic_shapes", "reshape_input"}; + "enable_causallm", "disable_dynamic_shapes", "reshape_input", "layout"}; }; // Holds context applicable to the entire EP instance. diff --git a/onnxruntime/core/providers/openvino/ibackend.h b/onnxruntime/core/providers/openvino/ibackend.h index ec38425f602eb..365a4625815d6 100644 --- a/onnxruntime/core/providers/openvino/ibackend.h +++ b/onnxruntime/core/providers/openvino/ibackend.h @@ -19,7 +19,7 @@ class IBackend { virtual ~IBackend() = default; virtual void RewindKVCache(size_t index) {} }; -using ptr_stream_t = std::unique_ptr; +using ptr_stream_t = std::unique_ptr; class BackendFactory { public: static std::shared_ptr diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index 9e70756a254aa..051a39bd4f205 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -100,7 +100,8 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, return Status::OK(); } -std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const { +std::unique_ptr +EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const { auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin(); auto node = graph_viewer.GetNode(first_index); ORT_ENFORCE(node != nullptr); @@ -113,10 +114,11 @@ std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesy bool embed_mode = static_cast(attrs.at(EMBED_MODE).i()); std::unique_ptr result; + std::filesystem::path blob_filepath{}; if (embed_mode) { result.reset((std::istream*)new std::istringstream(ep_cache_context)); } else { - auto blob_filepath = so_context_file_path; + blob_filepath = so_context_file_path; if (blob_filepath.empty() && !graph_viewer.ModelPath().empty()) { blob_filepath = graph_viewer.ModelPath(); } @@ -126,16 +128,18 @@ std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesy } bool isXML = backend_utils::IsModelStreamXML(*result); + std::filesystem::path native_blob_path{}; if (!isXML) { // If the model stream is not an XML (i.e. precompiled blob), the OpenVINO SDK version that it was // exported with must match the version that is currently running. + native_blob_path = std::move(blob_filepath); ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_), "EPCtx blob was exported / is compatible with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() + ", but OpenVINO SDK version currently in use is " + openvino_sdk_version_); } LOGS_DEFAULT(VERBOSE) << "[OpenVINO EP] Read blob from EPContext Node"; - return result; + return std::make_unique(std::move(result), native_blob_path); } bool EPCtxHandler::CheckForOVEPCtxNodeInGraph(const GraphViewer& graph_viewer) const { diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h index b9ddb40a7a233..f207f5014ca1f 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h @@ -12,6 +12,12 @@ namespace onnxruntime { namespace openvino_ep { +struct ModelBlobWrapper { + ModelBlobWrapper(std::unique_ptr stream, const std::filesystem::path& native_blob_path) : stream_(std::move(stream)), maybe_native_blob_path_(native_blob_path) {} + std::unique_ptr stream_; + std::filesystem::path maybe_native_blob_path_; +}; + // Utilities to handle EPContext node export and parsing of an EPContext node // to create the compiled_model object to infer on static const char EPCONTEXT_OP[] = "EPContext"; @@ -31,7 +37,7 @@ class EPCtxHandler { const std::string& graph_name, const bool embed_mode, std::string&& model_blob_str) const; - std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const; + std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const; InlinedVector GetEPCtxNodes() const; bool CheckEPCacheContextAttribute(const GraphViewer& graph_viewer, const std::string& target_attr_extn) const; diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index 1b19517b07363..a0fa885cbfc38 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -94,18 +94,23 @@ common::Status OpenVINOExecutionProvider::Compile( auto& logger = *GetLogger(); Status status = Status::OK(); + bool is_epctx_model = false; if (!fused_nodes.empty()) { // Assume these properties are constant for all the model subgraphs, otherwise move to SubGraphContext const auto& graph_body_viewer_0 = fused_nodes[0].filtered_graph.get(); session_context_.onnx_model_path_name = graph_body_viewer_0.ModelPath().string(); session_context_.onnx_opset_version = graph_body_viewer_0.DomainToVersionMap().at(kOnnxDomain); + + // OVIR wrapped in epctx should be treated as source but this code does not + // This corner case is not in use and will be addressed in a future commit + is_epctx_model = ep_ctx_handle_.CheckForOVEPCtxNodeInGraph(graph_body_viewer_0); } // The block below is executed during EP context model inference auto& metadata = shared_context_->shared_weights.metadata; // Metadata object in memory if (session_context_.so_share_ep_contexts && - !session_context_.so_context_enable && + is_epctx_model && metadata.empty()) { fs::path context_model_file_path = session_context_.so_context_file_path; if (context_model_file_path.empty()) { diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc index 21fc7f935da23..a290fea73e0e8 100644 --- a/onnxruntime/core/providers/openvino/openvino_parser_utils.cc +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.cc @@ -236,5 +236,79 @@ ov::Dimension OpenVINOParserUtils::ParseDimensionRange(const std::string& range_ return ov::Dimension(range_start, range_end); } +layout_t OpenVINOParserUtils::ParseLayout(const std::string& layout_definition) { + layout_t parsed_layout_map; + + // Return empty map for empty input + if (layout_definition.empty()) { + ORT_THROW("Empty layout definition provided in layout parameter"); + } + + // Regular expression for parsing layout definitions + const std::regex layout_pattern(R"(([^\[\],]+)\s*\[(.*?)\])"); // e.g. "input_1[NC],data[CHW]" + + // Find all tensor layout definitions using regex + auto layout_begin = std::sregex_iterator( + layout_definition.begin(), + layout_definition.end(), + layout_pattern); + auto layout_end = std::sregex_iterator(); + + // If no matches found, throw error + if (layout_begin == layout_end) { + ORT_THROW("Invalid layout definition format: " + layout_definition); + } + + // Process each tensor definition + for (std::sregex_iterator i = std::move(layout_begin); i != layout_end; ++i) { + std::smatch layout_match = *i; + + // Extract tensor name and trim whitespace + std::string tensor_name = layout_match[1].str(); // Group 1: tensor name e.g. "input_1" + tensor_name = TrimWhitespace(tensor_name); + + if (tensor_name.empty()) { + ORT_THROW("Empty tensor name provided in layout parameter"); + } + + // Extract dimensions string + std::string dimensions_str = layout_match[2].str(); // Group 2: dimensions string [e.g. "NC", "CHW"] + + if (!Check_Valid_Layout(dimensions_str, tensor_name)) { + ORT_THROW("Invalid dimensions string provided in layout parameter"); + } + + // Store parsed shape in result map + parsed_layout_map[tensor_name] = ov::Layout(dimensions_str); + } + + return parsed_layout_map; +} + +bool OpenVINOParserUtils::Check_Valid_Layout(const std::string& layout_str, const std::string& tensor_name) { + // Check if the layout string is empty + if (layout_str.empty()) { + return false; + } + + std::unordered_set seen_alphabets; + for (char c : layout_str) { + if (std::isalpha(c)) { + char upper_c = static_cast(std::toupper(c)); // Convert to uppercase for case-insensitive comparison + if (seen_alphabets.find(upper_c) != seen_alphabets.end()) { + ORT_THROW("Repeated Dim '" + std::string(1, c) + + "' found in layout dimensions for tensor '" + tensor_name + "'"); + } + seen_alphabets.insert(upper_c); + } else if (c != '?') { + // Only '?' is allowed as non-alphabetic character + ORT_THROW("Invalid character '" + std::string(1, c) + + "' found in layout dimensions for tensor '" + tensor_name + "'"); + } + } + + return true; +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/openvino_parser_utils.h b/onnxruntime/core/providers/openvino/openvino_parser_utils.h index e6aa0e0a46a3b..a0936d627df40 100644 --- a/onnxruntime/core/providers/openvino/openvino_parser_utils.h +++ b/onnxruntime/core/providers/openvino/openvino_parser_utils.h @@ -18,8 +18,10 @@ class OpenVINOParserUtils { std::string& device_type, const std::string& option_name); static reshape_t ParseInputShape(const std::string& reshape_input_definition); + static layout_t ParseLayout(const std::string& layout_definition); static std::string TrimWhitespace(const std::string& str); static ov::Dimension ParseDimensionRange(const std::string& range_str, const std::string& tensor_name); + static bool Check_Valid_Layout(const std::string& layout_str, const std::string& tensor_name); }; } // namespace openvino_ep diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 9dba8623031d0..1a10d9849d5cc 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -171,7 +171,7 @@ std::string ParseDeviceType(std::shared_ptr ov_core, const ProviderOptio if (!device_mode.empty()) { selected_device = device_mode + ":" + ov_luid_devices; for (const auto& dev_str : devices_to_check) { - const auto default_dev = split(dev_str, '.')[0]; + const std::string default_dev = split(dev_str, '.')[0]; if (ov_luid_devices.find(default_dev) == std::string::npos) selected_device = selected_device + "," + dev_str; @@ -230,6 +230,10 @@ static void ParseProviderInfo(const ProviderOptions& provider_options, pi.reshape = OpenVINOParserUtils::ParseInputShape(provider_options.at("reshape_input")); } + if (provider_options.contains("layout")) { + pi.layout = OpenVINOParserUtils::ParseLayout(provider_options.at("layout")); + } + if (provider_options.contains("load_config")) { auto parse_config = [&](const std::string& config_str) -> std::map { // If the config string is empty, return an empty map and skip processing @@ -526,7 +530,7 @@ struct OpenVINO_Provider : Provider { std::string ov_device_string; if (is_meta_device_factory) { // Build up a meta device string based on the devices that are passed in. E.g. AUTO:NPU,GPU.0,CPU - ov_device_string = ov_meta_device_type; + ov_device_string = std::move(ov_meta_device_type); ov_device_string += ":"; } @@ -539,7 +543,7 @@ struct OpenVINO_Provider : Provider { prepend_comma = true; } - provider_options["device_type"] = ov_device_string; + provider_options["device_type"] = std::move(ov_device_string); // Parse provider info with the device type ProviderInfo pi; diff --git a/onnxruntime/core/providers/openvino/ov_factory.cc b/onnxruntime/core/providers/openvino/ov_factory.cc index 8860405338409..2853cc17726ab 100644 --- a/onnxruntime/core/providers/openvino/ov_factory.cc +++ b/onnxruntime/core/providers/openvino/ov_factory.cc @@ -105,7 +105,7 @@ OrtStatus* OpenVINOEpPluginFactory::GetSupportedDevices(const OrtHardwareDevice* std::string ov_device_name; auto get_gpu_device_id = [&](const std::string& ov_device) { try { - auto device_id_str = ov_core_->get_property(ov_device, "GPU_DEVICE_ID").as(); + const std::string device_id_str = ov_core_->get_property(ov_device, "GPU_DEVICE_ID").as(); return static_cast(std::stoul(device_id_str, nullptr, 0)); } catch (ov::Exception&) { return 0u; // If we can't get the GPU_DEVICE_ID info, we won't have a device ID. diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 2d29df8eb4197..899845d4890cf 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -11,6 +11,7 @@ #include "core/providers/openvino/backend_utils.h" #include "core/providers/openvino/backends/basic_backend.h" #include "core/providers/openvino/ov_stateful_patch_utils.h" +#include "core/providers/openvino/onnx_ctx_model_helper.h" namespace onnxruntime { namespace openvino_ep { @@ -191,14 +192,23 @@ OVExeNetwork OVCore::CompileModel(const std::string& onnx_model, "Exception while Loading Network for graph {}", name); } -OVExeNetwork OVCore::ImportModel(std::istream& model_stream, +OVExeNetwork OVCore::ImportModel(ModelBlobWrapper& model_blob, std::string hw_target, const ov::AnyMap& device_config, std::string name) { return OvExceptionBoundary([&]() { ov::CompiledModel obj; - obj = core.import_model(model_stream, hw_target, device_config); +#if (OPENVINO_VERSION_MAJOR > 2025 || (OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR >= 3)) + if (!model_blob.maybe_native_blob_path_.empty()) { + obj = core.import_model(ov::read_tensor_data(model_blob.maybe_native_blob_path_), hw_target, device_config); + } else { + obj = core.import_model(*model_blob.stream_, hw_target, device_config); + } +#else + obj = core.import_model(*model_blob.stream_, hw_target, device_config); +#endif OVExeNetwork exe(obj, hw_target); + #ifndef NDEBUG printDebugInfo(exe.Get()); #endif diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 6d1db4366410b..38ea883078e85 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -26,6 +26,7 @@ namespace openvino_ep { class OVCore; class OVInferRequest; class OVExeNetwork; +struct ModelBlobWrapper; typedef ov::Tensor OVTensor; typedef ov::ProfilingInfo OVProfilingInfo; @@ -82,7 +83,7 @@ struct OVCore : WeakSingleton { ov::AnyMap& device_config, const std::string& name); // OV Interface for Import model Stream - OVExeNetwork ImportModel(std::istream& model_stream, + OVExeNetwork ImportModel(ModelBlobWrapper& model_blob, std::string hw_target, const ov::AnyMap& device_config, std::string name); @@ -126,29 +127,16 @@ class OVInferRequest { OVTensorPtr GetTensor(const std::string& name); std::string GetInputTensorName(uint32_t index); - // Set tensor described param_info and ort_ptr. Overrides shape in param_info with shape_override. Call infer req tensor if ort_ptr is last set. + // Set tensor call infer req tensor if ort_ptr differs from last set ptr. void SetTensor(const std::string& name, const ov::element::Type& type, const ov::Shape& shape, void* ort_ptr) { auto& cached_binding = bindings_cache_[name]; - if (cached_binding.ort_ptr != ort_ptr) { - auto tensor_ptr = std::make_shared(type, shape, const_cast(ort_ptr)); - SetTensor(name, tensor_ptr); - cached_binding = {tensor_ptr, ort_ptr}; - } else if (ort_ptr == nullptr) { - // a null ort_ptr is expected for a tensor that has 0 elements. - // for example, a tensor of shape=[1, 8, 0, 64], which is valid. - // So, we check to see if at least one shape entry is 0. - auto contains_zero = [](const ov::Shape& shape) { - for (auto& s : shape) - if (s == 0) return true; - return false; - }; - if (contains_zero(shape)) { - // if there are zero elements (i.e. at least one shape entry is 0), - // then create and set the tensor anyway. - auto tensor_ptr = std::make_shared(type, shape); - SetTensor(name, tensor_ptr); - cached_binding = {tensor_ptr, ort_ptr}; - } + if (cached_binding.ort_ptr != ort_ptr || + !cached_binding.tensor_ptr || + cached_binding.tensor_ptr->get_shape() != shape) { + cached_binding.tensor_ptr.reset(); + auto ov_tensor = std::make_shared(type, shape, const_cast(ort_ptr)); + ovInfReq.set_tensor(name, *ov_tensor); + cached_binding = {std::move(ov_tensor), ort_ptr}; } } diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 2309ff3de751b..1893700cab09c 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -166,17 +166,28 @@ std::vector> GetCapability::Execute() { auto connected_clusters = GetConnectedClusters(graph_viewer_, ng_clusters); int no_of_clusters = 0; - + size_t cluster_index = 0; + size_t total_clusters = connected_clusters.size(); for (auto this_cluster : connected_clusters) { - // If subgraph has less then three, graph is considered trivial unless its an epctx cluster - if (this_cluster.size() < 3) { - bool is_epctx_node = false; - for (auto node_idx : this_cluster) { - if (graph_viewer_.GetNode(node_idx)->OpType() == "EPContext") - is_epctx_node = true; + bool omit_subgraph = false; + + if (this_cluster.size() == 1) { + // check next cluster + auto index = this_cluster.at(0); + size_t j = cluster_index; + if (graph_viewer_.GetNode(index)->OpType() == "EPContext") { + omit_subgraph = false; + } else if (j < total_clusters - 1) { + bool append_node = false; + while (j < total_clusters && !append_node) { + j = j + 1; + append_node = AddTrivialClusterToNextClusterIfConnected(graph_viewer_, index, connected_clusters[j]); + } + if (append_node) { + connected_clusters[j].emplace_back(index); + } + omit_subgraph = true; } - if (!is_epctx_node) - continue; } std::vector cluster_graph_inputs, cluster_inputs, cluster_outputs; @@ -188,7 +199,6 @@ std::vector> GetCapability::Execute() { cluster_inputs, cluster_outputs); - bool omit_subgraph = false; // Omitting zero dim subgraphs for (auto index : this_cluster) { const Node* node = graph_viewer_.GetNode(index); @@ -217,15 +227,17 @@ std::vector> GetCapability::Execute() { } } } - if (omit_subgraph) - continue; /* In scenarios, when there are no inputs or all inputs being initializers, ConstantFolding optimization in onnxruntime pre-computes the value.*/ - if (!cluster_inputs.empty()) { - AppendClusterToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result); - no_of_clusters++; + if (!omit_subgraph) { + if (!cluster_inputs.empty()) { + AppendClusterToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result); + no_of_clusters++; + } } + + cluster_index = cluster_index + 1; } LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Supported subgraphs on OpenVINO: " << no_of_clusters; } diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 17e69ad080b90..f848b89ed10c8 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -121,6 +121,7 @@ std::vector supported_op_mode = { {"DepthToSpace", V_2020_4, {"CPU", "GPU"}}, {"DequantizeLinear", V_2021_4, {"CPU", "GPU"}}, {"DequantizeLinear", V_2024_4, {"NPU"}}, + {"DynamicQuantizeLinear", V_2025_2, {"CPU", "GPU"}}, {"DynamicQuantizeMatMul", V_2025_0, {"CPU", "GPU"}}, {"Div", V_2020_4, {"CPU", "GPU"}}, {"Dropout", V_2020_4, {"CPU", "GPU"}}, @@ -172,6 +173,7 @@ std::vector supported_op_mode = { {"LSTM", V_2020_4, {"CPU", "GPU"}}, {"MatMul", V_2020_4, {"CPU", "GPU"}}, {"MatMulInteger", V_2022_1, {"CPU"}}, + {"MatMulInteger", V_2025_2, {"GPU"}}, {"MatMulNBits", V_2024_5, {"CPU", "GPU"}}, {"Max", V_2020_4, {"CPU", "GPU"}}, {"MaxPool", V_2020_4, {"CPU", "GPU"}}, @@ -191,7 +193,7 @@ std::vector supported_op_mode = { {"Pad", V_2020_4, {"CPU", "GPU"}}, {"Pow", V_2020_4, {"CPU", "GPU"}}, {"PRelu", V_2020_4, {"CPU", "GPU"}}, - {"QLinearMatMul", V_2022_3, {"CPU"}}, + // {"QLinearMatMul", V_2022_3, {"CPU"}}, {"QuantizeLinear", V_2021_4, {"CPU", "GPU"}}, {"QuickGelu", V_2025_0, {"CPU", "GPU"}}, {"RNN", V_2023_1, {"CPU", "GPU"}}, @@ -361,6 +363,7 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Clip", V_2022_1, {"All"}}); no_dimension_supported_.push_back({"Div", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"DequantizeLinear", V_2021_4, {"All"}}); + no_dimension_supported_.push_back({"DynamicQuantizeLinear", V_2025_2, {"All"}}); no_dimension_supported_.push_back({"Equal", V_2022_1, {"CPU"}}); no_dimension_supported_.push_back({"Equal", V_2023_0, {"GPU"}}); no_dimension_supported_.push_back({"Expand", V_2023_3, {"CPU"}}); @@ -374,6 +377,7 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Loop", V_2021_4, {"All"}}); no_dimension_supported_.push_back({"Max", V_2024_4, {"All"}}); no_dimension_supported_.push_back({"Min", V_2020_4, {"All"}}); + no_dimension_supported_.push_back({"MatMulInteger", V_2025_2, {"All"}}); no_dimension_supported_.push_back({"Mul", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Neg", V_2023_0, {"CPU", "GPU"}}); no_dimension_supported_.push_back({"Pow", V_2023_0, {"CPU", "GPU"}}); @@ -555,8 +559,13 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { return false; } + auto dtype = type_proto->tensor_type().elem_type(); + // Enable bfloat16 -> float16 on-the-fly conversion + if (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BFLOAT16 || + dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 || + dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16) + return true; if (is_initializer) { - auto dtype = type_proto->tensor_type().elem_type(); for (auto const& var : supported_types_initializer_) { if ((var.first <= version_id_) && (var.second == dtype)) { @@ -571,8 +580,6 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { #endif return false; } else { - auto dtype = type_proto->tensor_type().elem_type(); - if (device_id_.find("HETERO") != std::string::npos || device_id_.find("MULTI") != std::string::npos || device_id_.find("AUTO") != std::string::npos) { for (auto const& var : supported_types_npu_) { @@ -609,9 +616,6 @@ bool DataOps::type_is_supported(const NodeArg* node_arg, bool is_initializer) { (var.second == dtype)) { return true; } - // experimentally for GPU and qdq stripping mode allow int16 types - if (npu_qdq_optimizer_enabled_ && (dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT16 || dtype == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT16)) - return true; } #ifndef NDEBUG if (openvino_ep::backend_utils::IsDebugEnabled()) { diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.cc b/onnxruntime/core/providers/openvino/ov_versions/utils.cc index f924fa0c8205c..791341218913f 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.cc @@ -153,6 +153,24 @@ GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector& search_cluster) { + for (auto index : search_cluster) { + auto curr_node = graph_viewer.GetNode(index); + for (auto node = curr_node->InputNodesBegin(); node != curr_node->InputNodesEnd(); ++node) { + if ((*node).Index() == curr_node_index) + return true; + } + + for (auto node = curr_node->OutputNodesBegin(); node != curr_node->OutputNodesEnd(); ++node) { + if ((*node).Index() == curr_node_index) + return true; + } + } + return false; +} + void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const std::vector& cluster, const std::unordered_set& ng_required_initializers, diff --git a/onnxruntime/core/providers/openvino/ov_versions/utils.h b/onnxruntime/core/providers/openvino/ov_versions/utils.h index 34aa762ba9b67..bdad047a422c1 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/utils.h +++ b/onnxruntime/core/providers/openvino/ov_versions/utils.h @@ -40,6 +40,10 @@ void IdentifyConnectedNodes( std::vector> GetConnectedClusters(const GraphViewer& graph_viewer, const std::vector>& clusters); +bool AddTrivialClusterToNextClusterIfConnected(const GraphViewer& graph_viewer, + const NodeIndex index, + const std::vector& search_cluster); + void GetInputsOutputsOfCluster(const GraphViewer& graph_viewer, const std::vector& cluster, const std::unordered_set& ng_required_initializers, diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index d159930d52845..3a39152b5d17d 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -3,6 +3,8 @@ #include "qdq_scales_fix.h" #include "core/providers/openvino/ov_protobuf_utils.h" +#include "core/framework/ort_value.h" +#include "core/framework/float16.h" #include #include @@ -903,22 +905,11 @@ Status copy_model(const GraphViewer& src_graph_viewer, } for (auto& [name, tensor_proto] : src_graph.GetAllInitializedTensors()) { - dst_graph.AddInitializedTensor(*tensor_proto); - } - - for (auto node_arg : src_graph.GetInputsIncludingInitializers()) { - auto check_inputs = [node_arg](auto input_node_arg) { - return input_node_arg->Name() == node_arg->Name(); - }; - if (std::find_if(dst_graph_inputs.begin(), dst_graph_inputs.end(), check_inputs) != dst_graph_inputs.end()) - continue; - - auto src_tensor_proto = src_graph.GetConstantInitializer(node_arg->Name(), true); - if (src_tensor_proto) { - auto dst_tensor_proto = onnx::TensorProto::Create(); - dst_tensor_proto->copy_from(src_tensor_proto); - dst_graph.AddInitializedTensor(*dst_tensor_proto); - } + auto ort_value = OrtValue(); + if (src_graph.GetOrtValueInitializer(name, ort_value)) + ORT_RETURN_IF_ERROR(dst_graph.AddInitializedOrtValue(*tensor_proto, ort_value)); + else + dst_graph.AddInitializedTensor(*tensor_proto); } ORT_RETURN_IF_ERROR(dst_graph.Resolve()); @@ -940,5 +931,54 @@ Status Transform(const GraphViewer& src_graph_viewer, return status; } } // namespace qdq_scales_fix + +namespace bfloat16_fix { +void replace_bf16_with_fp16(qdq_scales_fix::CustomGraph& gen_graph) { + for (auto& const_node : gen_graph.original_graph.Nodes()) { + auto node = const_cast(const_node); + if (node->OpType() == "Cast") { + for (auto& [name, const_attribute] : node->GetAttributes()) { + auto& attribute = const_cast(const_attribute); + if (name == "to" && attribute.type() == ONNX_NAMESPACE::AttributeProto_AttributeType_INT) + if (attribute.i() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) + attribute.set_i(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + } + } + for (auto& output : node->OutputDefs()) { + auto& output_proto = const_cast(output->ToProto().type()); + if (output_proto.mutable_tensor_type()->elem_type() == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) + output_proto.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + } + } + + const auto& init_set = gen_graph.original_graph.GetAllInitializedTensors(); + for (auto& [key, const_tensor_proto] : init_set) { + auto tensor_proto = const_cast(const_tensor_proto); + auto dt = tensor_proto->data_type(); + if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) { + auto raw_data = tensor_proto->has_raw_data() ? reinterpret_cast(tensor_proto->mutable_raw_data()->data()) : nullptr; + if (raw_data) { + tensor_proto->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); + std::int64_t size = 1; + for (int i = 0; i < tensor_proto->dims_size(); ++i) + size *= tensor_proto->dims()[i]; + for (std::int64_t i = 0; i < size; ++i) { + raw_data[i] = onnxruntime::MLFloat16(onnxruntime::BFloat16::FromBits(raw_data[i])).val; + } + } + } + } +} + +Status Transform(const GraphViewer& src_graph_viewer, + const logging::Logger& logger, + /*out*/ std::unique_ptr& model) { + auto status = qdq_scales_fix::copy_model(src_graph_viewer, logger, model); + auto g = qdq_scales_fix::generate_graph_from_onnx(model->MainGraph()); + + replace_bf16_with_fp16(g); + return status; +} +} // namespace bfloat16_fix } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h index c54c531e1bd40..2182850d96c43 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.h @@ -15,5 +15,10 @@ Status Transform(const GraphViewer& src_graph, const logging::Logger& logger, /*out*/ std::unique_ptr& model); } +namespace bfloat16_fix { +Status Transform(const GraphViewer& src_graph, + const logging::Logger& logger, + /*out*/ std::unique_ptr& model); +} } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc index 24e8892622175..e010851f22e50 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc @@ -677,6 +677,27 @@ static void AddInitializerAsInput(onnxruntime::Graph& dst_graph, } } +// To check if the input parameters of a DQ or Q node are quantization parameters +// Scale and Zero point parameters are quantization parameters +static bool IsQuantizationParameter(const std::string& initializer_name, + const onnxruntime::GraphViewer& src_graph) { + // Check if this initializer is used as scale or zero_point in any DQ/Q node + for (auto& node_idx : src_graph.GetNodesInTopologicalOrder()) { + const auto* node = src_graph.GetNode(node_idx); + if (node->OpType() == "DequantizeLinear" || node->OpType() == "QuantizeLinear") { + const auto& input_defs = node->InputDefs(); + // Check if this initializer is used as scale (input 1) or zero_point (input 2) + if (input_defs.size() >= 2 && input_defs[1]->Name() == initializer_name) { + return true; // This is a scale parameter + } + if (input_defs.size() >= 3 && input_defs[2]->Name() == initializer_name) { + return true; // This is a zero_point parameter + } + } + } + return false; +} + // Creates a new model without the DQ/Q operators in the src graph. Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, const logging::Logger& logger, @@ -845,10 +866,20 @@ Status CreateModelWithStrippedQDQNodes(const GraphViewer& src_graph, if (!init_with_data && utils::HasExternalData(initializer_tensor) && enable_ovep_weight_sharing) { - insert_metadata(initializer_tensor); + // Only convert to input if it's not a quantization parameter + bool is_quant_param = IsQuantizationParameter(name, src_graph); + + if (!is_quant_param) { + // This is actual weight data - so to convert to input for weight sharing + insert_metadata(initializer_tensor); + AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, name); + } else { + // This is a quantization parameter - keep as initializer even if external - // Add initializer with external data as input - AddInitializerAsInput(dst_graph, accumulated_inputs, src_graph, name); + if (initializers_to_keep.count(name) > 0) { + dst_graph.AddInitializedTensor(initializer_tensor); + } + } } else { // Add as an initialized tensor if it does not have external data if (initializers_to_keep.count(name) > 0) { diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index a22375320edae..46958843872d7 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -66,7 +66,9 @@ ABSL_FLAG(std::string, i, "", " [OpenVINO only] [num_of_threads]: Overrides the accelerator hardware type and precision with these values at runtime.\n" " [OpenVINO only] [cache_dir]: Explicitly specify the path to dump and load the blobs(Model caching) or cl_cache (Kernel Caching) files feature. If blob files are already present, it will be directly loaded.\n" " [OpenVINO only] [enable_opencl_throttling]: Enables OpenCL queue throttling for GPU device(Reduces the CPU Utilization while using GPU) \n" - " [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" + " [OpenVINO only] [reshape_input]: Sets model input shapes with support for bounded dynamic dimensions using 'min..max' syntax (e.g., [1..10,3,224,224]) \n" + " [OpenVINO only] [layout]: Specifies the layout for inputs/outputs to interpret tensor dimensions correctly. \n" + " [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU num_of_threads|5 enable_opencl_throttling|true reshape_input|[1,3,60,60..100] layout|[NCHW] cache_dir|\"\"\"\n" "\n" " [QNN only] [backend_type]: QNN backend type. E.g., 'cpu', 'htp'. Mutually exclusive with 'backend_path'.\n" " [QNN only] [backend_path]: QNN backend path. E.g., '/folderpath/libQnnHtp.so', '/winfolderpath/QnnHtp.dll'. Mutually exclusive with 'backend_type'.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index f1a40b1da8651..1ba3078efdb1a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -863,12 +863,14 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); ov_options[key] = value; } else if (key == "reshape_input") { ov_options[key] = value; + } else if (key == "layout") { + ov_options[key] = value; } else { ORT_THROW( "[ERROR] [OpenVINO] wrong key type entered. Choose from the following runtime key options that are available for OpenVINO." " ['device_type', 'device_id', 'num_of_threads', 'load_config', 'cache_dir', 'num_streams', " "'enable_opencl_throttling', 'disable_dynamic_shapes', 'enable_qdq_optimizer'," - " 'enable_causallm', 'model_priority'] \n"); + " 'enable_causallm', 'reshape_input', 'layout', 'model_priority'] \n"); } } session_options.AppendExecutionProvider_OpenVINO_V2(ov_options); diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index a5fd37361a255..dc50a75873034 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -688,7 +688,7 @@ TEST(Loop, SubgraphTypeOverride) { Graph::ResolveOptions options; options.override_types = true; test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kTensorrtExecutionProvider}, &session_run_options, nullptr, + {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}, &session_run_options, nullptr, ExecutionMode::ORT_SEQUENTIAL, options); } @@ -1162,7 +1162,7 @@ TEST(Loop, SequenceAsLoopCarriedDependency) { test.AddSeqOutput("loop_var_0_final", seq_output); // Disable TensorRT on unsupported data type BOOL - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } #if !defined(DISABLE_OPTIONAL_TYPE) diff --git a/onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc index f4d8cad90a714..1a71da6d95135 100644 --- a/onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/dynamic_quantize_linear_test.cc @@ -11,7 +11,8 @@ namespace test { // range = [-ve, +ve] TEST(QuantizeLinearOpTest, DynamicQuantizeLinear) { // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { + if (DefaultDmlExecutionProvider().get() != nullptr || + DefaultOpenVINOExecutionProvider().get() != nullptr) { GTEST_SKIP() << "Skipping because of the following error: Expected equality of these values: 26 and 25"; } diff --git a/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc b/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc new file mode 100644 index 0000000000000..fc90563a61bb1 --- /dev/null +++ b/onnxruntime/test/providers/openvino/openvino_ep_bfloat16_pass_test.cc @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/float16.h" + +#include "test/util/include/test/test_environment.h" +#include "test/optimizer/qdq_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::logging; + +extern std::unique_ptr ort_env; + +class OVEP_BF16_Tests : public ::testing::TestWithParam {}; + +namespace detail { +auto ConstructModel() { + using namespace onnxruntime; + using namespace test; + + std::unordered_map domain_to_version; + domain_to_version[kOnnxDomain] = 19; + Model model("Bfloat16Tester", true, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, {}, DefaultLoggingManager().DefaultLogger()); + + Graph& graph = model.MainGraph(); + ModelTestBuilder builder(graph); + auto dim = 4; + std::vector input_data(dim, 1.0f); + auto* input = builder.MakeInput({dim}, input_data); + builder.graph_.SetInputs({input}); + + auto* cast_to_bf16 = builder.MakeIntermediate(); + Node& cast_node = builder.AddNode("Cast", {input}, {cast_to_bf16}, ""); + cast_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)); + + std::vector weight_data(dim * dim); + for (std::size_t i = 0; i < weight_data.size(); ++i) + weight_data[i] = onnxruntime::BFloat16(static_cast(i % dim) / dim); + auto* weights = builder.MakeInitializer({dim, dim}, weight_data); + + auto* matmul_out = builder.MakeIntermediate(); + builder.AddNode("MatMul", {cast_to_bf16, weights}, {matmul_out}); + + std::vector weight_data_2(dim * dim); + for (std::size_t i = 0; i < weight_data_2.size(); ++i) + weight_data_2[i] = onnxruntime::BFloat16(static_cast(i % dim) / dim); + auto* weights_2 = builder.MakeInitializer({dim, dim}, weight_data_2); + + auto* matmul_out_2 = builder.MakeIntermediate(); + builder.AddNode("MatMul", {matmul_out, weights_2}, {matmul_out_2}); + + auto* output = builder.MakeOutput(); + Node& cast2_node = builder.AddNode("Cast", {matmul_out_2}, {output}); + cast2_node.AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + + builder.SetGraphOutputs(); + auto st = model.MainGraph().Resolve(); + if (st != Status::OK()) + throw std::runtime_error(st.ErrorMessage()); + return model; +} + +auto ProbeDevice(const std::string& device) { + static std::map is_present; + if (is_present.find(device) == is_present.end()) { + Ort::SessionOptions sessionOptions; + std::unordered_map ov_options; + ov_options["device_type"] = device; + try { + sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); + is_present[device] = true; + } catch (...) { + is_present[device] = false; + } + } + return is_present[device]; +} +} // namespace detail + +namespace onnxruntime { +namespace test { + +TEST_P(OVEP_BF16_Tests, TestModelConversion) { + Ort::SessionOptions sessionOptions; + std::unordered_map ov_options; + const auto& device = GetParam(); + if (!::detail::ProbeDevice(device)) + GTEST_SKIP() << device + " is not available on this machine"; + + ov_options["device_type"] = device; + auto model = ::detail::ConstructModel(); + sessionOptions.AppendExecutionProvider_OpenVINO_V2(ov_options); + + std::string model_data; + model.ToProto().SerializeToString(&model_data); + auto model_data_span = AsByteSpan(model_data.data(), model_data.size()); + try { + Ort::Session session(*ort_env, model_data_span.data(), model_data_span.size(), sessionOptions); + } catch (...) { + FAIL(); + } +} +INSTANTIATE_TEST_SUITE_P(OVEP_Tests, + OVEP_BF16_Tests, + ::testing::Values("CPU", "GPU", "NPU")); +} // namespace test +} // namespace onnxruntime From c9bdbd70c50dabf4d060a9e0189cc2125e955e48 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 29 Aug 2025 14:24:21 -0400 Subject: [PATCH 13/23] [java] Auto EP and compile model support (#25131) ### Description Java API for compile model and EP discovery APIs. Roughly equivalent to the C# version in #24604. cc: @skottmckay. I haven't quite got the CMake configured so the Java tests for the ep registration only run when the ONNX Runtime shared provider support is built, but everything else works. I expect that to be a quick fix, but I'm not sure in what conditions it should be built and how we should handle it so I don't know where/when to plumb it through. ### Motivation and Context API parity for Java. --- cmake/onnxruntime_java.cmake | 4 +- cmake/onnxruntime_unittests.cmake | 4 + .../main/java/ai/onnxruntime/OnnxRuntime.java | 18 +- .../java/ai/onnxruntime/OrtEnvironment.java | 82 ++++- .../main/java/ai/onnxruntime/OrtEpDevice.java | 117 ++++++++ .../onnxruntime/{providers => }/OrtFlags.java | 4 +- .../ai/onnxruntime/OrtHardwareDevice.java | 156 ++++++++++ .../OrtModelCompilationOptions.java | 280 ++++++++++++++++++ .../main/java/ai/onnxruntime/OrtSession.java | 78 +++-- .../src/main/java/ai/onnxruntime/OrtUtil.java | 51 +++- .../ai/onnxruntime/providers/CoreMLFlags.java | 4 +- .../ai/onnxruntime/providers/NNAPIFlags.java | 4 +- java/src/main/native/OrtJniUtil.c | 30 ++ java/src/main/native/OrtJniUtil.h | 2 + .../main/native/ai_onnxruntime_OnnxRuntime.c | 13 + .../native/ai_onnxruntime_OrtEnvironment.c | 70 +++++ .../main/native/ai_onnxruntime_OrtEpDevice.c | 82 +++++ .../native/ai_onnxruntime_OrtHardwareDevice.c | 96 ++++++ ...i_onnxruntime_OrtModelCompilationOptions.c | 193 ++++++++++++ ...ai_onnxruntime_OrtSession_SessionOptions.c | 53 +++- .../java/ai/onnxruntime/CompileApiTest.java | 53 ++++ .../java/ai/onnxruntime/EpDeviceTest.java | 123 ++++++++ 22 files changed, 1484 insertions(+), 33 deletions(-) create mode 100644 java/src/main/java/ai/onnxruntime/OrtEpDevice.java rename java/src/main/java/ai/onnxruntime/{providers => }/OrtFlags.java (88%) create mode 100644 java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java create mode 100644 java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java create mode 100644 java/src/main/native/ai_onnxruntime_OrtEpDevice.c create mode 100644 java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c create mode 100644 java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c create mode 100644 java/src/test/java/ai/onnxruntime/CompileApiTest.java create mode 100644 java/src/test/java/ai/onnxruntime/EpDeviceTest.java diff --git a/cmake/onnxruntime_java.cmake b/cmake/onnxruntime_java.cmake index 6b638b3e5d8bc..7da63b523be70 100644 --- a/cmake/onnxruntime_java.cmake +++ b/cmake/onnxruntime_java.cmake @@ -159,7 +159,7 @@ if (WIN32) if(NOT onnxruntime_ENABLE_STATIC_ANALYSIS) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_JNI_DIR}/$) - if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB)) + if (TARGET onnxruntime_providers_shared) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() if (onnxruntime_USE_CUDA) @@ -207,7 +207,7 @@ if (WIN32) else() add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_JNI_DIR}/$) - if (onnxruntime_USE_CUDA OR onnxruntime_USE_DNNL OR onnxruntime_USE_OPENVINO OR onnxruntime_USE_TENSORRT OR (onnxruntime_USE_QNN AND NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB)) + if (TARGET onnxruntime_providers_shared) add_custom_command(TARGET onnxruntime4j_jni POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_PACKAGE_LIB_DIR}/$) endif() if (onnxruntime_USE_CUDA) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 6847db64004ca..b31849440c426 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1640,6 +1640,10 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") add_custom_command(TARGET onnxruntime_providers_qnn POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} ${JAVA_NATIVE_TEST_DIR}) endif() + if (WIN32) + set(EXAMPLE_PLUGIN_EP_DST_FILE_NAME $,$,$>) + add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $ ${JAVA_NATIVE_TEST_DIR}/${EXAMPLE_PLUGIN_EP_DST_FILE_NAME}) + endif() # delegate to gradle's test runner diff --git a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java index 97423ffb37251..3bb61698f5da7 100644 --- a/java/src/main/java/ai/onnxruntime/OnnxRuntime.java +++ b/java/src/main/java/ai/onnxruntime/OnnxRuntime.java @@ -42,6 +42,8 @@ final class OnnxRuntime { private static final int ORT_API_VERSION_13 = 13; // Post 1.13 builds of the ORT API private static final int ORT_API_VERSION_14 = 14; + // Post 1.22 builds of the ORT API + private static final int ORT_API_VERSION_23 = 23; // The initial release of the ORT training API. private static final int ORT_TRAINING_API_VERSION_1 = 1; @@ -103,6 +105,9 @@ final class OnnxRuntime { /** The Training API handle. */ static long ortTrainingApiHandle; + /** The Compile API handle. */ + static long ortCompileApiHandle; + /** Is training enabled in the native library */ static boolean trainingEnabled; @@ -176,12 +181,13 @@ static synchronized void init() throws IOException { } load(ONNXRUNTIME_JNI_LIBRARY_NAME); - ortApiHandle = initialiseAPIBase(ORT_API_VERSION_14); + ortApiHandle = initialiseAPIBase(ORT_API_VERSION_23); if (ortApiHandle == 0L) { throw new IllegalStateException( "There is a mismatch between the ORT class files and the ORT native library, and the native library could not be loaded"); } - ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_14); + ortTrainingApiHandle = initialiseTrainingAPIBase(ortApiHandle, ORT_API_VERSION_23); + ortCompileApiHandle = initialiseCompileAPIBase(ortApiHandle); trainingEnabled = ortTrainingApiHandle != 0L; providers = initialiseProviders(ortApiHandle); version = initialiseVersion(); @@ -499,6 +505,14 @@ private static EnumSet initialiseProviders(long ortApiHandle) { */ private static native long initialiseTrainingAPIBase(long apiHandle, int apiVersionNumber); + /** + * Get a reference to the compile API struct. + * + * @param apiHandle The ORT API struct pointer. + * @return A pointer to the compile API struct. + */ + private static native long initialiseCompileAPIBase(long apiHandle); + /** * Gets the array of available providers. * diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index 8382ef06e26e5..497772baf5357 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -8,7 +8,11 @@ import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; import java.util.EnumSet; +import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.logging.Logger; @@ -442,6 +446,48 @@ public static EnumSet getAvailableProviders() { return OnnxRuntime.providers.clone(); } + /** + * Registers an execution provider library with this OrtEnvironment. + * + * @param registrationName The name to register the library with (used to remove it later with + * {@link #unregisterExecutionProviderLibrary(String)}). + * @param libraryPath The path to the library binary on disk. + * @throws OrtException If the library could not be registered. + */ + public void registerExecutionProviderLibrary(String registrationName, String libraryPath) + throws OrtException { + registerExecutionProviderLibrary( + OnnxRuntime.ortApiHandle, nativeHandle, registrationName, libraryPath); + } + + /** + * Unregisters an execution provider library from this OrtEnvironment. + * + * @param registrationName The name the library was registered under. + * @throws OrtException If the library could not be removed. + */ + public void unregisterExecutionProviderLibrary(String registrationName) throws OrtException { + unregisterExecutionProviderLibrary(OnnxRuntime.ortApiHandle, nativeHandle, registrationName); + } + + /** + * Get the list of all execution provider and device combinations that are available. + * + * @see OrtSession.SessionOptions#addExecutionProvider(List, Map) + * @return The list of execution provider and device combinations. + * @throws OrtException If the devices could not be listed. + */ + public List getEpDevices() throws OrtException { + long[] deviceHandles = getEpDevices(OnnxRuntime.ortApiHandle, nativeHandle); + + List devicesList = new ArrayList<>(); + for (long deviceHandle : deviceHandles) { + devicesList.add(new OrtEpDevice(deviceHandle)); + } + + return Collections.unmodifiableList(devicesList); + } + /** * Creates the native object. * @@ -476,6 +522,40 @@ private static native long createHandle( */ private static native long getDefaultAllocator(long apiHandle) throws OrtException; + /** + * Registers the specified execution provider with this OrtEnvironment. + * + * @param apiHandle The API handle. + * @param nativeHandle The OrtEnvironment handle. + * @param registrationName The name of the execution provider. + * @param libraryPath The path to the execution provider binary. + * @throws OrtException If the registration failed. + */ + private static native void registerExecutionProviderLibrary( + long apiHandle, long nativeHandle, String registrationName, String libraryPath) + throws OrtException; + + /** + * Removes the specified execution provider from this OrtEnvironment. + * + * @param apiHandle The API handle. + * @param nativeHandle The OrtEnvironment handle. + * @param registrationName The name of the execution provider. + * @throws OrtException If the removal failed. + */ + private static native void unregisterExecutionProviderLibrary( + long apiHandle, long nativeHandle, String registrationName) throws OrtException; + + /** + * Gets handles for the EP device tuples available in this OrtEnvironment. + * + * @param apiHandle The API handle to use. + * @param nativeHandle The OrtEnvironment handle. + * @return An array of OrtEpDevice handles. + * @throws OrtException If the call failed. + */ + private static native long[] getEpDevices(long apiHandle, long nativeHandle) throws OrtException; + /** * Closes the OrtEnvironment, frees the handle. * diff --git a/java/src/main/java/ai/onnxruntime/OrtEpDevice.java b/java/src/main/java/ai/onnxruntime/OrtEpDevice.java new file mode 100644 index 0000000000000..f63dec1dbaf83 --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtEpDevice.java @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.util.Map; + +/** A tuple of Execution Provider information and the hardware device. */ +public final class OrtEpDevice { + + private final long nativeHandle; + + private final String epName; + private final String epVendor; + private final Map epMetadata; + private final Map epOptions; + private final OrtHardwareDevice device; + + /** + * Construct an OrtEpDevice tuple from the native pointer. + * + * @param nativeHandle The native pointer. + */ + OrtEpDevice(long nativeHandle) { + this.nativeHandle = nativeHandle; + this.epName = getName(OnnxRuntime.ortApiHandle, nativeHandle); + this.epVendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle); + String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle); + this.epMetadata = OrtUtil.convertToMap(metadata); + String[][] options = getOptions(OnnxRuntime.ortApiHandle, nativeHandle); + this.epOptions = OrtUtil.convertToMap(options); + this.device = new OrtHardwareDevice(getDeviceHandle(OnnxRuntime.ortApiHandle, nativeHandle)); + } + + /** + * Return the native pointer. + * + * @return The native pointer. + */ + long getNativeHandle() { + return nativeHandle; + } + + /** + * Gets the EP name. + * + * @return The EP name. + */ + public String getName() { + return epName; + } + + /** + * Gets the vendor name. + * + * @return The vendor name. + */ + public String getVendor() { + return epVendor; + } + + /** + * Gets an unmodifiable view on the EP metadata. + * + * @return The EP metadata. + */ + public Map getMetadata() { + return epMetadata; + } + + /** + * Gets an unmodifiable view on the EP options. + * + * @return The EP options. + */ + public Map getOptions() { + return epOptions; + } + + /** + * Gets the device information. + * + * @return The device information. + */ + public OrtHardwareDevice getDevice() { + return device; + } + + @Override + public String toString() { + return "OrtEpDevice{" + + "epName='" + + epName + + '\'' + + ", epVendor='" + + epVendor + + '\'' + + ", epMetadata=" + + epMetadata + + ", epOptions=" + + epOptions + + ", device=" + + device + + '}'; + } + + private static native String getName(long apiHandle, long nativeHandle); + + private static native String getVendor(long apiHandle, long nativeHandle); + + private static native String[][] getMetadata(long apiHandle, long nativeHandle); + + private static native String[][] getOptions(long apiHandle, long nativeHandle); + + private static native long getDeviceHandle(long apiHandle, long nativeHandle); +} diff --git a/java/src/main/java/ai/onnxruntime/providers/OrtFlags.java b/java/src/main/java/ai/onnxruntime/OrtFlags.java similarity index 88% rename from java/src/main/java/ai/onnxruntime/providers/OrtFlags.java rename to java/src/main/java/ai/onnxruntime/OrtFlags.java index 73d3eeae6499c..f57fd945dbeec 100644 --- a/java/src/main/java/ai/onnxruntime/providers/OrtFlags.java +++ b/java/src/main/java/ai/onnxruntime/OrtFlags.java @@ -1,8 +1,8 @@ /* - * Copyright (c) 2021, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ -package ai.onnxruntime.providers; +package ai.onnxruntime; import java.util.EnumSet; diff --git a/java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java b/java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java new file mode 100644 index 0000000000000..bd99f5599fd14 --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtHardwareDevice.java @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.util.Map; +import java.util.logging.Logger; + +/** Hardware information for a specific device. */ +public final class OrtHardwareDevice { + + /** The hardware device types. */ + // Must be updated in concert with the native OrtHardwareDeviceType enum in the C API + public enum OrtHardwareDeviceType { + /** A CPU device. */ + CPU(0), + /** A GPU device. */ + GPU(1), + /** A NPU (Neural Processing Unit) device. */ + NPU(2); + private final int value; + + private static final Logger logger = Logger.getLogger(OrtHardwareDeviceType.class.getName()); + private static final OrtHardwareDeviceType[] values = new OrtHardwareDeviceType[3]; + + static { + for (OrtHardwareDeviceType ot : OrtHardwareDeviceType.values()) { + values[ot.value] = ot; + } + } + + OrtHardwareDeviceType(int value) { + this.value = value; + } + + /** + * Gets the native value associated with this device type. + * + * @return The native value. + */ + public int getValue() { + return value; + } + + /** + * Maps from the C API's int enum to the Java enum. + * + * @param deviceType The index of the Java enum. + * @return The Java enum. + */ + public static OrtHardwareDeviceType mapFromInt(int deviceType) { + if ((deviceType >= 0) && (deviceType < values.length)) { + return values[deviceType]; + } else { + logger.warning("Unknown device type '" + deviceType + "' setting to CPU"); + return CPU; + } + } + } + + private final long nativeHandle; + + private final OrtHardwareDeviceType type; + private final int vendorId; + private final String vendor; + private final int deviceId; + private final Map metadata; + + OrtHardwareDevice(long nativeHandle) { + this.nativeHandle = nativeHandle; + this.type = + OrtHardwareDeviceType.mapFromInt(getDeviceType(OnnxRuntime.ortApiHandle, nativeHandle)); + this.vendorId = getVendorId(OnnxRuntime.ortApiHandle, nativeHandle); + this.vendor = getVendor(OnnxRuntime.ortApiHandle, nativeHandle); + this.deviceId = getDeviceId(OnnxRuntime.ortApiHandle, nativeHandle); + String[][] metadata = getMetadata(OnnxRuntime.ortApiHandle, nativeHandle); + this.metadata = OrtUtil.convertToMap(metadata); + } + + long getNativeHandle() { + return nativeHandle; + } + + /** + * Gets the device type. + * + * @return The device type. + */ + public OrtHardwareDeviceType getType() { + return type; + } + + /** + * Gets the vendor ID number. + * + * @return The vendor ID number. + */ + public int getVendorId() { + return vendorId; + } + + /** + * Gets the device ID number. + * + * @return The device ID number. + */ + public int getDeviceId() { + return deviceId; + } + + /** + * Gets an unmodifiable view on the device metadata. + * + * @return The device metadata. + */ + public Map getMetadata() { + return metadata; + } + + /** + * Gets the vendor name. + * + * @return The vendor name. + */ + public String getVendor() { + return vendor; + } + + @Override + public String toString() { + return "OrtHardwareDevice{" + + "type=" + + type + + ", vendorId=" + + vendorId + + ", vendor='" + + vendor + + '\'' + + ", deviceId=" + + deviceId + + ", metadata=" + + metadata + + '}'; + } + + private static native String getVendor(long apiHandle, long nativeHandle); + + private static native String[][] getMetadata(long apiHandle, long nativeHandle); + + private static native int getDeviceType(long apiHandle, long nativeHandle); + + private static native int getDeviceId(long apiHandle, long nativeHandle); + + private static native int getVendorId(long apiHandle, long nativeHandle); +} diff --git a/java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java b/java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java new file mode 100644 index 0000000000000..09b3064b72b93 --- /dev/null +++ b/java/src/main/java/ai/onnxruntime/OrtModelCompilationOptions.java @@ -0,0 +1,280 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import java.nio.ByteBuffer; +import java.util.EnumSet; + +/** Configuration options for compiling ONNX models. */ +public final class OrtModelCompilationOptions implements AutoCloseable { + /** Flags representing options when compiling a model. */ + public enum OrtCompileApiFlags implements OrtFlags { + /** Default. Do not enable any additional compilation options. */ + NONE(0), + + /** + * Force compilation to return an error (ORT_FAIL) if no nodes were compiled. Otherwise, a model + * with basic optimizations (ORT_ENABLE_BASIC) is still generated by default. + */ + ERROR_IF_NO_NODES_COMPILED(1), + + /** + * Force compilation to return an error (ORT_FAIL) if a file with the same filename as the + * output model exists. Otherwise, compilation will automatically overwrite the output file if + * it exists. + */ + ERROR_IF_OUTPUT_FILE_EXISTS(1 << 1); + + /** The native value of the enum. */ + public final int value; + + OrtCompileApiFlags(int value) { + this.value = value; + } + + @Override + public int getValue() { + return value; + } + } + + private final long nativeHandle; + private boolean closed = false; + + // Used to ensure the byte buffer doesn't get GC'd before the model is compiled. + private ByteBuffer buffer; + + OrtModelCompilationOptions(long nativeHandle) { + this.nativeHandle = nativeHandle; + } + + /** + * Creates a model compilation options from an existing SessionOptions. + * + *

An OrtModelCompilationOptions object contains the settings used to generate a compiled ONNX + * model. The OrtSessionOptions object has the execution providers with which the model will be + * compiled. + * + * @param env The OrtEnvironment. + * @param sessionOptions The session options to use. + * @return A constructed model compilation options instance. + * @throws OrtException If the construction failed. + */ + public static OrtModelCompilationOptions createFromSessionOptions( + OrtEnvironment env, OrtSession.SessionOptions sessionOptions) throws OrtException { + long handle = + createFromSessionOptions( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + env.getNativeHandle(), + sessionOptions.getNativeHandle()); + return new OrtModelCompilationOptions(handle); + } + + /** + * Checks if the OrtModelCompilationOptions is closed, if so throws {@link IllegalStateException}. + */ + private void checkClosed() { + if (closed) { + throw new IllegalStateException("Trying to use a closed OrtModelCompilationOptions."); + } + } + + @Override + public void close() { + if (!closed) { + close(OnnxRuntime.ortCompileApiHandle, nativeHandle); + closed = true; + } else { + throw new IllegalStateException("Trying to close a closed OrtModelCompilationOptions."); + } + } + + /** + * Sets the file path to the input ONNX model. + * + *

The input model's location must be set either to a path on disk with this method, or by + * supplying an in-memory reference with {@link #setInputModelFromBuffer}. + * + * @param inputModelPath The path to the model on disk. + * @throws OrtException If the set failed. + */ + public void setInputModelPath(String inputModelPath) throws OrtException { + checkClosed(); + setInputModelPath( + OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, inputModelPath); + } + + /** + * Uses the supplied buffer as the input ONNX model. + * + *

The input model's location must be set either to an in-memory reference with this method, or + * by supplying a path on disk with {@link #setInputModelPath(String)}. + * + *

If the {@link ByteBuffer} is not direct it is copied into a direct buffer. In either case + * this object holds a reference to the buffer to prevent it from being GC'd. + * + * @param inputModelBuffer The buffer. + * @throws OrtException If the buffer could not be set. + */ + public void setInputModelFromBuffer(ByteBuffer inputModelBuffer) throws OrtException { + checkClosed(); + if (!inputModelBuffer.isDirect()) { + // if it's not a direct buffer, copy it. + buffer = ByteBuffer.allocateDirect(inputModelBuffer.remaining()); + int tmpPos = inputModelBuffer.position(); + buffer.put(inputModelBuffer); + buffer.rewind(); + inputModelBuffer.position(tmpPos); + } else { + buffer = inputModelBuffer; + } + int bufferPos = buffer.position(); + int bufferRemaining = buffer.remaining(); + setInputModelFromBuffer( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + nativeHandle, + buffer, + bufferPos, + bufferRemaining); + } + + /** + * Sets the file path for the output compiled ONNX model. + * + *

If this is unset it will append `_ctx` to the file name, e.g., my_model.onnx becomes + * my_model_ctx.onnx. + * + * @param outputModelPath The output model path. + * @throws OrtException If the path could not be set. + */ + public void setOutputModelPath(String outputModelPath) throws OrtException { + checkClosed(); + setOutputModelPath( + OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, outputModelPath); + } + + /** + * Optionally sets the file that stores initializers for the compiled ONNX model. If unset then + * initializers are stored inside the model. + * + *

Only initializers for nodes that were not compiled are stored in the external initializers + * file. Compiled nodes contain their initializer data within the `ep_cache_context` attribute of + * EPContext nodes. + * + * @see OrtModelCompilationOptions#setEpContextEmbedMode + * @param outputExternalInitializersPath Path to the file. + * @param sizeThreshold Initializers larger than this threshold are stored in the file. + * @throws OrtException If the path could not be set. + */ + public void setOutputExternalInitializersPath( + String outputExternalInitializersPath, long sizeThreshold) throws OrtException { + checkClosed(); + // check positive + setOutputExternalInitializersPath( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + nativeHandle, + outputExternalInitializersPath, + sizeThreshold); + } + + /** + * Enables or disables the embedding of EPContext binary data into the ep_cache_context attribute + * of EPContext nodes. + * + *

Defaults to false. When enabled, the `ep_cache_context` attribute of EPContext nodes will + * store the context binary data, which may include weights for compiled subgraphs. When disabled, + * the `ep_cache_context` attribute of EPContext nodes will contain the path to the file + * containing the context binary data. The path is set by the execution provider creating the + * EPContext node. + * + *

For more details see the EPContext design + * document. + * + * @param embedEpContext True to embed EPContext binary data into the EPContext node's + * ep_cache_context attribute. + * @throws OrtException If the set operation failed. + */ + public void setEpContextEmbedMode(boolean embedEpContext) throws OrtException { + checkClosed(); + setEpContextEmbedMode( + OnnxRuntime.ortApiHandle, OnnxRuntime.ortCompileApiHandle, nativeHandle, embedEpContext); + } + + /** + * Sets the specified compilation flags. + * + * @param flags The compilation flags. + * @throws OrtException If the set operation failed. + */ + public void setCompilationFlags(EnumSet flags) throws OrtException { + checkClosed(); + setCompilationFlags( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + nativeHandle, + OrtFlags.aggregateToInt(flags)); + } + + /** + * Compiles the ONNX model with the configuration described by this instance of + * OrtModelCompilationOptions. + * + * @throws OrtException If the compilation failed. + */ + public void compileModel() throws OrtException { + checkClosed(); + // Safe as the environment must exist to create one of these objects. + OrtEnvironment env = OrtEnvironment.getEnvironment(); + compileModel( + OnnxRuntime.ortApiHandle, + OnnxRuntime.ortCompileApiHandle, + env.getNativeHandle(), + nativeHandle); + } + + private static native long createFromSessionOptions( + long apiHandle, long compileApiHandle, long envHandle, long nativeHandle) throws OrtException; + + private static native void close(long compileApiHandle, long nativeHandle); + + private static native void setInputModelPath( + long apiHandle, long compileApiHandle, long nativeHandle, String inputModelPath) + throws OrtException; + + private static native void setInputModelFromBuffer( + long apiHandle, + long compileApiHandle, + long nativeHandle, + ByteBuffer inputBuffer, + long bufferPos, + long bufferRemaining) + throws OrtException; + + private static native void setOutputModelPath( + long apiHandle, long compileApiHandle, long nativeHandle, String outputModelPath) + throws OrtException; + + private static native void setOutputExternalInitializersPath( + long apiHandle, + long compileApiHandle, + long nativeHandle, + String externalInitializersPath, + long sizeThreshold) + throws OrtException; + + private static native void setEpContextEmbedMode( + long apiHandle, long compileApiHandle, long nativeHandle, boolean embedEpContext) + throws OrtException; + + private static native void setCompilationFlags( + long apiHandle, long compileApiHandle, long nativeHandle, int flags) throws OrtException; + + private static native void compileModel( + long apiHandle, long compileApiHandle, long envHandle, long nativeHandle) throws OrtException; +} diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index a399d5080ca16..42dc90b71cb80 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates * Licensed under the MIT License. */ @@ -8,7 +8,6 @@ import ai.onnxruntime.providers.CoreMLFlags; import ai.onnxruntime.providers.NNAPIFlags; import ai.onnxruntime.providers.OrtCUDAProviderOptions; -import ai.onnxruntime.providers.OrtFlags; import ai.onnxruntime.providers.OrtTensorRTProviderOptions; import java.io.IOException; import java.nio.ByteBuffer; @@ -624,6 +623,10 @@ private native OnnxModelMetadata constructMetadata( *

Used to set the number of threads, optimisation level, computation backend and other * options. * + *

The order execution providers are added to an options instance is the order they will be + * considered for op node assignment, with the EP added first having priority. The CPU EP is a + * fallback and added by default. + * *

Modifying this after the session has been constructed will have no effect. * *

The SessionOptions object must not be closed until all sessions which use it are closed, as @@ -730,7 +733,7 @@ public SessionOptions() { @Override public void close() { if (!closed) { - if (customLibraryHandles.size() > 0) { + if (!customLibraryHandles.isEmpty()) { long[] longArray = new long[customLibraryHandles.size()]; for (int i = 0; i < customLibraryHandles.size(); i++) { longArray[i] = customLibraryHandles.get(i); @@ -917,10 +920,10 @@ public void registerCustomOpLibrary(String path) throws OrtException { * *

 OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); * - *

See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for more - * information on custom ops. See - * https://github.com/microsoft/onnxruntime/blob/342a5bf2b756d1a1fc6fdc582cfeac15182632fe/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc#L115 - * for an example of a custom op library registration function. + *

See Add + * Custom Op for more information on custom ops. See an example of a custom op library + * registration function here. * * @param registrationFuncName The name of the registration function to call. * @throws OrtException If there was an error finding or calling the registration function. @@ -1273,10 +1276,47 @@ public void addCoreML(EnumSet flags) throws OrtException { addCoreML(OnnxRuntime.ortApiHandle, nativeHandle, OrtFlags.aggregateToInt(flags)); } + /** + * Adds the specified execution provider and device tuples as an execution backend. + * + *

Execution provider priority is in the order added, i.e., the first provider added to a + * session options will be used first for op node assignment. + * + * @param devices The EP and device tuples. Each element must use the same EP, though they can + * use different devices. + * @param providerOptions Configuration options for the execution provider. Refer to the + * specific execution provider's documentation. + * @throws OrtException If there was an error in native code. + */ + public void addExecutionProvider(List devices, Map providerOptions) + throws OrtException { + checkClosed(); + if (devices.isEmpty()) { + throw new IllegalArgumentException("Must supply at least one OrtEpDevice"); + } + long[] deviceHandles = new long[devices.size()]; + for (int i = 0; i < devices.size(); i++) { + deviceHandles[i] = devices.get(i).getNativeHandle(); + } + String[][] optsArray = OrtUtil.unpackMap(providerOptions); + // This is valid as the environment must have been created to create the OrtEpDevice list. + long envHandle = OrtEnvironment.getEnvironment().getNativeHandle(); + addExecutionProvider( + OnnxRuntime.ortApiHandle, + envHandle, + nativeHandle, + deviceHandles, + optsArray[0], + optsArray[1]); + } + /** * Adds the named execution provider (backend) as an execution backend. This generic function * only allows a subset of execution providers. * + *

Execution provider priority is in the order added, i.e., the first provider added to a + * session options will be used first for op node assignment. + * * @param providerName The name of the execution provider. * @param providerOptions Configuration options for the execution provider. Refer to the * specific execution provider's documentation. @@ -1285,20 +1325,9 @@ public void addCoreML(EnumSet flags) throws OrtException { private void addExecutionProvider(String providerName, Map providerOptions) throws OrtException { checkClosed(); - String[] providerOptionKey = new String[providerOptions.size()]; - String[] providerOptionVal = new String[providerOptions.size()]; - int i = 0; - for (Map.Entry entry : providerOptions.entrySet()) { - providerOptionKey[i] = entry.getKey(); - providerOptionVal[i] = entry.getValue(); - i++; - } + String[][] optsArray = OrtUtil.unpackMap(providerOptions); addExecutionProvider( - OnnxRuntime.ortApiHandle, - nativeHandle, - providerName, - providerOptionKey, - providerOptionVal); + OnnxRuntime.ortApiHandle, nativeHandle, providerName, optsArray[0], optsArray[1]); } /** @@ -1484,6 +1513,15 @@ private native void addExecutionProvider( String[] providerOptionKey, String[] providerOptionVal) throws OrtException; + + private native void addExecutionProvider( + long apiHandle, + long envHandle, + long nativeHandle, + long[] deviceHandles, + String[] providerOptionKey, + String[] providerOptionVal) + throws OrtException; } /** Used to control logging and termination of a call to {@link OrtSession#run}. */ diff --git a/java/src/main/java/ai/onnxruntime/OrtUtil.java b/java/src/main/java/ai/onnxruntime/OrtUtil.java index 2f44236e4ef67..ee91fdb292baa 100644 --- a/java/src/main/java/ai/onnxruntime/OrtUtil.java +++ b/java/src/main/java/ai/onnxruntime/OrtUtil.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2025, Oracle and/or its affiliates. All rights reserved. * Copyright (c) Microsoft Corporation. All rights reserved. * Licensed under the MIT License. */ @@ -16,6 +16,9 @@ import java.nio.ShortBuffer; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.logging.Logger; /** Util code for interacting with Java arrays. */ @@ -370,6 +373,52 @@ public static boolean validateShape(long[] shape) { return valid && shape.length <= TensorInfo.MAX_DIMENSIONS; } + /** + * Converts the output of a OrtKeyValuePairs into a Java unmodifiable HashMap. + * + * @param zippedString The zipped keys and values. + * @return An unmodifiable Map. + */ + static Map convertToMap(String[][] zippedString) { + if (zippedString.length != 2) { + throw new IllegalArgumentException("Invalid zipped string, must have two arrays."); + } else if (zippedString[0].length != zippedString[1].length) { + throw new IllegalArgumentException( + "Invalid zipped string, must have two arrays of the same length."); + } + Map map = new HashMap<>(capacityFromSize(zippedString[0].length)); + for (int i = 0; i < zippedString[0].length; i++) { + map.put(zippedString[0][i], zippedString[1][i]); + } + return Collections.unmodifiableMap(map); + } + + /** + * Converts a Java string map into a pair of arrays suitable for constructing a native + * OrtKeyValuePairs object. + * + * @param map A map from string to string, with no null keys or values. + * @return A pair of String arrays. + */ + static String[][] unpackMap(Map map) { + String[] keys = new String[map.size()]; + String[] values = new String[map.size()]; + int i = 0; + for (Map.Entry entry : map.entrySet()) { + if (entry.getKey() == null || entry.getValue() == null) { + throw new IllegalArgumentException( + "Invalid map, keys and values must not be null, found key = " + + entry.getKey() + + ", value = " + + entry.getValue()); + } + keys[i] = entry.getKey(); + values[i] = entry.getValue(); + i++; + } + return new String[][] {keys, values}; + } + /** * Flatten a multidimensional String array into a single dimensional String array, reading it in a * multidimensional row-major order. diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java index 22bf940844774..15fe459dad7c8 100644 --- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java @@ -1,9 +1,11 @@ /* - * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; +import ai.onnxruntime.OrtFlags; + /** Flags for the CoreML provider. */ public enum CoreMLFlags implements OrtFlags { /** diff --git a/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java b/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java index eeaf6cc8d53bc..dd30684078717 100644 --- a/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/NNAPIFlags.java @@ -1,9 +1,11 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; +import ai.onnxruntime.OrtFlags; + /** Flags for the NNAPI provider. */ public enum NNAPIFlags implements OrtFlags { /** Enables fp16 support. */ diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index 5d8efd7b476cb..96ea8e79bc978 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -1014,6 +1014,36 @@ jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAlloca } } +jobjectArray convertOrtKeyValuePairsToArrays(JNIEnv *jniEnv, const OrtApi * api, const OrtKeyValuePairs * kvp) { + // extract pair arrays + const char* const* keys = NULL; + const char* const* values = NULL; + size_t numKeys = 0; + api->GetKeyValuePairs(kvp, &keys, &values, &numKeys); + jsize jNumKeys = safecast_size_t_to_jsize(numKeys); + + // create Java String[] + jclass stringClazz = (*jniEnv)->FindClass(jniEnv, "java/lang/String"); + jobjectArray keyArray = (*jniEnv)->NewObjectArray(jniEnv, jNumKeys, stringClazz, NULL); + jobjectArray valueArray = (*jniEnv)->NewObjectArray(jniEnv, jNumKeys, stringClazz, NULL); + + // populate Java arrays + for (jsize i = 0; i < jNumKeys; i++) { + jstring key = (*jniEnv)->NewStringUTF(jniEnv, keys[i]); + (*jniEnv)->SetObjectArrayElement(jniEnv, keyArray, i, key); + jstring value = (*jniEnv)->NewStringUTF(jniEnv, values[i]); + (*jniEnv)->SetObjectArrayElement(jniEnv, valueArray, i, value); + } + + // create Java String[][] + jclass stringArrClazz = (*jniEnv)->GetObjectClass(jniEnv, keyArray); + jobjectArray pair = (*jniEnv)->NewObjectArray(jniEnv, 2, stringArrClazz, 0); + (*jniEnv)->SetObjectArrayElement(jniEnv, pair, 0, keyArray); + (*jniEnv)->SetObjectArrayElement(jniEnv, pair, 1, valueArray); + + return pair; +} + jint throwOrtException(JNIEnv *jniEnv, int messageId, const char *message) { jstring messageStr = (*jniEnv)->NewStringUTF(jniEnv, message); diff --git a/java/src/main/native/OrtJniUtil.h b/java/src/main/native/OrtJniUtil.h index 7f41e06371f2a..040fd41264c10 100644 --- a/java/src/main/native/OrtJniUtil.h +++ b/java/src/main/native/OrtJniUtil.h @@ -78,6 +78,8 @@ jobject createMapInfoFromValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator jobject convertOrtValueToONNXValue(JNIEnv *jniEnv, const OrtApi * api, OrtAllocator* allocator, OrtValue* onnxValue); +jobjectArray convertOrtKeyValuePairsToArrays(JNIEnv *jniEnv, const OrtApi * api, const OrtKeyValuePairs * kvp); + jint throwOrtException(JNIEnv *env, int messageId, const char *message); jint convertErrorCode(OrtErrorCode code); diff --git a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c index 659f34e1fb66f..d8f5f1a3cb2db 100644 --- a/java/src/main/native/ai_onnxruntime_OnnxRuntime.c +++ b/java/src/main/native/ai_onnxruntime_OnnxRuntime.c @@ -32,6 +32,19 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseTrainingAPIBas return (jlong) trainingApi; } +/* + * Class: ai_onnxruntime_OnnxRuntime + * Method: initialiseCompileAPIBase + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OnnxRuntime_initialiseCompileAPIBase + (JNIEnv * jniEnv, jclass clazz, jlong apiHandle) { + (void)jniEnv; (void)clazz; // required JNI parameters not needed by functions which don't call back into Java. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = api->GetCompileApi(); + return (jlong) compileApi; +} + /* * Class: ai_onnxruntime_OnnxRuntime * Method: getAvailableProviders diff --git a/java/src/main/native/ai_onnxruntime_OrtEnvironment.c b/java/src/main/native/ai_onnxruntime_OrtEnvironment.c index e1b1ff1c05fe1..77b096d62ec76 100644 --- a/java/src/main/native/ai_onnxruntime_OrtEnvironment.c +++ b/java/src/main/native/ai_onnxruntime_OrtEnvironment.c @@ -60,6 +60,76 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtEnvironment_getDefaultAllocator return (jlong)allocator; } +/* + * Class: ai_onnxruntime_OrtEnvironment + * Method: registerExecutionProviderLibrary + * Signature: (JJLjava/lang/String;Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtEnvironment_registerExecutionProviderLibrary + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong nativeHandle, jstring name, jstring libraryPath) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEnv* env = (OrtEnv*) nativeHandle; + const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, NULL); +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, libraryPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, libraryPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, libraryPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, api->RegisterExecutionProviderLibrary(env, cName, newString)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, libraryPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, libraryPath, NULL); + checkOrtStatus(jniEnv, api, api->RegisterExecutionProviderLibrary(env, cName, cPath)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, libraryPath, cPath); +#endif + (*jniEnv)->ReleaseStringUTFChars(jniEnv, name, cName); +} + +/* + * Class: ai_onnxruntime_OrtEnvironment + * Method: unregisterExecutionProviderLibrary + * Signature: (JJLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtEnvironment_unregisterExecutionProviderLibrary + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong nativeHandle, jstring name) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEnv* env = (OrtEnv*) nativeHandle; + const char* cName = (*jniEnv)->GetStringUTFChars(jniEnv, name, NULL); + checkOrtStatus(jniEnv, api, api->UnregisterExecutionProviderLibrary(env, cName)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, name, cName); +} + +/* + * Class: ai_onnxruntime_OrtEnvironment + * Method: getEpDevices + * Signature: (JJ)[J + */ +JNIEXPORT jlongArray JNICALL Java_ai_onnxruntime_OrtEnvironment_getEpDevices + (JNIEnv * jniEnv, jclass jobj, jlong apiHandle, jlong nativeHandle) { + (void) jobj; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEnv* env = (OrtEnv*) nativeHandle; + size_t numDevices = 0; + const OrtEpDevice* const* devicesArr = NULL; + OrtErrorCode code = checkOrtStatus(jniEnv, api, api->GetEpDevices(env, &devicesArr, &numDevices)); + if (code != ORT_OK) { + return NULL; + } else { + jsize numDevicesInt = safecast_size_t_to_jsize(numDevices); + jlongArray outputArr = (*jniEnv)->NewLongArray(jniEnv, numDevicesInt); + (*jniEnv)->SetLongArrayRegion(jniEnv, outputArr, 0, numDevicesInt, (jlong*)devicesArr); + return outputArr; + } +} + /* * Class: ai_onnxruntime_OrtEnvironment * Method: close diff --git a/java/src/main/native/ai_onnxruntime_OrtEpDevice.c b/java/src/main/native/ai_onnxruntime_OrtEpDevice.c new file mode 100644 index 0000000000000..5a1e3092b0fb9 --- /dev/null +++ b/java/src/main/native/ai_onnxruntime_OrtEpDevice.c @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "OrtJniUtil.h" +#include "ai_onnxruntime_OrtEpDevice.h" + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getName + * Signature: (JJ)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getName + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const char* name = api->EpDevice_EpName(epDevice); + jstring nameStr = (*jniEnv)->NewStringUTF(jniEnv, name); + return nameStr; +} + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getVendor + * Signature: (JJ)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtEpDevice_getVendor + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const char* vendor = api->EpDevice_EpVendor(epDevice); + jstring vendorStr = (*jniEnv)->NewStringUTF(jniEnv, vendor); + return vendorStr; +} + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getMetadata + * Signature: (JJ)[[Ljava/lang/String; + */ +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getMetadata + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const OrtKeyValuePairs* kvp = api->EpDevice_EpMetadata(epDevice); + jobjectArray pair = convertOrtKeyValuePairsToArrays(jniEnv, api, kvp); + return pair; +} + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getOptions + * Signature: (JJ)[[Ljava/lang/String; + */ +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtEpDevice_getOptions + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const OrtKeyValuePairs* kvp = api->EpDevice_EpOptions(epDevice); + jobjectArray pair = convertOrtKeyValuePairsToArrays(jniEnv, api, kvp); + return pair; +} + +/* + * Class: ai_onnxruntime_OrtEpDevice + * Method: getDeviceHandle + * Signature: (JJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtEpDevice_getDeviceHandle + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtEpDevice* epDevice = (OrtEpDevice*) nativeHandle; + const OrtHardwareDevice* device = api->EpDevice_Device(epDevice); + return (jlong) device; +} diff --git a/java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c b/java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c new file mode 100644 index 0000000000000..3191a89c26ba1 --- /dev/null +++ b/java/src/main/native/ai_onnxruntime_OrtHardwareDevice.c @@ -0,0 +1,96 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "OrtJniUtil.h" +#include "ai_onnxruntime_OrtHardwareDevice.h" + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getVendor + * Signature: (JJ)Ljava/lang/String; + */ +JNIEXPORT jstring JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getVendor + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + const char* vendor = api->HardwareDevice_Vendor(device); + jstring vendorStr = (*jniEnv)->NewStringUTF(jniEnv, vendor); + return vendorStr; +} + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getMetadata + * Signature: (JJ)[[Ljava/lang/String; + */ +JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getMetadata + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + const OrtKeyValuePairs* kvp = api->HardwareDevice_Metadata(device); + jobjectArray pair = convertOrtKeyValuePairsToArrays(jniEnv, api, kvp); + return pair; +} + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getDeviceType + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getDeviceType + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + OrtHardwareDeviceType type = api->HardwareDevice_Type(device); + jint output = 0; + // Must be kept aligned with the Java OrtHardwareDeviceType enum. + switch (type) { + case OrtHardwareDeviceType_CPU: + output = 0; + break; + case OrtHardwareDeviceType_GPU: + output = 1; + break; + case OrtHardwareDeviceType_NPU: + output = 2; + break; + default: + throwOrtException(jniEnv, convertErrorCode(ORT_NOT_IMPLEMENTED), "Unexpected device type found. Only CPU, GPU and NPU are supported."); + break; + } + return output; +} + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getDeviceId + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getDeviceId + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + uint32_t id = api->HardwareDevice_DeviceId(device); + return (jint) id; +} + +/* + * Class: ai_onnxruntime_OrtHardwareDevice + * Method: getVendorId + * Signature: (JJ)I + */ +JNIEXPORT jint JNICALL Java_ai_onnxruntime_OrtHardwareDevice_getVendorId + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong nativeHandle) { + (void) jniEnv; (void) jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. + const OrtApi* api = (const OrtApi*) apiHandle; + OrtHardwareDevice* device = (OrtHardwareDevice*) nativeHandle; + uint32_t id = api->HardwareDevice_VendorId(device); + return (jint) id; +} diff --git a/java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c b/java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c new file mode 100644 index 0000000000000..4f79383d09766 --- /dev/null +++ b/java/src/main/native/ai_onnxruntime_OrtModelCompilationOptions.c @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2025 Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +#include +#include "onnxruntime/core/session/onnxruntime_c_api.h" +#include "OrtJniUtil.h" +#include "ai_onnxruntime_OrtModelCompilationOptions.h" + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: createFromSessionOptions + * Signature: (JJJJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_createFromSessionOptions + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong envHandle, jlong sessionOptionsHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + const OrtEnv* env = (const OrtEnv*)envHandle; + const OrtSessionOptions* sessionOptions = (const OrtSessionOptions*) sessionOptionsHandle; + OrtModelCompilationOptions* output = NULL; + checkOrtStatus(jniEnv, api, compileApi->CreateModelCompilationOptionsFromSessionOptions(env, sessionOptions, &output)); + return (jlong) output; +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: close + * Signature: (JJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_close + (JNIEnv * jniEnv, jclass jclazz, jlong compileApiHandle, jlong nativeHandle) { + (void)jniEnv; (void)jclazz; // Required JNI parameters not needed by functions which don't need to access their host object or the JVM. + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + compileApi->ReleaseModelCompilationOptions((OrtModelCompilationOptions *)nativeHandle); +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setInputModelPath + * Signature: (JJJLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setInputModelPath + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jstring modelPath) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*) compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, modelPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, modelPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, modelPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetInputModelPath(compOpts, newString)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, modelPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, modelPath, NULL); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetInputModelPath(compOpts, cPath)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, modelPath, cPath); +#endif +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setInputModelFromBuffer + * Signature: (JJJLjava/nio/ByteBuffer;JJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setInputModelFromBuffer + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jobject buffer, jlong bufferPos, jlong bufferRemaining) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + // Cast to pointers + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; + + // Extract the buffer + char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); + // Increment by bufferPos bytes + bufferArr = bufferArr + bufferPos; + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetInputModelFromBuffer(compOpts, bufferArr, bufferRemaining)); +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setOutputModelPath + * Signature: (JJJLjava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setOutputModelPath + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jstring outputPath) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*) compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, outputPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, outputPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, outputPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelPath(compOpts, newString)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, outputPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, outputPath, NULL); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelPath(compOpts, cPath)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, outputPath, cPath); +#endif +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setOutputExternalInitializersPath + * Signature: (JJJLjava/lang/String;J)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setOutputExternalInitializersPath + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jstring initializersPath, jlong threshold) { + (void) jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*) apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*) compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; +#ifdef _WIN32 + const jchar* cPath = (*jniEnv)->GetStringChars(jniEnv, initializersPath, NULL); + size_t stringLength = (*jniEnv)->GetStringLength(jniEnv, initializersPath); + wchar_t* newString = (wchar_t*)calloc(stringLength + 1, sizeof(wchar_t)); + if (newString == NULL) { + (*jniEnv)->ReleaseStringChars(jniEnv, initializersPath, cPath); + throwOrtException(jniEnv, 1, "Not enough memory"); + return; + } + wcsncpy_s(newString, stringLength + 1, (const wchar_t*)cPath, stringLength); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelExternalInitializersFile(compOpts, newString, threshold)); + free(newString); + (*jniEnv)->ReleaseStringChars(jniEnv, initializersPath, cPath); +#else + const char* cPath = (*jniEnv)->GetStringUTFChars(jniEnv, initializersPath, NULL); + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetOutputModelExternalInitializersFile(compOpts, cPath, threshold)); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, initializersPath, cPath); +#endif +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setEpContextEmbedMode + * Signature: (JJJZ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setEpContextEmbedMode + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jboolean embedMode) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetEpContextEmbedMode(compOpts, (bool) embedMode)); +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: setCompilationFlags + * Signature: (JJJI)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_setCompilationFlags + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong nativeHandle, jint flags) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; + checkOrtStatus(jniEnv, api, compileApi->ModelCompilationOptions_SetFlags(compOpts, flags)); +} + +/* + * Class: ai_onnxruntime_OrtModelCompilationOptions + * Method: compileModel + * Signature: (JJJJ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtModelCompilationOptions_compileModel + (JNIEnv * jniEnv, jclass jclazz, jlong apiHandle, jlong compileApiHandle, jlong envHandle, jlong nativeHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + const OrtCompileApi* compileApi = (const OrtCompileApi*)compileApiHandle; + const OrtEnv* env = (const OrtEnv*)envHandle; + OrtModelCompilationOptions* compOpts = (OrtModelCompilationOptions *) nativeHandle; + checkOrtStatus(jniEnv, api, compileApi->CompileModel(env, compOpts)); +} diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index ff6b7fa703e6e..95bcdf7af9746 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -718,11 +718,11 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addROC } /* - * Class:: ai_onnxruntime_OrtSession_SessionOptions + * Class: ai_onnxruntime_OrtSession_SessionOptions * Method: addExecutionProvider - * Signature: (JILjava/lang/String)V + * Signature: (JJLjava/lang/String;[Ljava/lang/String;[Ljava/lang/String;)V */ -JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExecutionProvider( +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExecutionProvider__JJLjava_lang_String_2_3Ljava_lang_String_2_3Ljava_lang_String_2( JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jstring jepName, jobjectArray configKeyArr, jobjectArray configValueArr) { (void)jobj; @@ -756,3 +756,50 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExe free((void*)jkeyArray); free((void*)jvalueArray); } + +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: addExecutionProvider + * Signature: (JJJ[J[Ljava/lang/String;[Ljava/lang/String;)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addExecutionProvider__JJJ_3J_3Ljava_lang_String_2_3Ljava_lang_String_2 + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong envHandle, jlong optionsHandle, jlongArray deviceHandleArr, jobjectArray configKeyArr, jobjectArray configValueArr) { + (void)jobj; + + const OrtApi* api = (const OrtApi*)apiHandle; + OrtEnv* env = (OrtEnv*) envHandle; + OrtSessionOptions* options = (OrtSessionOptions*)optionsHandle; + jsize deviceCount = (*jniEnv)->GetArrayLength(jniEnv, deviceHandleArr); + jsize keyCount = (*jniEnv)->GetArrayLength(jniEnv, configKeyArr); + + const char** keyArray = (const char**)allocarray(keyCount, sizeof(const char*)); + const char** valueArray = (const char**)allocarray(keyCount, sizeof(const char*)); + jstring* jkeyArray = (jstring*)allocarray(keyCount, sizeof(jstring)); + jstring* jvalueArray = (jstring*)allocarray(keyCount, sizeof(jstring)); + const OrtEpDevice** devicePtrs = allocarray(deviceCount, sizeof(OrtEpDevice *)); + + jlong* deviceHandleElements = (*jniEnv)->GetLongArrayElements(jniEnv, deviceHandleArr, NULL); + for (jsize i = 0; i < deviceCount; i++) { + devicePtrs[i] = (OrtEpDevice*) deviceHandleElements[i]; + } + (*jniEnv)->ReleaseLongArrayElements(jniEnv, deviceHandleArr, deviceHandleElements, JNI_ABORT); + + for (jsize i = 0; i < keyCount; i++) { + jkeyArray[i] = (jstring)((*jniEnv)->GetObjectArrayElement(jniEnv, configKeyArr, i)); + jvalueArray[i] = (jstring)((*jniEnv)->GetObjectArrayElement(jniEnv, configValueArr, i)); + keyArray[i] = (*jniEnv)->GetStringUTFChars(jniEnv, jkeyArray[i], NULL); + valueArray[i] = (*jniEnv)->GetStringUTFChars(jniEnv, jvalueArray[i], NULL); + } + + checkOrtStatus(jniEnv, api, api->SessionOptionsAppendExecutionProvider_V2(options, env, devicePtrs, deviceCount, keyArray, valueArray, keyCount)); + + for (jsize i = 0; i < keyCount; i++) { + (*jniEnv)->ReleaseStringUTFChars(jniEnv, jkeyArray[i], keyArray[i]); + (*jniEnv)->ReleaseStringUTFChars(jniEnv, jvalueArray[i], valueArray[i]); + } + free((void*)devicePtrs); + free((void*)keyArray); + free((void*)valueArray); + free((void*)jkeyArray); + free((void*)jvalueArray); +} diff --git a/java/src/test/java/ai/onnxruntime/CompileApiTest.java b/java/src/test/java/ai/onnxruntime/CompileApiTest.java new file mode 100644 index 0000000000000..b70f4dca5cbd0 --- /dev/null +++ b/java/src/test/java/ai/onnxruntime/CompileApiTest.java @@ -0,0 +1,53 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import ai.onnxruntime.OrtSession.SessionOptions; +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +/** Test for the compilation API. */ +public class CompileApiTest { + private final OrtEnvironment env = OrtEnvironment.getEnvironment(); + + @Test + public void basicUsage() throws OrtException, IOException { + SessionOptions so = new SessionOptions(); + try (OrtModelCompilationOptions compileOptions = + OrtModelCompilationOptions.createFromSessionOptions(env, so)) { + // mainly checking these don't throw which ensures all the plumbing for the binding works. + compileOptions.setInputModelPath("model.onnx"); + compileOptions.setOutputModelPath("compiled_model.onnx"); + + compileOptions.setOutputExternalInitializersPath("external_data.bin", 512); + compileOptions.setEpContextEmbedMode(true); + } + + try (OrtModelCompilationOptions compileOptions = + OrtModelCompilationOptions.createFromSessionOptions(env, so)) { + Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); + byte[] modelBytes = Files.readAllBytes(modelPath); + ByteBuffer modelBuffer = ByteBuffer.wrap(modelBytes); + compileOptions.setInputModelFromBuffer(modelBuffer); + compileOptions.setOutputModelPath("compiled_model.onnx"); + + File f = new File("compiled_model.onnx"); + + compileOptions.compileModel(); + + // Check the compiled model is valid + try (OrtSession session = env.createSession(f.toString(), so)) { + Assertions.assertNotNull(session); + } + + f.delete(); + } + } +} diff --git a/java/src/test/java/ai/onnxruntime/EpDeviceTest.java b/java/src/test/java/ai/onnxruntime/EpDeviceTest.java new file mode 100644 index 0000000000000..ec4c977508c8c --- /dev/null +++ b/java/src/test/java/ai/onnxruntime/EpDeviceTest.java @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2025, Oracle and/or its affiliates. All rights reserved. + * Licensed under the MIT License. + */ +package ai.onnxruntime; + +import ai.onnxruntime.OrtHardwareDevice.OrtHardwareDeviceType; +import ai.onnxruntime.OrtSession.SessionOptions; +import java.io.File; +import java.nio.file.Path; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledOnOs; +import org.junit.jupiter.api.condition.OS; + +/** Tests for {@link OrtEpDevice} and {@link OrtHardwareDevice}. */ +@EnabledOnOs(value = OS.WINDOWS) +public class EpDeviceTest { + private final OrtEnvironment ortEnv = OrtEnvironment.getEnvironment(); + + private void readHardwareDeviceValues(OrtHardwareDevice device) { + OrtHardwareDeviceType type = device.getType(); + + Assertions.assertTrue( + type == OrtHardwareDeviceType.CPU + || type == OrtHardwareDeviceType.GPU + || type == OrtHardwareDeviceType.NPU); + + if (type == OrtHardwareDeviceType.CPU) { + Assertions.assertFalse(device.getVendor().isEmpty()); + } else { + Assertions.assertTrue(device.getVendorId() != 0); + Assertions.assertTrue(device.getDeviceId() != 0); + } + + Map metadata = device.getMetadata(); + Assertions.assertNotNull(metadata); + for (Map.Entry kvp : metadata.entrySet()) { + Assertions.assertFalse(kvp.getKey().isEmpty()); + } + } + + @Test + public void getEpDevices() throws OrtException { + List epDevices = ortEnv.getEpDevices(); + Assertions.assertNotNull(epDevices); + Assertions.assertFalse(epDevices.isEmpty()); + for (OrtEpDevice epDevice : epDevices) { + Assertions.assertFalse(epDevice.getName().isEmpty()); + Assertions.assertFalse(epDevice.getVendor().isEmpty()); + Map metadata = epDevice.getMetadata(); + Assertions.assertNotNull(metadata); + Map options = epDevice.getOptions(); + Assertions.assertNotNull(options); + readHardwareDeviceValues(epDevice.getDevice()); + } + } + + @Test + public void registerUnregisterLibrary() throws OrtException { + String libFullPath = TestHelpers.getResourcePath("/example_plugin_ep.dll").toString(); + Assertions.assertTrue( + new File(libFullPath).exists(), "Expected lib " + libFullPath + " does not exist."); + + // example plugin ep uses the registration name as the ep name + String epName = "java_ep"; + + // register. shouldn't throw + ortEnv.registerExecutionProviderLibrary(epName, libFullPath); + + // check OrtEpDevice was found + List epDevices = ortEnv.getEpDevices(); + boolean found = epDevices.stream().anyMatch(a -> a.getName().equals(epName)); + Assertions.assertTrue(found); + + // unregister + ortEnv.unregisterExecutionProviderLibrary(epName); + } + + @Test + public void appendToSessionOptionsV2() { + Consumer>> runTest = + (Supplier> options) -> { + try (SessionOptions sessionOptions = new SessionOptions()) { + sessionOptions.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE); + + List epDevices = ortEnv.getEpDevices(); + + // cpu ep ignores the provider options so we can use any value in epOptions and it won't + // break. + List selectedEpDevices = + epDevices.stream() + .filter(a -> a.getName().equals("CPUExecutionProvider")) + .collect(Collectors.toList()); + + Map epOptions = options.get(); + sessionOptions.addExecutionProvider(selectedEpDevices, epOptions); + + Path model = TestHelpers.getResourcePath("/squeezenet.onnx"); + String modelPath = model.toString(); + + // session should load successfully + try (OrtSession session = ortEnv.createSession(modelPath, sessionOptions)) { + Assertions.assertNotNull(session); + } + } catch (OrtException e) { + throw new RuntimeException(e); + } + }; + + // empty options + runTest.accept(Collections::emptyMap); + + // dummy options + runTest.accept(() -> Collections.singletonMap("random_key", "value")); + } +} From d51430c9f338bcc47f7dafe25ae5d33b0398eb7f Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 29 Aug 2025 13:53:07 -0700 Subject: [PATCH 14/23] Add error handling to extract_nuget_files.ps1 (#25866) ### Description 1. Check process exit code when running 7z.exe . Currently the errors were silently ignored. 2. Add snld20 flag to the 7z.exe commands, which is needed to be compatible with the latest 7z release. --- .../github/windows/extract_nuget_files.ps1 | 148 ++++++++++-------- .../windows/extract_nuget_files_gpu.ps1 | 86 +++++++--- 2 files changed, 141 insertions(+), 93 deletions(-) diff --git a/tools/ci_build/github/windows/extract_nuget_files.ps1 b/tools/ci_build/github/windows/extract_nuget_files.ps1 index ff8f63a85b97a..20d6c1f2b63a5 100644 --- a/tools/ci_build/github/windows/extract_nuget_files.ps1 +++ b/tools/ci_build/github/windows/extract_nuget_files.ps1 @@ -1,105 +1,119 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline +# This file is used by Zip-Nuget-Java Packaging Pipeline -# Re-construct a build directory that contains binaries from all the different platforms we're including -# in the native ORT nuget package +# Define the directory for NuGet artifacts. $nuget_artifacts_dir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" -New-Item -Path $nuget_artifacts_dir -ItemType directory +# Create the directory if it doesn't exist. +New-Item -Path $nuget_artifacts_dir -ItemType directory -ErrorAction SilentlyContinue ## .zip files -# unzip directly -# exclude the iOS xcframework as we need to leave that zipped up to preserve symlinks -Get-ChildItem -Path $Env:BUILD_BINARIESDIRECTORY\nuget-artifact\* -Include *.zip -Exclude onnxruntime_ios_xcframework.*.zip | +# Unzip files directly, excluding the iOS xcframework to preserve its symlinks. +Get-ChildItem -Path "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact\*" -Include *.zip -Exclude onnxruntime_ios_xcframework.*.zip | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$nuget_artifacts_dir" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + # Directly call 7z.exe using the call operator '&' + & 7z.exe $arguments + # Check the exit code of the last command. A non-zero code indicates an error. + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } ## .tgz files -# first extract the tar file from the tgz -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tgz | +# First, extract the .tar file from the .tgz archive. +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tgz | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } -# now extract the actual folder structure from the tar file to the build dir -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tar | +# Now, extract the contents from the .tar file into the final directory. +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tar | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$nuget_artifacts_dir" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } -# process iOS xcframework -$xcframeworks = Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter onnxruntime_ios_xcframework.*.zip +# Process iOS xcframework +$xcframeworks = Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter onnxruntime_ios_xcframework.*.zip if ($xcframeworks.Count -eq 1) { - $xcframework = $xcframeworks[0] - $target_dir = "$nuget_artifacts_dir\onnxruntime-ios-xcframework" - # remove version info from filename and use required filename format - $target_file = "$target_dir\onnxruntime.xcframework.zip" - New-Item -Path $target_dir -ItemType directory + $xcframework = $xcframeworks[0] + $target_dir = "$nuget_artifacts_dir\onnxruntime-ios-xcframework" + # Use the required filename format, removing version info. + $target_file = "$target_dir\onnxruntime.xcframework.zip" + New-Item -Path $target_dir -ItemType directory -ErrorAction SilentlyContinue - Write-Output "Copy-Item $($xcframework.FullName) $target_file" - Copy-Item $xcframework.FullName $target_file + Write-Output "Copying $($xcframework.FullName) to $target_file" + Copy-Item $xcframework.FullName $target_file } elseif ($xcframeworks.Count -gt 1) { - Write-Error "Expected at most one onnxruntime_ios_xcframework*.zip file but got: [$xcframeworks]" + Write-Error "Expected at most one onnxruntime_ios_xcframework*.zip file but got: [$xcframeworks]" } - -# copy android AAR. -# for full build of onnxruntime Android AAR, there should only be one .aar file -# called onnxruntime-android-x.y.z.aar or onnxruntime-training-android-x.y.z.aar but sanity check that -$aars = Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.aar +# Copy Android AAR file. +# There should only be one .aar file for a full build. +$aars = Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.aar if ($aars.Count -eq 1) { - $aar = $aars[0] - $aar_prefix = "onnxruntime" - if ($aar -like "onnxruntime-training*") { - $aar_prefix = "onnxruntime-training" - } - $target_dir = "$nuget_artifacts_dir\$aar_prefix-android-aar" - $target_file = "$target_dir\onnxruntime.aar" # remove '-mobile' and version info from filename - New-Item -Path $target_dir -ItemType directory + $aar = $aars[0] + $aar_prefix = "onnxruntime" + if ($aar.Name -like "onnxruntime-training*") { + $aar_prefix = "onnxruntime-training" + } + $target_dir = "$nuget_artifacts_dir\$aar_prefix-android-aar" + # Remove version info from the filename for consistency. + $target_file = "$target_dir\onnxruntime.aar" + New-Item -Path $target_dir -ItemType directory -ErrorAction SilentlyContinue - Write-Output "Copy-Item $($aar.FullName) $target_file" - Copy-Item $aar.FullName $target_file + Write-Output "Copying $($aar.FullName) to $target_file" + Copy-Item $aar.FullName $target_file } elseif ($aars.Count -gt 1) { - Write-Error "Expected at most one Android .aar file but got: [$aars]" + Write-Error "Expected at most one Android .aar file but got: [$aars]" } -# Check whether this is a training pipeline -$is_training_pipeline = $false -if (Test-Path -Path $nuget_artifacts_dir\onnxruntime-training-win-x64-*) { - $is_training_pipeline = $true - Write-Output "onnxruntime-training-win-x64-* dir exists. This is a training pipeline." +# Check if this is a training pipeline by looking for a specific directory. +$is_training_pipeline = Test-Path -Path "$nuget_artifacts_dir\onnxruntime-training-win-x64-*" +if ($is_training_pipeline) { + Write-Output "onnxruntime-training-win-x64-* dir exists. This is a training pipeline." } -# Copy onnxruntime and protoc binaries to the binaries dir as these are required -# by Microsoft.ML.OnnxRuntime.Tests.NetCoreApp +# Copy onnxruntime and protoc binaries required by tests. +$destinationDir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo" if ($is_training_pipeline) { - Copy-Item -Path $nuget_artifacts_dir\onnxruntime-training-win-x64-*\lib\* -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo + Copy-Item -Path "$nuget_artifacts_dir\onnxruntime-training-win-x64-*\lib\*" -Destination $destinationDir -Recurse } else { - Copy-Item -Path $nuget_artifacts_dir\onnxruntime-win-x64-*\lib\* -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo + Copy-Item -Path "$nuget_artifacts_dir\onnxruntime-win-x64-*\lib\*" -Destination $destinationDir -Recurse } -"Get-ChildItem -Directory -Path $nuget_artifacts_dir\onnxruntime-*" -$ort_dirs = Get-ChildItem -Directory -Path $nuget_artifacts_dir\onnxruntime-* -foreach ($ort_dir in $ort_dirs) -{ - # remove the last '-xxx' segment from the dir name. typically that's the architecture. - $dirname = Split-Path -Path $ort_dir -Leaf - $dirname = $dirname.SubString(0,$dirname.LastIndexOf('-')) - Write-Output "Renaming $ort_dir to $dirname" - Rename-Item -Path $ort_dir -NewName $nuget_artifacts_dir\$dirname +# Rename directories to remove the architecture-specific suffix. +Write-Output "Renaming onnxruntime directories..." +Get-ChildItem -Directory -Path "$nuget_artifacts_dir\onnxruntime-*" | ForEach-Object { + $dirname = $_.Name + # Find the last hyphen and remove the suffix. + $lastHyphenIndex = $dirname.LastIndexOf('-') + if ($lastHyphenIndex -gt -1) { + $newName = $dirname.Substring(0, $lastHyphenIndex) + $newPath = Join-Path -Path $_.Parent.FullName -ChildPath $newName + Write-Output "Renaming '$($_.FullName)' to '$newPath'" + Rename-Item -Path $_.FullName -NewName $newName + } } -# List artifacts -"Post copy artifacts" -Get-ChildItem -Recurse $nuget_artifacts_dir\ +# List the final artifacts. +Write-Output "Post-copy artifacts:" +Get-ChildItem -Recurse $nuget_artifacts_dir \ No newline at end of file diff --git a/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 b/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 index 01a8eebe75df2..29946dcb73f8a 100644 --- a/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 +++ b/tools/ci_build/github/windows/extract_nuget_files_gpu.ps1 @@ -2,47 +2,81 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget-Java Packaging Pipeline -New-Item -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts -ItemType directory +# Define the directory for NuGet artifacts. +$nuget_artifacts_dir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" +# Create the directory if it doesn't exist. +New-Item -Path $nuget_artifacts_dir -ItemType directory -ErrorAction SilentlyContinue -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.zip | +## .zip files +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.zip | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tgz | +## .tgz files +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tgz | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" # *.tar will be created after *.tgz is extracted - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + # *.tar will be created after *.tgz is extracted + $arguments = "x", "$($_.FullName)", "-y", "-o$Env:BUILD_BINARIESDIRECTORY\nuget-artifact", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } -Get-ChildItem $Env:BUILD_BINARIESDIRECTORY\nuget-artifact -Filter *.tar | +## .tar files +Get-ChildItem "$Env:BUILD_BINARIESDIRECTORY\nuget-artifact" -Filter *.tar | Foreach-Object { - $cmd = "7z.exe x $($_.FullName) -y -o$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts" - Write-Output $cmd - Invoke-Expression -Command $cmd + # The -snld20 flag is used to bypass security checks for creating symbolic links (added in 7-Zip 25.01). + $arguments = "x", "$($_.FullName)", "-y", "-o$nuget_artifacts_dir", "-snld20" + Write-Output "Executing: 7z.exe $arguments" + & 7z.exe $arguments + if ($LASTEXITCODE -ne 0) { + throw "Error extracting '$($_.FullName)'. Exit code: $LASTEXITCODE" + } } +# Create directory for protobuf build dependencies. +New-Item -Path "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo" -ItemType directory -ErrorAction SilentlyContinue -New-Item -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo -ItemType directory - -Copy-Item -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-x64-cuda-*\lib\* -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo +# Copy CUDA libraries. +Copy-Item -Path "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-win-x64-cuda-*\lib\*" -Destination "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo" +# Install protoc via dotnet. $protocInstallDir = "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build" dotnet new console dotnet add package Google.Protobuf.Tools --version 3.21.12 --package-directory $protocInstallDir +if ($LASTEXITCODE -ne 0) { + throw "Error adding Google.Protobuf.Tools package. Exit code: $LASTEXITCODE" +} + +# Find and copy the protoc executable. $protocDir = Get-ChildItem -Path $protocInstallDir -Recurse -Filter "protoc.exe" | Select-Object -ExpandProperty DirectoryName -First 1 -Write-Output $protocDir -Copy-Item -Path $protocDir -Destination $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo - -$ort_dirs = Get-ChildItem -Path $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-* -Directory -foreach ($ort_dir in $ort_dirs) -{ - $dirname = Split-Path -Path $ort_dir -Leaf - $dirname = $dirname.SubString(0,$dirname.LastIndexOf('-')) - Write-Output "Renaming $ort_dir to $dirname" - Rename-Item -Path $ort_dir -NewName $Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\$dirname +if ($protocDir) { + Write-Output "Found protoc directory: $protocDir" + Copy-Item -Path $protocDir -Destination "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\_deps\protobuf-build\RelWithDebInfo" +} +else { + Write-Error "Could not find protoc.exe in $protocInstallDir" } +# Rename onnxruntime directories to a generic format. +$ort_dirs = Get-ChildItem -Path "$Env:BUILD_BINARIESDIRECTORY\RelWithDebInfo\RelWithDebInfo\nuget-artifacts\onnxruntime-*" -Directory +foreach ($ort_dir in $ort_dirs) { + $dirname = Split-Path -Path $ort_dir -Leaf + $lastHyphenIndex = $dirname.LastIndexOf('-') + if ($lastHyphenIndex -gt -1) { + $newName = $dirname.Substring(0, $lastHyphenIndex) + $newPath = Join-Path -Path $ort_dir.Parent.FullName -ChildPath $newName + Write-Output "Renaming '$($ort_dir.FullName)' to '$newPath'" + Rename-Item -Path $ort_dir.FullName -NewName $newName + } +} From 928df7cf25f0425e4b3250901a08a72acfd561b0 Mon Sep 17 00:00:00 2001 From: mingyue <131847423+mingyueliuh@users.noreply.github.com> Date: Fri, 29 Aug 2025 16:27:23 -0500 Subject: [PATCH 15/23] [Fix] illegal memory access in GetInputIndices with optional inputs (#25881) ### Description Fix illegal memory access in GetInputIndices with optional inputs ### Motivation and Context When an input is optional, its ValueInfo may be nullptr. The current implementation directly calls InputValueInfo->GetName(), leading to illegal memory access. Update logic to skip optional inputs when valueInfo is nullptr . --- onnxruntime/core/graph/ep_api_types.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 0d9b93631ee8a..92eb31f0ad385 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -327,6 +327,9 @@ static Status GetInputIndices(const EpNode& consumer_node, [&found, &value_info_name, &indices](gsl::span input_value_infos, bool is_implicit) -> void { for (size_t i = 0; i < input_value_infos.size(); i++) { + if (input_value_infos[i] == nullptr) { // input_value_info == nullptr means the input is optional + continue; + } if (input_value_infos[i]->GetName() == value_info_name) { indices.push_back(is_implicit ? -1 : static_cast(i)); found = true; From 69ec7b17307f2b94e20199220108e7a415377cf1 Mon Sep 17 00:00:00 2001 From: Edward Chen <18449977+edgchen1@users.noreply.github.com> Date: Fri, 29 Aug 2025 16:59:36 -0700 Subject: [PATCH 16/23] Re-enable cpuinfo for ARM64EC (#25863) ### Description Re-enable cpuinfo for ARM64EC build and fix `CPUIDINFO_ARCH_ARM` so it is actually used. Patch cpuinfo to support vcpkg ARM64EC build. See https://github.com/pytorch/cpuinfo/pull/324. ### Motivation and Context Fix for workaround in #25831. --- cmake/CMakeLists.txt | 8 +- .../external/onnxruntime_external_deps.cmake | 61 ++++++------- cmake/onnxruntime.cmake | 13 ++- cmake/onnxruntime_common.cmake | 57 +----------- cmake/onnxruntime_nodejs.cmake | 1 + .../cpuinfo/patch_vcpkg_arm64ec_support.patch | 91 +++++++++++++++++++ .../cpuinfo/patch_vcpkg_arm64ec_support.patch | 91 +++++++++++++++++++ cmake/vcpkg-ports/cpuinfo/portfile.cmake | 1 + .../core/common/cpuid_arch_definition.h | 2 +- .../test/platform/device_discovery_test.cc | 4 +- 10 files changed, 232 insertions(+), 97 deletions(-) create mode 100644 cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch create mode 100644 cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 98548957d0b42..40e6a8da28e45 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1607,7 +1607,6 @@ if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Linux") endif() endif() - #Now the 'onnxruntime_EXTERNAL_LIBRARIES' variable should be sealed. It will be used in onnxruntime.cmake which will be included in the next. #The order of the following targets matters. Right depends on left. If target A appears before target B. Then A.cmake can not use variables defined in B.cmake. set(ONNXRUNTIME_CMAKE_FILES onnxruntime_flatbuffers onnxruntime_common onnxruntime_mlas onnxruntime_graph onnxruntime_lora onnxruntime_framework onnxruntime_util onnxruntime_providers onnxruntime_optimizer onnxruntime_session ${ONNXRUNTIME_EAGER_CMAKE_FILE_NAME}) @@ -1623,9 +1622,6 @@ if (onnxruntime_USE_WINML) list(APPEND ONNXRUNTIME_CMAKE_FILES winml) endif() # if (onnxruntime_USE_WINML) -if (onnxruntime_BUILD_APPLE_FRAMEWORK AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin|iOS|visionOS|tvOS") - message(FATAL_ERROR "onnxruntime_BUILD_APPLE_FRAMEWORK can only be enabled for macOS or iOS or visionOS or tvOS.") -endif() list(APPEND ONNXRUNTIME_CMAKE_FILES onnxruntime) if (onnxruntime_BUILD_JAVA) @@ -1690,8 +1686,8 @@ if (WIN32 AND NOT GDK_PLATFORM AND NOT CMAKE_CROSSCOMPILING) endif() endif() -foreach(target_name ${ONNXRUNTIME_CMAKE_FILES}) - include(${target_name}.cmake) +foreach(onnxruntime_cmake_file ${ONNXRUNTIME_CMAKE_FILES}) + include(${onnxruntime_cmake_file}.cmake) endforeach() if (UNIX) option(BUILD_PKGCONFIG_FILES "Build and install pkg-config files" ON) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 3095968795d1a..827be3e6dea2a 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -313,41 +313,32 @@ onnxruntime_fetchcontent_makeavailable(nlohmann_json) if (onnxruntime_ENABLE_CPUINFO) # Adding pytorch CPU info library # TODO!! need a better way to find out the supported architectures - list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) + set(CPUINFO_SUPPORTED FALSE) if (APPLE) + list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) if (CMAKE_OSX_ARCHITECTURES_LEN LESS_EQUAL 1) set(CPUINFO_SUPPORTED TRUE) - elseif (onnxruntime_BUILD_APPLE_FRAMEWORK) - # We stitch multiple static libraries together when onnxruntime_BUILD_APPLE_FRAMEWORK is true, - # but that would not work for universal static libraries - message(FATAL_ERROR "universal binary is not supported for apple framework") - endif() - else() - # if xnnpack is enabled in a wasm build it needs clog from cpuinfo, but we won't internally use cpuinfo - # so we don't set CPUINFO_SUPPORTED in the CXX flags below. - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_USE_XNNPACK) - set(CPUINFO_SUPPORTED FALSE) else() + message(WARNING "cpuinfo is not supported when CMAKE_OSX_ARCHITECTURES has more than one value.") + endif() + elseif (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # if xnnpack is enabled in a wasm build it needs clog from cpuinfo, but we won't internally use cpuinfo. + if (onnxruntime_USE_XNNPACK) set(CPUINFO_SUPPORTED TRUE) endif() - if (WIN32) - # There's an error when linking with cpuinfo on arm64ec with a vcpkg build (--use_vcpkg). - # TODO Fix it and then re-enable cpuinfo on arm64ec. - if (onnxruntime_target_platform STREQUAL "ARM64EC") - set(CPUINFO_SUPPORTED FALSE) - else() - set(CPUINFO_SUPPORTED TRUE) - endif() - elseif (NOT ${onnxruntime_target_platform} MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$") - message(WARNING - "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo. " - "cpuinfo not included." - ) - set(CPUINFO_SUPPORTED FALSE) + elseif (WIN32) + set(CPUINFO_SUPPORTED TRUE) + else() + if (onnxruntime_target_platform MATCHES "^(i[3-6]86|AMD64|x86(_64)?|armv[5-8].*|aarch64|arm64)$") + set(CPUINFO_SUPPORTED TRUE) + else() + message(WARNING "Target processor architecture \"${onnxruntime_target_platform}\" is not supported in cpuinfo.") endif() endif() -else() - set(CPUINFO_SUPPORTED FALSE) + + if(NOT CPUINFO_SUPPORTED) + message(WARNING "onnxruntime_ENABLE_CPUINFO was set but cpuinfo is not supported.") + endif() endif() if (CPUINFO_SUPPORTED) @@ -358,23 +349,26 @@ if (CPUINFO_SUPPORTED) # if this is a wasm build with xnnpack (only type of wasm build where cpuinfo is involved) # we do not use cpuinfo in ORT code, so don't define CPUINFO_SUPPORTED. - if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - string(APPEND CMAKE_CXX_FLAGS " -DCPUINFO_SUPPORTED") + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND onnxruntime_USE_XNNPACK) + else() + add_compile_definitions(CPUINFO_SUPPORTED) endif() - set(CPUINFO_BUILD_TOOLS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_UNIT_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_BENCHMARKS OFF CACHE INTERNAL "") if (onnxruntime_target_platform STREQUAL "ARM64EC" OR onnxruntime_target_platform STREQUAL "ARM64") - message(STATUS "Applying a patch for Windows ARM64/ARM64EC in cpuinfo") + message(STATUS "Applying patches for Windows ARM64/ARM64EC in cpuinfo") onnxruntime_fetchcontent_declare( pytorch_cpuinfo URL ${DEP_URL_pytorch_cpuinfo} URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} EXCLUDE_FROM_ALL - PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch + PATCH_COMMAND + ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch && + # https://github.com/pytorch/cpuinfo/pull/324 + ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch FIND_PACKAGE_ARGS NAMES cpuinfo ) else() @@ -584,8 +578,7 @@ endif() set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES_XNNPACK} ${WIL_TARGET} nlohmann_json::nlohmann_json onnx onnx_proto ${PROTOBUF_LIB} re2::re2 Boost::mp11 safeint_interface - flatbuffers::flatbuffers ${GSL_TARGET} ${ABSEIL_LIBS} date::date - ${ONNXRUNTIME_CLOG_TARGET_NAME} Eigen3::Eigen) + flatbuffers::flatbuffers ${GSL_TARGET} ${ABSEIL_LIBS} date::date Eigen3::Eigen) # The source code of onnx_proto is generated, we must build this lib first before starting to compile the other source code that uses ONNX protobuf types. # The other libs do not have the problem. All the sources are already there. We can compile them in any order. diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 010696a61022c..e1d98109208d4 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -350,8 +350,19 @@ if (winml_is_inbox) endif() endif() -# Assemble the Apple static framework (iOS and macOS) +# Assemble the Apple static framework if(onnxruntime_BUILD_APPLE_FRAMEWORK) + if (NOT CMAKE_SYSTEM_NAME MATCHES "Darwin|iOS|visionOS|tvOS") + message(FATAL_ERROR "onnxruntime_BUILD_APPLE_FRAMEWORK can only be enabled for macOS or iOS or visionOS or tvOS.") + endif() + + list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) + if (CMAKE_OSX_ARCHITECTURES_LEN GREATER 1) + # We stitch multiple static libraries together when onnxruntime_BUILD_APPLE_FRAMEWORK is true, + # but that would not work for universal static libraries + message(FATAL_ERROR "universal binary is not supported for apple framework") + endif() + # when building for mac catalyst, the CMAKE_OSX_SYSROOT is set to MacOSX as well, to avoid duplication, # we specify as `-macabi` in the name of the output static apple framework directory. if (PLATFORM_NAME STREQUAL "macabi") diff --git a/cmake/onnxruntime_common.cmake b/cmake/onnxruntime_common.cmake index d927489372e7c..0218994e537a0 100644 --- a/cmake/onnxruntime_common.cmake +++ b/cmake/onnxruntime_common.cmake @@ -194,59 +194,10 @@ if(APPLE) target_link_libraries(onnxruntime_common PRIVATE "-framework Foundation") endif() -if(MSVC) - if(onnxruntime_target_platform STREQUAL "ARM64") - set(ARM64 TRUE) - elseif (onnxruntime_target_platform STREQUAL "ARM") - set(ARM TRUE) - elseif(onnxruntime_target_platform STREQUAL "x64") - set(X64 TRUE) - elseif(onnxruntime_target_platform STREQUAL "x86") - set(X86 TRUE) - endif() -elseif(APPLE) - if(CMAKE_OSX_ARCHITECTURES_LEN LESS_EQUAL 1) - set(X64 TRUE) - endif() -elseif(NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") - set(ARM TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") - set(ARM64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") - set(X86_64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") - set(X86 TRUE) - endif() - else() - execute_process( - COMMAND ${CMAKE_C_COMPILER} -dumpmachine - OUTPUT_VARIABLE dumpmachine_output - ERROR_QUIET - ) - if(dumpmachine_output MATCHES "^arm64.*") - set(ARM64 TRUE) - elseif(dumpmachine_output MATCHES "^arm.*") - set(ARM TRUE) - elseif(dumpmachine_output MATCHES "^aarch64.*") - set(ARM64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^riscv64.*") - set(RISCV64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") - set(X86 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") - set(X86_64 TRUE) - endif() - endif() -endif() - -if (RISCV64 OR ARM64 OR ARM OR X86 OR X64 OR X86_64) - # Link cpuinfo if supported - if (CPUINFO_SUPPORTED) - onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) - list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo ${ONNXRUNTIME_CLOG_TARGET_NAME}) - endif() +if(CPUINFO_SUPPORTED) + # Link cpuinfo if supported + onnxruntime_add_include_to_target(onnxruntime_common cpuinfo::cpuinfo) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES cpuinfo::cpuinfo) endif() if (NOT onnxruntime_BUILD_SHARED_LIB) diff --git a/cmake/onnxruntime_nodejs.cmake b/cmake/onnxruntime_nodejs.cmake index b28bda6c94276..cce0810c5bbe8 100644 --- a/cmake/onnxruntime_nodejs.cmake +++ b/cmake/onnxruntime_nodejs.cmake @@ -10,6 +10,7 @@ include(node_helper.cmake) # setup ARCH if (APPLE) + list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) if (CMAKE_OSX_ARCHITECTURES_LEN GREATER 1) message(FATAL_ERROR "CMake.js does not support multi-architecture for macOS") endif() diff --git a/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch b/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch new file mode 100644 index 0000000000000..af0f039b6c2a3 --- /dev/null +++ b/cmake/patches/cpuinfo/patch_vcpkg_arm64ec_support.patch @@ -0,0 +1,91 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index aedc983..dab589e 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -72,6 +72,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "FreeBSD" AND CPUINFO_TARGET_PROCESSOR STREQUAL "am + ENDIF() + IF(IS_APPLE_OS AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64.*)$") + SET(CPUINFO_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") ++ELSEIF(MSVC AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.10") ++ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID. MSVC values are documented as available since CMake 3.10. ++ IF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "X86") ++ SET(CPUINFO_TARGET_PROCESSOR "x86") ++ ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "x64") ++ SET(CPUINFO_TARGET_PROCESSOR "x86_64") ++ ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID MATCHES "^(ARM64|ARM64EC)$") ++ SET(CPUINFO_TARGET_PROCESSOR "arm64") ++ ELSE() ++ MESSAGE(FATAL_ERROR "Unsupported MSVC compiler architecture ID \"${CMAKE_C_COMPILER_ARCHITECTURE_ID}\"") ++ ENDIF() + ELSEIF(CMAKE_GENERATOR MATCHES "^Visual Studio " AND CMAKE_VS_PLATFORM_NAME) + IF(CMAKE_VS_PLATFORM_NAME STREQUAL "Win32") + SET(CPUINFO_TARGET_PROCESSOR "x86") +@@ -88,7 +99,7 @@ ENDIF() + + # ---[ Build flags + SET(CPUINFO_SUPPORTED_PLATFORM TRUE) +-IF(NOT CMAKE_SYSTEM_PROCESSOR) ++IF(NOT CPUINFO_TARGET_PROCESSOR) + IF(NOT IOS) + MESSAGE(WARNING + "Target processor architecture is not specified. " +@@ -201,12 +212,12 @@ IF(CPUINFO_SUPPORTED_PLATFORM) + src/arm/linux/chipset.c + src/arm/linux/midr.c + src/arm/linux/hwcap.c) +- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]") ++ IF(CPUINFO_TARGET_PROCESSOR MATCHES "^armv[5-8]") + LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch32-isa.c) + IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND ANDROID_ABI STREQUAL "armeabi") + SET_SOURCE_FILES_PROPERTIES(src/arm/linux/aarch32-isa.c PROPERTIES COMPILE_FLAGS -marm) + ENDIF() +- ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)$") ++ ELSEIF(CPUINFO_TARGET_PROCESSOR MATCHES "^(aarch64|arm64)$") + LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch64-isa.c) + ENDIF() + ELSEIF(IS_APPLE_OS AND CPUINFO_TARGET_PROCESSOR MATCHES "arm64.*") +@@ -395,7 +406,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) + TARGET_COMPILE_DEFINITIONS(cpuinfo_mock PRIVATE _GNU_SOURCE=1) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv5te|armv7-a)$") + ADD_EXECUTABLE(atm7029b-tablet-test test/mock/atm7029b-tablet.cc) + TARGET_INCLUDE_DIRECTORIES(atm7029b-tablet-test BEFORE PRIVATE test/mock) + TARGET_LINK_LIBRARIES(atm7029b-tablet-test PRIVATE cpuinfo_mock gtest) +@@ -577,7 +588,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) + ADD_TEST(NAME xperia-sl-test COMMAND xperia-sl-test) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") + ADD_EXECUTABLE(alcatel-revvl-test test/mock/alcatel-revvl.cc) + TARGET_INCLUDE_DIRECTORIES(alcatel-revvl-test BEFORE PRIVATE test/mock) + TARGET_LINK_LIBRARIES(alcatel-revvl-test PRIVATE cpuinfo_mock gtest) +@@ -774,7 +785,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) + ADD_TEST(NAME xperia-c4-dual-test COMMAND xperia-c4-dual-test) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(i686|x86_64)$") + ADD_EXECUTABLE(alldocube-iwork8-test test/mock/alldocube-iwork8.cc) + TARGET_INCLUDE_DIRECTORIES(alldocube-iwork8-test BEFORE PRIVATE test/mock) + TARGET_LINK_LIBRARIES(alldocube-iwork8-test PRIVATE cpuinfo_mock gtest) +@@ -831,7 +842,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_UNIT_TESTS) + ADD_TEST(NAME brand-string-test COMMAND brand-string-test) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") + ADD_LIBRARY(android_properties_interface STATIC test/name/android-properties-interface.c) + CPUINFO_TARGET_ENABLE_C99(android_properties_interface) + CPUINFO_TARGET_RUNTIME_LIBRARY(android_properties_interface) +@@ -879,7 +890,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_TOOLS) + TARGET_LINK_LIBRARIES(cache-info PRIVATE cpuinfo) + INSTALL(TARGETS cache-info RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + +- IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") ++ IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") + ADD_EXECUTABLE(auxv-dump tools/auxv-dump.c) + CPUINFO_TARGET_ENABLE_C99(auxv-dump) + CPUINFO_TARGET_RUNTIME_LIBRARY(auxv-dump) diff --git a/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch b/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch new file mode 100644 index 0000000000000..af0f039b6c2a3 --- /dev/null +++ b/cmake/vcpkg-ports/cpuinfo/patch_vcpkg_arm64ec_support.patch @@ -0,0 +1,91 @@ +diff --git a/CMakeLists.txt b/CMakeLists.txt +index aedc983..dab589e 100644 +--- a/CMakeLists.txt ++++ b/CMakeLists.txt +@@ -72,6 +72,17 @@ IF(CMAKE_SYSTEM_NAME MATCHES "FreeBSD" AND CPUINFO_TARGET_PROCESSOR STREQUAL "am + ENDIF() + IF(IS_APPLE_OS AND CMAKE_OSX_ARCHITECTURES MATCHES "^(x86_64|arm64.*)$") + SET(CPUINFO_TARGET_PROCESSOR "${CMAKE_OSX_ARCHITECTURES}") ++ELSEIF(MSVC AND CMAKE_VERSION VERSION_GREATER_EQUAL "3.10") ++ # Use CMAKE_C_COMPILER_ARCHITECTURE_ID. MSVC values are documented as available since CMake 3.10. ++ IF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "X86") ++ SET(CPUINFO_TARGET_PROCESSOR "x86") ++ ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID STREQUAL "x64") ++ SET(CPUINFO_TARGET_PROCESSOR "x86_64") ++ ELSEIF(CMAKE_C_COMPILER_ARCHITECTURE_ID MATCHES "^(ARM64|ARM64EC)$") ++ SET(CPUINFO_TARGET_PROCESSOR "arm64") ++ ELSE() ++ MESSAGE(FATAL_ERROR "Unsupported MSVC compiler architecture ID \"${CMAKE_C_COMPILER_ARCHITECTURE_ID}\"") ++ ENDIF() + ELSEIF(CMAKE_GENERATOR MATCHES "^Visual Studio " AND CMAKE_VS_PLATFORM_NAME) + IF(CMAKE_VS_PLATFORM_NAME STREQUAL "Win32") + SET(CPUINFO_TARGET_PROCESSOR "x86") +@@ -88,7 +99,7 @@ ENDIF() + + # ---[ Build flags + SET(CPUINFO_SUPPORTED_PLATFORM TRUE) +-IF(NOT CMAKE_SYSTEM_PROCESSOR) ++IF(NOT CPUINFO_TARGET_PROCESSOR) + IF(NOT IOS) + MESSAGE(WARNING + "Target processor architecture is not specified. " +@@ -201,12 +212,12 @@ IF(CPUINFO_SUPPORTED_PLATFORM) + src/arm/linux/chipset.c + src/arm/linux/midr.c + src/arm/linux/hwcap.c) +- IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv[5-8]") ++ IF(CPUINFO_TARGET_PROCESSOR MATCHES "^armv[5-8]") + LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch32-isa.c) + IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND ANDROID_ABI STREQUAL "armeabi") + SET_SOURCE_FILES_PROPERTIES(src/arm/linux/aarch32-isa.c PROPERTIES COMPILE_FLAGS -marm) + ENDIF() +- ELSEIF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm64)$") ++ ELSEIF(CPUINFO_TARGET_PROCESSOR MATCHES "^(aarch64|arm64)$") + LIST(APPEND CPUINFO_SRCS src/arm/linux/aarch64-isa.c) + ENDIF() + ELSEIF(IS_APPLE_OS AND CPUINFO_TARGET_PROCESSOR MATCHES "arm64.*") +@@ -395,7 +406,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) + TARGET_COMPILE_DEFINITIONS(cpuinfo_mock PRIVATE _GNU_SOURCE=1) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv5te|armv7-a)$") + ADD_EXECUTABLE(atm7029b-tablet-test test/mock/atm7029b-tablet.cc) + TARGET_INCLUDE_DIRECTORIES(atm7029b-tablet-test BEFORE PRIVATE test/mock) + TARGET_LINK_LIBRARIES(atm7029b-tablet-test PRIVATE cpuinfo_mock gtest) +@@ -577,7 +588,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) + ADD_TEST(NAME xperia-sl-test COMMAND xperia-sl-test) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv5te|armv7-a|aarch64)$") + ADD_EXECUTABLE(alcatel-revvl-test test/mock/alcatel-revvl.cc) + TARGET_INCLUDE_DIRECTORIES(alcatel-revvl-test BEFORE PRIVATE test/mock) + TARGET_LINK_LIBRARIES(alcatel-revvl-test PRIVATE cpuinfo_mock gtest) +@@ -774,7 +785,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_MOCK_TESTS) + ADD_TEST(NAME xperia-c4-dual-test COMMAND xperia-c4-dual-test) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(i686|x86_64)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(i686|x86_64)$") + ADD_EXECUTABLE(alldocube-iwork8-test test/mock/alldocube-iwork8.cc) + TARGET_INCLUDE_DIRECTORIES(alldocube-iwork8-test BEFORE PRIVATE test/mock) + TARGET_LINK_LIBRARIES(alldocube-iwork8-test PRIVATE cpuinfo_mock gtest) +@@ -831,7 +842,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_UNIT_TESTS) + ADD_TEST(NAME brand-string-test COMMAND brand-string-test) + ENDIF() + +- IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") ++ IF(CMAKE_SYSTEM_NAME STREQUAL "Android" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") + ADD_LIBRARY(android_properties_interface STATIC test/name/android-properties-interface.c) + CPUINFO_TARGET_ENABLE_C99(android_properties_interface) + CPUINFO_TARGET_RUNTIME_LIBRARY(android_properties_interface) +@@ -879,7 +890,7 @@ IF(CPUINFO_SUPPORTED_PLATFORM AND CPUINFO_BUILD_TOOLS) + TARGET_LINK_LIBRARIES(cache-info PRIVATE cpuinfo) + INSTALL(TARGETS cache-info RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) + +- IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CMAKE_SYSTEM_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") ++ IF(CMAKE_SYSTEM_NAME MATCHES "^(Android|Linux)$" AND CPUINFO_TARGET_PROCESSOR MATCHES "^(armv[5-8].*|aarch64)$") + ADD_EXECUTABLE(auxv-dump tools/auxv-dump.c) + CPUINFO_TARGET_ENABLE_C99(auxv-dump) + CPUINFO_TARGET_RUNTIME_LIBRARY(auxv-dump) diff --git a/cmake/vcpkg-ports/cpuinfo/portfile.cmake b/cmake/vcpkg-ports/cpuinfo/portfile.cmake index 3fcf76b7adafc..eeb0007195ca3 100644 --- a/cmake/vcpkg-ports/cpuinfo/portfile.cmake +++ b/cmake/vcpkg-ports/cpuinfo/portfile.cmake @@ -11,6 +11,7 @@ vcpkg_from_github( HEAD_REF master PATCHES patch_cpuinfo_h_for_arm64ec.patch + patch_vcpkg_arm64ec_support.patch # https://github.com/pytorch/cpuinfo/pull/324 ) vcpkg_check_features(OUT_FEATURE_OPTIONS FEATURE_OPTIONS diff --git a/onnxruntime/core/common/cpuid_arch_definition.h b/onnxruntime/core/common/cpuid_arch_definition.h index a541eb66d8ba3..5946b8ca27067 100644 --- a/onnxruntime/core/common/cpuid_arch_definition.h +++ b/onnxruntime/core/common/cpuid_arch_definition.h @@ -9,6 +9,6 @@ #define CPUIDINFO_ARCH_X86 #endif -#if defined(_M_ARM64) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) +#if defined(_M_ARM64) || defined(_M_ARM64EC) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) #define CPUIDINFO_ARCH_ARM #endif // ARM or ARM64 diff --git a/onnxruntime/test/platform/device_discovery_test.cc b/onnxruntime/test/platform/device_discovery_test.cc index 21ddf9a5b1cd7..6b43ccbc8f670 100644 --- a/onnxruntime/test/platform/device_discovery_test.cc +++ b/onnxruntime/test/platform/device_discovery_test.cc @@ -25,9 +25,9 @@ TEST(DeviceDiscoveryTest, HasCpuDevice) { const auto cpu_devices = GetDevicesByType(OrtHardwareDeviceType_CPU); ASSERT_GT(cpu_devices.size(), 0); -#if !defined(__wasm__) +#if defined(CPUINFO_SUPPORTED) ASSERT_NE(cpu_devices[0].vendor_id, 0); -#endif // !defined(__WASM__) +#endif // defined(CPUINFO_SUPPORTED) } } // namespace onnxruntime::test From 5746ba9d3b7b5eaf3a5c64fd24974f3649d71b34 Mon Sep 17 00:00:00 2001 From: Xiaofei Han Date: Mon, 1 Sep 2025 15:19:23 +0800 Subject: [PATCH 17/23] [webgpu] Add back missing code comments for flash decoding (#25879) Restore accidentally removed comments when using WGSL template. --- .../flash_attention_decode_qkt.wgsl.template | 23 ++++++++++++++++ ...sh_attention_decode_split_vx.wgsl.template | 27 +++++++++++++++++++ ...h_attention_decode_vx_reduce.wgsl.template | 11 ++++++++ 3 files changed, 61 insertions(+) diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template index c2be08b2186d4..7f41f2518b84b 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_qkt.wgsl.template @@ -6,6 +6,29 @@ #param tile_size_k_vec #param sub_tile_count +// Note that this shader adopts similar algorithm with dp4a generation shader. +// +// This algorithm works to compute dot product of keys with queries parallelly, +// by processing on the k (head_size) dimension at each step amongst +// tile_size_k_vec threads, and utilizing the remaining threads in the workgroup +// to process additional rows of |present_key| in parallel (such that the values +// in shared memory (tile_q) for |q| can be reused). For each load of q, the +// tile_size_k_vec threads also reload |present_key| tile_size/sub_tile_count +// times to compute partial dot products of other |present_key| rows in order to +// complete all tile_size |present_key| rows in this workgroup and also reusing +// the loaded in register values of |q|. + +// 1. Each workgroup processes one row of |q| and tile_size rows of |present_key| +// +// 2. Computation Process: +// - Reads [tile_size][tile_size_k_vec] block of |present_key| data at a time +// - Each thread within workgroup computes dot products of 4 A*B elements +// since each k represents 4 elements of |present_key| +// - Stores intermediate results in shared memory (inner_qk_values) +// - Iterates through columns (head_size_vec) accumulating results in +// inner_qk_values +// - Performs final reduction sum in inner_qk_values for output + var tile_q: array; var inner_qk_values: array, tile_size>; var tile_qk: array; diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template index 8d8519fec79b2..c7593af311ce2 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_split_vx.wgsl.template @@ -6,6 +6,33 @@ #param tile_size_k_vec #param sub_tile_count +// Note that this shader adopts similar algorithm with dp4a generation shader. +// +// This algorithm works to compute dot product of v with qk parallelly, by +// processing on the head_size dimension at each step amongst tile_size_k_vec +// threads, and utilizing the remaining threads in the workgroup to process +// additional rows of |present_value| in parallel (such that the values in +// shared memory (tile_qk) for |qk| can be reused). The tile_size_k_vec threads +// also reload |present_value| tile_size/sub_tile_count times to compute partial +// dot products of other |present_value| rows in order to complete all tile_size +// |present_value| rows in this workgroup and also reusing the values in +// tile_qk. +// +// The difference with FlashAttentionDecodeQKTProgram is that the dot products +// go through the rows (total_sequence_length) of |present_value| instead of +// columns (head_size_vec). And each workgroup only calculate current +// tile_size's dot products instead of iterating the whole row +// |total_sequence_length|. That's why this shader is a split shader. The final +// reduce will be done in FlashAttentionDecodeReduceProgram. + +// TODO: Ideally, there should only be two shaders FlashAttentionDecodeSplitVx +// and FlashAttentionDecodeVxReduce, which can also reduce the intermediate +// memory. The FlashAttentionDecodeQKT can be merged into split shader and do +// the final softmax adjustment in the reduce shader. However, some issues are +// met that when the total sequence length exceeds some value, the result will +// become garbage. Since it can't be resolved in a short time, leave it as TODO +// to fix it in future. + var tile_qk: array; var tile_output: array; var qkv_values: array, sub_tile_count>; diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template index 51dcd892338a4..a4381baa638ce 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention_decode_vx_reduce.wgsl.template @@ -3,6 +3,17 @@ #param tile_size +// Inputs are splits of the GQA output, split into num_total_seq_length_tiles +// rows. This shader needs to add these splits across the row dimension to +// arrive at the final result. The column is head size wide. The reduction +// achieves maximum parallelization by splitting this task first into tile_size +// columns that each workgroup is responsible for. Then within each workgroup +// the task of summation over the num_total_seq_length_tile for the tile_size +// columns is further split in two ways. First across the row dimension to have +// WORKGROUP_SIZE/TILE_SIZE parallel computations of summation of TILE_SIZE +// rows. Then across the column dimension where each thread is responsible for 1 +// column of the TILE_SIZE columns the workgroup is responsible for. + var tile_input: array, tile_size>; $MAIN { From af4bf436d70f62afe486bb5b9d9cc7c8f3b5b958 Mon Sep 17 00:00:00 2001 From: Rohanjames1997 Date: Tue, 2 Sep 2025 11:41:52 -0500 Subject: [PATCH 18/23] Replace vmlaq_f32 with vfmaq_f32 (fused multiply-add) (#25669) ### Description The [vfmaq_f32](https://developer.arm.com/architectures/instruction-sets/intrinsics/vfmaq_f32) intrinsic compiles to the [FMLA](https://developer.arm.com/documentation/ddi0596/2021-03/SIMD-FP-Instructions/FMLA--vector---Floating-point-fused-Multiply-Add-to-accumulator--vector--?lang=en) instruction which is more performant than separate `fmul`+`fadd` instructions that [vmlaq_f32](https://developer.arm.com/architectures/instruction-sets/intrinsics/vmlaq_f32) compiles to on latest GCC versions: https://godbolt.org/z/aYc9as5Wh Note that this is not a breaking change, as vmlaq_f32 compiles to FMLA instructions already on the latest clang compilers (which are the default for MacOS ORT builds already) ### Motivation and Context With this change, the NEON version of `MlasMultiplyAddFloat32x4` achieves parity with the x86 version that uses `_mm_fmadd_ps`. It also achieves up to ~15% speedups compared to the current `vmlaq_f32` implementation when tested on top of #25580 --- onnxruntime/core/mlas/lib/mlasi.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index a099bcf8438fe..90d44adbfd286 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -2280,7 +2280,7 @@ MLAS_FLOAT32X4 MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, MLAS_FLOAT32X4 Vector3) { #if defined(MLAS_NEON_INTRINSICS) - return vmlaq_f32(Vector3, Vector1, Vector2); + return vfmaq_f32(Vector3, Vector1, Vector2); #elif defined(MLAS_FMA3_INTRINSICS) return _mm_fmadd_ps(Vector1, Vector2, Vector3); #elif defined(MLAS_SSE2_INTRINSICS) From 53bb79b929c72cbc8222f599336d3ed9858c6a06 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Wed, 3 Sep 2025 03:43:18 +0800 Subject: [PATCH 19/23] Support DynamicQuantizeLinear op (#25905) --- js/web/docs/webnn-operators.md | 1 + .../impl/dynamicQuantizeLinear_op_builder.cc | 182 ++++++++++++++++-- .../core/providers/webnn/builders/map_info.h | 3 +- 3 files changed, 174 insertions(+), 12 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 0fffe99ec4f78..793161aecefaf 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -32,6 +32,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s | Div | ai.onnx(7-12, 13, 14+) | div | | | DequantizeLinear | ai.onnx(10-12, 13-18, 19-20, 21-22, 23+) | dequantizeLinear | The shape of x_scale should be a subsample of the shape of input | | Dropout | ai.onnx(7-9, 10-11, 12, 13-21, 22+) | identity | Only supports test mode | +| DynamicQuantizeLinear | ai.onnx(11+) | cast, clamp, div, div, max, min, quantizeLinear, reduceMax, reduceMin, reshape, roundEven, sub | | | Einsum | ai.onnx(12+) | reshape, transpose, matmul, reduceSum, mul, triangular | | | Elu | ai.onnx(7+) | elu | | | Equal | ai.onnx(7-10, 11-12, 13-18, 19+) | equal | | diff --git a/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc index f3363b1e186d5..80425d5fcf42f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dynamicQuantizeLinear_op_builder.cc @@ -14,34 +14,194 @@ namespace onnxruntime { namespace webnn { -class DynamicQuantizaLinearOpBuilder : public BaseOpBuilder { +class DynamicQuantizeLinearOpBuilder : public BaseOpBuilder { // Add operator related. private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + private: + bool HasSupportedInputsImpl(const GraphViewer&, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; + bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; -Status DynamicQuantizaLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, +// DynamicQuantizeLinear is a function defined as follows: +// DynamicQuantizeLinear (x) => (y, y_scale, y_zero_point) +// { +// Q_Min = Constant () +// Q_Max = Constant () +// X_Min = ReduceMin (x) +// X_Min_Adjusted = Min (X_Min, Q_Min) +// X_Max = ReduceMax (x) +// X_Max_Adjusted = Max (X_Max, Q_Min) +// X_Range = Sub (X_Max_Adjusted, X_Min_Adjusted) +// Scale = Div (X_Range, Q_Max) +// Min_Scaled = Div (X_Min_Adjusted, Scale) +// Initial_ZeroPoint_FP = Sub (Q_Min, Min_Scaled) +// Clipped_ZeroPoint_FP = Clip (Initial_ZeroPoint_FP, Q_Min, Q_Max) +// Rounded_ZeroPoint_FP = Round (Clipped_ZeroPoint_FP) +// Zeropoint = Cast (Rounded_ZeroPoint_FP) +// y_scale = Identity (Scale) (Skip in WebNN) +// y_zero_point = Identity (Zeropoint) (Skip in WebNN) +// y = QuantizeLinear (x, Scale, Zeropoint) +// } +Status DynamicQuantizeLinearOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); - emscripten::val output_array; - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - emscripten::val options = emscripten::val::object(); - options.set("label", node.Name()); + emscripten::val common_options = emscripten::val::object(); + + // Q_Min = Constant () + emscripten::val q_min = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 0.0f); + // Q_Max = Constant () + emscripten::val q_max = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 255.0f); + + // X_Min = ReduceMin (x) + common_options.set("label", node.Name() + "_x_min"); + emscripten::val x_min = model_builder.GetBuilder().call("reduceMin", input, common_options); + + // X_Min_Adjusted = Min (X_Min, Q_Min) + common_options.set("label", node.Name() + "_x_min_adjusted"); + emscripten::val x_min_adjusted = model_builder.GetBuilder().call("min", x_min, q_min, common_options); + + // X_Max = ReduceMax (x) + common_options.set("label", node.Name() + "_x_max"); + emscripten::val x_max = model_builder.GetBuilder().call("reduceMax", input, common_options); + + // X_Max_Adjusted = Max (X_Max, Q_Min) + common_options.set("label", node.Name() + "_x_max_adjusted"); + emscripten::val x_max_adjusted = model_builder.GetBuilder().call( + "max", x_max, q_min, common_options); + + // X_Range = Sub (X_Max_Adjusted, X_Min_Adjusted) + common_options.set("label", node.Name() + "_x_range"); + emscripten::val x_range = model_builder.GetBuilder().call( + "sub", x_max_adjusted, x_min_adjusted, common_options); - output_array = model_builder.GetBuilder().call("dynamicQuantizeLinear", input, options); + // Scale = Div (X_Range, Q_Max) + common_options.set("label", node.Name() + "_scale"); + emscripten::val scale = model_builder.GetBuilder().call("div", x_range, q_max, common_options); - for (size_t i = 0, count = output_array["length"].as(); i < count; i++) { - model_builder.AddOperand(node.OutputDefs()[i]->Name(), std::move(output_array[i])); + // Min_Scaled = Div (X_Min_Adjusted, Scale) + common_options.set("label", node.Name() + "_min_scaled"); + emscripten::val min_scaled = model_builder.GetBuilder().call( + "div", x_min_adjusted, scale, common_options); + + // Initial_ZeroPoint_FP = Sub (Q_Min, Min_Scaled) + common_options.set("label", node.Name() + "_initial_zero_point_fp"); + emscripten::val initial_zero_point_fp = model_builder.GetBuilder().call( + "sub", q_min, min_scaled, common_options); + + // Clipped_ZeroPoint_FP = Clip (Initial_ZeroPoint_FP, Q_Min, Q_Max) + emscripten::val clip_options = emscripten::val::object(); + clip_options.set("label", node.Name() + "_clipped_zero_point_fp"); + clip_options.set("minValue", 0); + clip_options.set("maxValue", 255); + emscripten::val clipped_zero_point_fp = model_builder.GetBuilder().call( + "clamp", initial_zero_point_fp, clip_options); + + // Rounded_ZeroPoint_FP = Round (Clipped_ZeroPoint_FP) + common_options.set("label", node.Name() + "_rounded_zero_point_fp"); + emscripten::val rounded_zero_point_fp = model_builder.GetBuilder().call( + "roundEven", clipped_zero_point_fp, common_options); + + // Zeropoint = Cast (Rounded_ZeroPoint_FP) + // to: int = 2 means cast to uint8 + common_options.set("label", node.Name() + "_zero_point"); + emscripten::val zero_point = model_builder.GetBuilder().call( + "cast", rounded_zero_point_fp, emscripten::val("uint8"), common_options); + + // The WebNN quantizeLinear op requires the scale and zero_point tensors to have the same rank as the input tensor. + // The scale and zero_point outputs are both scalars, so we need to reshape them to match the input rank. + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); + const auto input_rank = input_shape.size(); + emscripten::val new_scale = scale; + emscripten::val new_zero_point = zero_point; + if (input_rank > 0) { + std::vector new_shape(input_rank, 1); + common_options.set("label", node.Name() + "_reshape_scale"); + new_scale = model_builder.GetBuilder().call( + "reshape", scale, emscripten::val::array(new_shape), common_options); + + common_options.set("label", node.Name() + "_reshape_zero_point"); + new_zero_point = model_builder.GetBuilder().call( + "reshape", zero_point, emscripten::val::array(new_shape), common_options); } + + // y = QuantizeLinear (x, Scale, Zeropoint) + common_options.set("label", node.Name() + "_quantize_linear"); + emscripten::val y = model_builder.GetBuilder().call( + "quantizeLinear", input, new_scale, new_zero_point, common_options); + + // Add output: y + model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(y)); + // Add output: y_scale + model_builder.AddOperand(node.OutputDefs()[1]->Name(), std::move(scale)); + // Add output: y_zero_point + model_builder.AddOperand(node.OutputDefs()[2]->Name(), std::move(zero_point)); + return Status::OK(); } +// Operator support related. +bool DynamicQuantizeLinearOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + + int32_t input_type = 0; + if (!GetType(*input_defs[0], input_type, logger)) { + return false; + } + if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + LOGS(logger, VERBOSE) << "DynamicQuantizeLinear only supports input data type float."; + return false; + } + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) { + return false; + } + // It's complicated to check all the decomposed ops' input rank support. + // Ensure at least the first input rank is supported by the decomposed ops. + // (reduceMax, reduceMin and quantizeLinear accept the first input). + const std::array operations = {"reduceMax", "reduceMin", "quantizeLinear"}; + for (const auto& op : operations) { + if (!IsInputRankSupported(wnn_limits, op, "input", input_shape.size(), node.Name(), logger)) { + return false; + } + } + + return true; +} + +bool DynamicQuantizeLinearOpBuilder::HasSupportedOutputsImpl(const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& output_defs = node.OutputDefs(); + const std::string_view op_type = node.OpType(); + int32_t y_type, y_scale_type, y_zero_point_type; + if (!GetType(*output_defs[0], y_type, logger) || + !GetType(*output_defs[1], y_scale_type, logger) || + !GetType(*output_defs[2], y_zero_point_type, logger)) { + return false; + } + + // Only need to check the output data type of ops that produce the outputs of DynamicQuantizeLinear. + // 1. QuantizeLinear -> y (uint8) + // 2. Div -> y_scale (float32) (skip it as WebNN should support it by default) + // 3. Cast -> y_zero_point (uint8) + return IsDataTypeSupportedByWebNNOp(op_type, "quantizeLinear", y_type, wnn_limits, "output", "y", logger) && + IsDataTypeSupportedByWebNNOp(op_type, "cast", y_zero_point_type, wnn_limits, "output", "y_zero_point", logger); +} + void CreateDynamicQuantizeLinearOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { - op_registrations.builders.push_back(std::make_unique()); + op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); } diff --git a/onnxruntime/core/providers/webnn/builders/map_info.h b/onnxruntime/core/providers/webnn/builders/map_info.h index 590614edf851c..b7959d1e0c3d4 100644 --- a/onnxruntime/core/providers/webnn/builders/map_info.h +++ b/onnxruntime/core/providers/webnn/builders/map_info.h @@ -47,6 +47,8 @@ constexpr std::array supported_fallback // Use ONNX-to-ONNX op mapping to improve the search complexity for WebNN ops in the op_inputs_map. const std::map> decomposed_op_map = { {"ConvInteger", {"Cast", "Conv", "DequantizeLinear"}}, + {"DynamicQuantizeLinear", + {"Cast", "Clip", "Div", "Max", "Min", "QuantizeLinear", "ReduceMax", "ReduceMin", "Reshape", "Round", "Sub"}}, {"Einsum", {"MatMul", "Mul", "ReduceSum", "Reshape", "Transpose", "Trilu"}}, {"GroupQueryAttention", {"Add", "Cast", "Concat", "CumSum", "Div", "Expand", "Less", "MatMul", "Reshape", "ScatterND", @@ -190,7 +192,6 @@ const std::unordered_map op_inputs_map = { {"GatherND", {"gatherND", {{0, "input"}, {1, "indices"}}}}, {"GreaterOrEqual", {"greaterOrEqual", {{0, "a"}, {1, "b"}}}}, {"Conv", {"conv2d", {{0, "input"}, {1, "filter"}, {2, "bias"}}}}, - {"DynamicQuantizeLinear", {"dynamicQuantizeLinear", {{0, "input"}}}}, {"GatherElements", {"gatherElements", {{0, "input"}, {1, "indices"}}}}, {"ScatterND", {"scatterND", {{0, "input"}, {1, "indices"}, {2, "updates"}}}}, {"Where", {"where", {{0, "condition"}, {1, "trueValue"}, {2, "falseValue"}}}}, From a3ad1f9b0a54de9428b5d8d6c662f026110e379f Mon Sep 17 00:00:00 2001 From: Jeff Kilpatrick Date: Tue, 2 Sep 2025 13:29:06 -0700 Subject: [PATCH 20/23] [QNN EP] Remove workaround for syntax error macro (#25923) ### Description Until QAIRT 2.37.0, `QNN_IR_GRAPH_SERIALIZATION_OPTION_INIT` was unusable due to a missing semicolon. Now that it's been fixed, revert the workaround. --- .../core/providers/qnn/builder/qnn_backend_manager.cc | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 5bcb8ca394346..211bcbc753140 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -56,12 +56,6 @@ static const char* DlError() { #endif } -// Workaround for a missing comma in QNN_IR_GRAPH_CUSTOM_CONFIG_INIT. -static QnnIrGraph_CustomConfig_t EmptyIrGraphConfig() { - return { - QNN_IR_GRAPH_CONFIG_OPTION_SERIALIZATION, {QNN_IR_GRAPH_SERIALIZATION_TYPE_FLAT_BUFFER, ""}}; -} - class QnnIrConfig : public QnnSerializerConfig { public: QnnIrConfig(std::string backend_path, std::string dlc_dir) @@ -97,7 +91,7 @@ class QnnIrConfig : public QnnSerializerConfig { private: static QnnConfigsBuilder MakeConfigsBuilder() { - return QnnConfigsBuilder(QNN_GRAPH_CONFIG_INIT, EmptyIrGraphConfig()); + return QnnConfigsBuilder(QNN_GRAPH_CONFIG_INIT, QNN_IR_GRAPH_CUSTOM_CONFIG_INIT); } std::filesystem::path dlc_dir_; From 2705d4b4cff900878036695c3faaf132af249a57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Wed, 3 Sep 2025 01:11:55 +0200 Subject: [PATCH 21/23] [TRT RTX EP] Add sync method (#25898) This PR adds a missing sync method and fixes the linux CI --- .../nv_tensorrt_rtx/nv_execution_provider.cc | 14 ++++++++++---- .../nv_tensorrt_rtx/nv_execution_provider.h | 1 + 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index 93b673f2df5bd..f827b807f2408 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -3,6 +3,7 @@ // Licensed under the MIT License. #include #include +#include #include #include "core/providers/shared_library/provider_api.h" #include "core/providers/nv_tensorrt_rtx/nv_provider_options.h" @@ -1056,6 +1057,11 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) << ", nv_op_types_to_exclude: " << op_types_to_exclude_; } +Status NvExecutionProvider::Sync() const { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); + return Status::OK(); +} + NvExecutionProvider::~NvExecutionProvider() { // clean up thread local context caches { @@ -1574,8 +1580,8 @@ SubGraphCollection_t NvExecutionProvider::GetSupportedList(SubGraphCollection_t // the initializer was marked as external data by the ORT graph at load time since it was provided in memory size_t size = 0; const void* ptr = nullptr; - c_api.GetTensorSizeInBytes(&initializer_value, &size); - c_api.GetTensorData(&initializer_value, &ptr); + Ort::ThrowOnError(c_api.GetTensorSizeInBytes(&initializer_value, &size)); + Ort::ThrowOnError(c_api.GetTensorData(&initializer_value, &ptr)); userWeights.emplace_back(tp->name(), ptr, size); } else if (utils::HasExternalDataInMemory(*tp)) { // only copy and take ownership of the data if none of the above conditions are met @@ -2394,8 +2400,8 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // the initializer was marked as external data by the ORT graph at load time since it was provided in memory size_t size = 0; const void* ptr = nullptr; - c_api.GetTensorSizeInBytes(&initializer_value, &size); - c_api.GetTensorData(&initializer_value, &ptr); + Ort::ThrowOnError(c_api.GetTensorSizeInBytes(&initializer_value, &size)); + Ort::ThrowOnError(c_api.GetTensorData(&initializer_value, &ptr)); userWeights.emplace_back(tp->name(), ptr, size); } else if (utils::HasExternalDataInMemory(*tp)) { // only copy and take ownership of the data if none of the above conditions are met diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 9e5fd03756f02..dc323cd643032 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -285,6 +285,7 @@ class NvExecutionProvider : public IExecutionProvider { IResourceAccountant* /* resource_accountant */) const override; int GetDeviceId() const { return device_id_; } + Status Sync() const; common::Status Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) override; From daa03069d0b5c18770bd766a3c7606994f8108f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Wed, 3 Sep 2025 06:24:52 +0200 Subject: [PATCH 22/23] [TRT RTX EP] Memory map the engine buffer (#25909) ### Description Change from fread to mmap to save on system memory. This also accelerated the load time of a ~4GB model in my testing by 1.5X. --- .../nv_tensorrt_rtx/onnx_ctx_model_helper.cc | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index 1f34a0f25877d..c1626fa4f36ad 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -311,13 +311,19 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const Node& node) { ". Please make sure engine cache is in the same directory or sub-directory of context model."); } - std::ifstream engine_file(engine_cache_path.string(), std::ios::binary | std::ios::in); - engine_file.seekg(0, std::ios::end); - size_t engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size)); + size_t file_length = 0; + auto path_str = ToPathString(engine_cache_path.string()); + + Env::MappedMemoryPtr engine_buf; + const auto& env = GetDefaultEnv(); + ORT_RETURN_IF_ERROR(env.GetFileLength(path_str.c_str(), file_length)); + if (!file_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "Nv EP could not read engine from cache: " + engine_cache_path.string()); + } + ORT_RETURN_IF_ERROR(env.MapFileIntoMemory(path_str.c_str(), 0, file_length, engine_buf)); + + *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), file_length)); if (!(*trt_engine_)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "Nv EP could not deserialize engine from cache: " + engine_cache_path.string()); From 5537d33fff700a8fc76ed261590df3495c5876da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maximilian=20M=C3=BCller?= <44298237+gedoensmax@users.noreply.github.com> Date: Wed, 3 Sep 2025 06:25:12 +0200 Subject: [PATCH 23/23] [TRT RTX EP] Add support for RTX runtime caches (#25917) ### Description Runtime caches can accelerate the JIT time when deserializing an engine of TRT RTX. Here we introduce a per engine caching in a user specified folder. The cache file will be named after the fused node name - which will also be the node name of an ep context node. @chilo-ms we would like to pick this to 1.23 --- .../nv_tensorrt_rtx/nv_provider_options.h | 1 + .../nv_tensorrt_rtx/nv_execution_provider.cc | 73 ++++++++++++++--- .../nv_tensorrt_rtx/nv_execution_provider.h | 33 ++++++-- .../nv_execution_provider_info.cc | 4 +- .../nv_execution_provider_info.h | 2 +- .../providers/nv_tensorrt_rtx/nv_file_utils.h | 52 ++++++++++++ .../nv_tensorrt_rtx/nv_options_test.cc | 82 +++++++++++++++++++ 7 files changed, 230 insertions(+), 17 deletions(-) create mode 100644 onnxruntime/core/providers/nv_tensorrt_rtx/nv_file_utils.h create mode 100644 onnxruntime/test/providers/nv_tensorrt_rtx/nv_options_test.cc diff --git a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h index a32f465e44adf..026fc3b2dc0a0 100644 --- a/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h +++ b/include/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_options.h @@ -34,6 +34,7 @@ constexpr const char* kProfilesOptShapes = "nv_profile_opt_shapes"; constexpr const char* kCudaGraphEnable = "enable_cuda_graph"; constexpr const char* kMultiProfileEnable = "nv_multi_profile_enable"; constexpr const char* kUseExternalDataInitializer = "nv_use_external_data_initializer"; +constexpr const char* kRuntimeCacheFile = "nv_runtime_cache_path"; } // namespace provider_option_names namespace run_option_names { diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index f827b807f2408..f1f93c34cbf4b 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -655,9 +655,9 @@ void NvExecutionProvider::PerThreadContext::ResetTensorRTContext(std::string fus } } -bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fused_node, std::unique_ptr context) { +bool NvExecutionProvider::PerThreadContext::UpdateTensorRTContext(std::string fused_node, tensorrt_ptr::unique_pointer_exec_ctx context) { if (!context) { - context = std::make_unique(); + context = tensorrt_ptr::unique_pointer_exec_ctx(); } trt_context_map_[fused_node] = std::move(context); @@ -758,11 +758,11 @@ bool NvExecutionProvider::PerThreadContext::IsTensorRTContextInMap(std::string f nvinfer1::IExecutionContext& NvExecutionProvider::PerThreadContext::GetTensorRTContext(std::string fused_node) { auto it = trt_context_map_.find(fused_node); if (it != trt_context_map_.end()) { - return *(it->second); // dereference shared pointer + return *(it->second.get()); // dereference shared pointer } - auto context = std::make_unique(); + auto context = tensorrt_ptr::unique_pointer_exec_ctx(); trt_context_map_[fused_node] = std::move(context); - return *(trt_context_map_[fused_node]); // dereference shared pointer + return *(trt_context_map_[fused_node].get()); // dereference shared pointer } void NvExecutionProvider::ReleasePerThreadContext() const { @@ -871,6 +871,20 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) max_shared_mem_size_ = info.max_shared_mem_size; dump_subgraphs_ = info.dump_subgraphs; weight_stripped_engine_enable_ = info.weight_stripped_engine_enable; + // make runtime cache path absolute and create directory if it doesn't exist + if (!info.runtime_cache_path.empty()) { + std::filesystem::path p(info.runtime_cache_path); + std::filesystem::path abs_path = std::filesystem::absolute(p); + const auto& env = GetDefaultEnv(); + auto status = env.CreateFolder(abs_path.string()); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] The runtime cache directory could not be created at: " << abs_path + << ". Runtime cache is disabled."; + } else { + runtime_cache_ = abs_path; + } + } + onnx_model_folder_path_ = info.onnx_model_folder_path; onnx_model_bytestream_ = info.onnx_bytestream; onnx_model_bytestream_size_ = info.onnx_bytestream_size; @@ -1054,7 +1068,8 @@ NvExecutionProvider::NvExecutionProvider(const NvExecutionProviderInfo& info) << ", nv_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_ << ", nv_onnx_external_bytestream_size_: " << onnx_external_data_bytestream_size_ << ", nv_use_external_data_initializer_: " << use_external_data_initializer_ - << ", nv_op_types_to_exclude: " << op_types_to_exclude_; + << ", nv_op_types_to_exclude: " << op_types_to_exclude_ + << ", nv_runtime_cache_path: " << runtime_cache_; } Status NvExecutionProvider::Sync() const { @@ -2637,8 +2652,10 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // // Otherwise engine will be handled at inference time. std::unique_ptr trt_engine; - std::unique_ptr trt_context; + tensorrt_ptr::unique_pointer_exec_ctx trt_context; + std::unique_ptr trt_runtime_cache; std::unique_ptr trt_runtime_config; + std::string runtime_cache_file = ""; // Generate file name for dumping ep context model if (dump_ep_context_model_ && ctx_model_path_.empty()) { @@ -2667,6 +2684,18 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr trt_runtime_config->setDynamicShapesKernelSpecializationStrategy(nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER); } trt_runtime_config->setExecutionContextAllocationStrategy(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED); + if (!runtime_cache_.empty()) { + runtime_cache_file = (runtime_cache_ / fused_node.Name()).string(); + trt_runtime_cache = std::unique_ptr(trt_runtime_config->createRuntimeCache()); + auto cache_data = file_utils::ReadFile(runtime_cache_file); + if (!trt_runtime_cache->deserialize(cache_data.data(), cache_data.size())) { + trt_runtime_cache = std::unique_ptr(trt_runtime_config->createRuntimeCache()); + LOGS_DEFAULT(INFO) << "TensorRT RTX failed to deserialize the runtime cache, will overwrite with new one" << std::endl; + } + if (!trt_runtime_config->setRuntimeCache(*trt_runtime_cache)) { + LOGS_DEFAULT(INFO) << "TensorRT RTX failed to set the runtime cache" << std::endl; + } + } if (detailed_build_log_) { auto engine_build_stop = std::chrono::steady_clock::now(); @@ -2727,7 +2756,9 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr // Build context // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_context = std::unique_ptr(trt_engine->createExecutionContext(trt_runtime_config.get())); + trt_context = tensorrt_ptr::unique_pointer_exec_ctx( + trt_engine->createExecutionContext(trt_runtime_config.get()), + tensorrt_ptr::IExecutionContextDeleter(runtime_cache_file, std::move(trt_runtime_cache))); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name()); @@ -3008,7 +3039,7 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra std::unordered_map& output_map, std::vector& node_compute_funcs) { std::unique_ptr trt_engine; - std::unique_ptr trt_context; + tensorrt_ptr::unique_pointer_exec_ctx trt_context; std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index std::unordered_map output_indexes; // TRT engine output name -> ORT kernel context output index std::unordered_map output_types; // TRT engine output name -> ORT output tensor type @@ -3030,11 +3061,33 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const Gra return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage()); } + std::unique_ptr trt_runtime_cache; + auto trt_runtime_config = std::unique_ptr(trt_engine->createRuntimeConfig()); + if (trt_runtime_config && cuda_graph_enable_) { + trt_runtime_config->setDynamicShapesKernelSpecializationStrategy(nvinfer1::DynamicShapesKernelSpecializationStrategy::kEAGER); + } + trt_runtime_config->setExecutionContextAllocationStrategy(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED); + std::string runtime_cache_file = ""; + if (!runtime_cache_.empty()) { + runtime_cache_file = (runtime_cache_ / graph_body_viewer.GetNode(node_idx)->Name()).string(); + trt_runtime_cache = std::unique_ptr(trt_runtime_config->createRuntimeCache()); + auto cache_data = file_utils::ReadFile(runtime_cache_file); + if (!trt_runtime_cache->deserialize(cache_data.data(), cache_data.size())) { + trt_runtime_cache = std::unique_ptr(trt_runtime_config->createRuntimeCache()); + LOGS_DEFAULT(INFO) << "TensorRT RTX failed to deserialize the runtime cache, will overwrite with new one" << std::endl; + } + if (!trt_runtime_config->setRuntimeCache(*trt_runtime_cache)) { + LOGS_DEFAULT(INFO) << "TensorRT RTX failed to set the runtime cache" << std::endl; + } + } + // Build context // // Note: Creating an execution context from an engine is thread safe per TRT doc // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_context = std::unique_ptr(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED)); + trt_context = tensorrt_ptr::unique_pointer_exec_ctx( + trt_engine->createExecutionContext(trt_runtime_config.get()), + tensorrt_ptr::IExecutionContextDeleter(runtime_cache_file, std::move(trt_runtime_cache))); if (!trt_context) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP could not build execution context for fused node: " + fused_node.Name()); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index dc323cd643032..bb8f687db094f 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -16,6 +16,7 @@ typedef void* cudnnStatus_t; #include #include "core/providers/cuda/cuda_graph.h" #include "nv_execution_provider_info.h" +#include "core/providers/nv_tensorrt_rtx/nv_file_utils.h" namespace onnxruntime { @@ -58,6 +59,26 @@ class TensorrtLogger : public nvinfer1::ILogger { }; namespace tensorrt_ptr { +/* + * custom deleter that will dump the optimized runtime cache when the execution context is destructed + */ +struct IExecutionContextDeleter { + IExecutionContextDeleter() = default; + IExecutionContextDeleter(const std::string& runtime_cache_path, std::unique_ptr&& runtime_cache) : runtime_cache_path_(runtime_cache_path), runtime_cache_(std::move(runtime_cache)) {}; + void operator()(nvinfer1::IExecutionContext* context) { + if (context != nullptr) { + if (!runtime_cache_path_.empty()) { + auto serialized_cache_data = std::unique_ptr(runtime_cache_->serialize()); + file_utils::WriteFile(runtime_cache_path_, serialized_cache_data->data(), serialized_cache_data->size()); + } + delete context; + } + } + + private: + std::string runtime_cache_path_; + std::unique_ptr runtime_cache_; +}; struct TensorrtInferDeleter { template @@ -70,6 +91,7 @@ struct TensorrtInferDeleter { template using unique_pointer = std::unique_ptr; +using unique_pointer_exec_ctx = std::unique_ptr; }; // namespace tensorrt_ptr // @@ -196,7 +218,7 @@ struct TensorrtFuncState { std::string fused_node_name; nvinfer1::IBuilder* builder; std::unique_ptr* engine = nullptr; - std::unique_ptr* context = nullptr; + tensorrt_ptr::unique_pointer_exec_ctx* context = nullptr; std::unique_ptr* network = nullptr; std::vector> input_info; std::vector> output_info; @@ -233,7 +255,7 @@ struct TensorrtShortFuncState { AllocatorHandle allocator = nullptr; std::string fused_node_name; std::unique_ptr* engine = nullptr; - std::unique_ptr* context = nullptr; + tensorrt_ptr::unique_pointer_exec_ctx* context = nullptr; std::vector> input_info; std::vector> output_info; std::mutex* tensorrt_mu_ptr = nullptr; @@ -357,6 +379,7 @@ class NvExecutionProvider : public IExecutionProvider { bool detailed_build_log_ = false; bool cuda_graph_enable_ = false; bool multi_profile_enable_ = false; + std::filesystem::path runtime_cache_; std::string cache_prefix_; std::string op_types_to_exclude_; int nv_profile_index_ = 0; @@ -387,7 +410,7 @@ class NvExecutionProvider : public IExecutionProvider { // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. std::unordered_map> engines_; - std::unordered_map> contexts_; + std::unordered_map contexts_; std::unordered_map> builders_; std::unordered_map> networks_; std::unordered_map>> input_info_; @@ -425,7 +448,7 @@ class NvExecutionProvider : public IExecutionProvider { bool IsTensorRTContextInMap(std::string fused_node); nvinfer1::IExecutionContext& GetTensorRTContext(std::string fused_node); - bool UpdateTensorRTContext(std::string fused_node, std::unique_ptr context); + bool UpdateTensorRTContext(std::string fused_node, tensorrt_ptr::unique_pointer_exec_ctx context); void ResetTensorRTContext(std::string fused_node); // CUDA Graph management @@ -455,7 +478,7 @@ class NvExecutionProvider : public IExecutionProvider { // See more details here: // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading // https://docs.nvidia.com/deeplearning/tensorrt/api/c_api/classnvinfer1_1_1_i_execution_context.html#a63cd95430852038ce864e17c670e0b36 - std::unordered_map> trt_context_map_; + std::unordered_map trt_context_map_; // The profile shape ranges for the engine that the execution context maintained by the PerThreadContext is built with. // TRT EP needs this info to determine whether to rebuild the execution context. diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc index 527a37f6c2b57..f25718114891b 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.cc @@ -51,6 +51,7 @@ NvExecutionProviderInfo NvExecutionProviderInfo::FromProviderOptions(const Provi .AddAssignmentToReference(nv::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable) .AddAssignmentToReference(nv::provider_option_names::kUseExternalDataInitializer, info.use_external_data_initializer) .AddAssignmentToReference(nv::provider_option_names::kMultiProfileEnable, info.multi_profile_enable) + .AddAssignmentToReference(nv::provider_option_names::kRuntimeCacheFile, info.runtime_cache_path) .Parse(options)); // add new provider option here. info.user_compute_stream = user_compute_stream; @@ -105,7 +106,8 @@ ProviderOptions NvExecutionProviderInfo::ToProviderOptions(const NvExecutionProv {nv::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)}, {nv::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)}, {nv::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)}, - {nv::provider_option_names::kUseExternalDataInitializer, MakeStringWithClassicLocale(info.use_external_data_initializer)}}; + {nv::provider_option_names::kUseExternalDataInitializer, MakeStringWithClassicLocale(info.use_external_data_initializer)}, + {nv::provider_option_names::kRuntimeCacheFile, MakeStringWithClassicLocale(info.runtime_cache_path)}}; return options; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h index b826925361b05..372e8196f38c2 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_info.h @@ -37,7 +37,7 @@ struct NvExecutionProviderInfo { bool engine_decryption_enable{false}; std::string engine_decryption_lib_path{""}; bool force_sequential_engine_build{false}; - std::string timing_cache_path{""}; + std::string runtime_cache_path{""}; bool detailed_build_log{false}; bool sparsity_enable{false}; int auxiliary_streams{-1}; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_file_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_file_utils.h new file mode 100644 index 0000000000000..159aba0507ffb --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_file_utils.h @@ -0,0 +1,52 @@ +#pragma once +#include +#include +#include +#include +#include +#include "core/providers/shared_library/provider_api.h" + +namespace onnxruntime { +namespace file_utils { + +inline std::vector ReadFile(const std::string& path) { + if (!std::filesystem::exists(path)) { + LOGS_DEFAULT(INFO) << "TensorRT RTX could not find the file and will create a new one " << path << std::endl; + return {}; + } + std::ifstream file(path, std::ios::in | std::ios::binary); + if (!file) { + ORT_THROW("Failed to open file: " + path); + } + file.seekg(0, std::ios::end); + std::streamsize size = file.tellg(); + file.seekg(0, std::ios::beg); + std::vector buffer(size); + if (size > 0 && !file.read(buffer.data(), size)) { + ORT_THROW("Failed to read file: " + path); + } + return buffer; +} + +inline void WriteFile(const std::string& path, const void* data, size_t size) { + if (std::filesystem::exists(path)) { + std::ofstream file(path, std::ios::out | std::ios::binary | std::ios::trunc); + if (!file) { + ORT_THROW("Failed to open file for writing: " + path); + } + file.write(static_cast(data), size); + } else { + LOGS_DEFAULT(INFO) << "TensorRT RTX a new file cache was written to " << path << std::endl; + // Create new file + std::ofstream file(path, std::ios::out | std::ios::binary); + if (!file) { + ORT_THROW("Failed to create file: " + path); + } + file.write(static_cast(data), size); + } +} + +inline void WriteFile(const std::string& path, const std::vector& data) { WriteFile(path, data.data(), data.size()); } + +} // namespace file_utils +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_options_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_options_test.cc new file mode 100644 index 0000000000000..d415548876153 --- /dev/null +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_options_test.cc @@ -0,0 +1,82 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" +#include "test/providers/provider_test_utils.h" +#include "test/framework/test_utils.h" + +#include "test/util/include/scoped_env_vars.h" +#include "test/common/trt_op_test_utils.h" +#include "test/common/random_generator.h" +#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h" + +#include +#include + +using namespace std; +using namespace ONNX_NAMESPACE; +using namespace ::onnxruntime::logging; +extern std::unique_ptr ort_env; +namespace onnxruntime { + +namespace test { +size_t countFilesInDirectory(const std::string& dir_path) { + return std::distance(std::filesystem::directory_iterator(dir_path), std::filesystem::directory_iterator{}); +} + +TEST(NvExecutionProviderTest, RuntimeCaching) { + PathString model_name = ORT_TSTR("nv_execution_provider_runtime_caching.onnx"); + PathString model_name_ctx = ORT_TSTR("nv_execution_provider_runtime_caching_ctx.onnx"); + auto model_name_ctx_str = PathToUTF8(model_name_ctx); + clearFileIfExists(model_name_ctx); + std::string graph_name = "test"; + std::vector dims = {1, 3, 2}; + std::string runtime_cache_name = "./runtime_cache/"; + if (std::filesystem::exists(runtime_cache_name)) { + std::filesystem::remove_all(runtime_cache_name); + } + CreateBaseModel(model_name, graph_name, dims); + // AOT time + { + Ort::SessionOptions so; + Ort::RunOptions run_options; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, model_name_ctx_str.c_str()); + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {{"nv_runtime_cache_path", runtime_cache_name.c_str()}}); + Ort::Session session_object(*ort_env, model_name.c_str(), so); + + auto io_binding = generate_io_binding(session_object); + session_object.Run(run_options, io_binding); + } + // the cache will be dumped to disk upon session destruction + ASSERT_TRUE(std::filesystem::exists(runtime_cache_name.c_str())); + ASSERT_TRUE(1 == countFilesInDirectory(runtime_cache_name)); + + // use existing cache + { + Ort::SessionOptions so; + Ort::RunOptions run_options; + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {{"nv_runtime_cache_path", runtime_cache_name.c_str()}}); + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); + } + ASSERT_TRUE(1 == countFilesInDirectory(runtime_cache_name)); + + // create new cache + { + Ort::SessionOptions so; + Ort::RunOptions run_options; + std::string new_cache_name = "/tmp/runtime_cache_new/"; + if (std::filesystem::exists(new_cache_name)) { + std::filesystem::remove_all(new_cache_name); + } + so.AppendExecutionProvider(kNvTensorRTRTXExecutionProvider, {{"nv_runtime_cache_path", new_cache_name.c_str()}}); + { + Ort::Session session_object(*ort_env, model_name_ctx.c_str(), so); + } + // the cache will be dumped to disk upon session destruction + ASSERT_TRUE(std::filesystem::exists(new_cache_name.c_str())); + ASSERT_TRUE(1 == countFilesInDirectory(new_cache_name)); + } +} +} // namespace test +} // namespace onnxruntime