diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index a5aebb1fbf47c..c1a89ee778a9c 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -122,6 +122,12 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") message(FATAL_ERROR "WebGPU EP shared library build is not supported on Emscripten. Please use static library build.") endif() + + # Configure precompiled headers for shared library build + # PCH ensures ep/_pch.h is included first and improves compilation speed + target_precompile_headers(onnxruntime_providers_webgpu PRIVATE + "${REPO_ROOT}/include/onnxruntime/ep/_pch.h" + ) endif() set_target_properties(onnxruntime_providers_webgpu PROPERTIES CXX_STANDARD_REQUIRED ON) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 4f245f8b86711..f8878f814672a 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1023,6 +1023,61 @@ endif() partition_provider_test_srcs(all_tests onnxruntime_provider_test_srcs onnxruntime_test_all_srcs) +# Shared settings for onnxruntime test targets. +function(onnxruntime_apply_common_test_target_settings target) + if (UNIX AND (onnxruntime_USE_TENSORRT OR onnxruntime_USE_NV)) + # The test_main.cc includes NvInfer.h where it has many deprecated declarations + # simply ignore them for TensorRT EP build + set_property(TARGET ${target} APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") + endif() + + if (MSVC) + # TODO: The test code for OpenVINO, QNN, and WebGPU is getting flagged with a warning from ABSL for unreachable code. + # Need to figure out how those particular targets/build variants are failing, but regular windows is not. + target_compile_options(${target} PRIVATE "/wd4702") + endif() + + # TODO fix shorten-64-to-32 warnings + # there are some in builds where sizeof(size_t) != sizeof(int64_t), e.g., in 'ONNX Runtime Web CI Pipeline' + if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) + target_compile_options(${target} PRIVATE -Wno-error=shorten-64-to-32) + endif() +endfunction() + +# Set environment variables for plugin EP tests when run via CTest. +function(onnxruntime_set_plugin_ep_test_environment target) + if(onnxruntime_USE_WEBGPU AND NOT onnxruntime_BUILD_WEBGPU_EP_STATIC_LIB) + set(ORT_PLUGIN_EP_JSON_CONFIG "{\"ep_library_registration_name\": \"WebGPU_PluginEP\", \"ep_library_path\": \"onnxruntime_providers_webgpu.dll\", \"selected_ep_name\": \"WebGpuExecutionProvider\"}") + set_tests_properties(${target} PROPERTIES + ENVIRONMENT "ORT_UNIT_TEST_MAIN_DYNAMIC_PLUGIN_EP_CONFIG_JSON=${ORT_PLUGIN_EP_JSON_CONFIG}" + ) + # TODO: add for other plugin EPs if needed + # elseif() + endif() +endfunction() + +function(onnxruntime_apply_emscripten_test_link_settings target) + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + set_target_properties(${target} PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js) + set_target_properties(${target} PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) + set_target_properties(${target} PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s INITIAL_MEMORY=536870912 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js\" --pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1") + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") + endif() + if (onnxruntime_USE_JSEP) + set_target_properties(${target} PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) + set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS " --pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"") + endif() + + ### + ### if you want to investigate or debug a test failure in ${target}, replace the following line. + ### those flags slow down the CI test significantly, so we don't use them by default. + ### + # set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=2 -s SAFE_HEAP=1 -s STACK_OVERFLOW_CHECK=2") + set_property(TARGET ${target} APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=0 -s SAFE_HEAP=0 -s STACK_OVERFLOW_CHECK=1") + endif() +endfunction() + list(APPEND onnxruntime_test_all_srcs ${onnxruntime_unittest_main_src}) AddTest( TARGET onnxruntime_test_all @@ -1035,6 +1090,9 @@ AddTest( ) target_include_directories(onnxruntime_test_all PRIVATE ${ONNXRUNTIME_ROOT}/core/flatbuffers/schema) # ort.fbs.h +onnxruntime_apply_common_test_target_settings(onnxruntime_test_all) +onnxruntime_set_plugin_ep_test_environment(onnxruntime_test_all) + if (MSVC) # The warning means the type of two integral values around a binary operator is narrow than their result. # If we promote the two input values first, it could be more tolerant to integer overflow. @@ -1044,10 +1102,6 @@ if (MSVC) target_compile_options(onnxruntime_test_all PRIVATE "$<$:SHELL:--compiler-options /wd4244>" "$<$>:/wd4244>") - # TODO: The test code for OpenVINO, QNN, and WebGPU is getting flagged with a warning from ABSL for unreachabel code. - # Need to figure out how those particular targets/build variants are failing, but regular windows is not. - target_compile_options(onnxruntime_test_all PRIVATE "/wd4702") - # Avoid this compile error in graph_transform_test.cc and qdq_transformer_test.cc: # fatal error C1128: number of sections exceeded object file format limit: compile with /bigobj set_property(SOURCE "${TEST_SRC_DIR}/optimizer/graph_transform_test.cc" @@ -1057,18 +1111,6 @@ else() target_compile_options(onnxruntime_test_all PRIVATE "-Wno-parentheses") endif() -# TODO fix shorten-64-to-32 warnings -# there are some in builds where sizeof(size_t) != sizeof(int64_t), e.g., in 'ONNX Runtime Web CI Pipeline' -if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) - target_compile_options(onnxruntime_test_all PRIVATE -Wno-error=shorten-64-to-32) -endif() - -if (UNIX AND (onnxruntime_USE_TENSORRT OR onnxruntime_USE_NV)) - # The test_main.cc includes NvInfer.h where it has many deprecated declarations - # simply ignore them for TensorRT EP build - set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") -endif() - if (MSVC AND onnxruntime_ENABLE_STATIC_ANALYSIS) # attention_op_test.cc: Function uses '49152' bytes of stack: exceeds /analyze:stacksize '16384'.. target_compile_options(onnxruntime_test_all PRIVATE "/analyze:stacksize 131072") @@ -1099,25 +1141,7 @@ endif() if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_link_libraries(onnxruntime_test_all PRIVATE Python::Python) endif() -if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js) - set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) - set_target_properties(onnxruntime_test_all PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s INITIAL_MEMORY=536870912 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js\" --pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") - endif() - if (onnxruntime_USE_JSEP) - set_target_properties(onnxruntime_test_all PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) - set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " --pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"") - endif() - - ### - ### if you want to investigate or debug a test failure in onnxruntime_test_all, replace the following line. - ### those flags slow down the CI test significantly, so we don't use them by default. - ### - # set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=2 -s SAFE_HEAP=1 -s STACK_OVERFLOW_CHECK=2") - set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=0 -s SAFE_HEAP=0 -s STACK_OVERFLOW_CHECK=1") -endif() +onnxruntime_apply_emscripten_test_link_settings(onnxruntime_test_all) if (onnxruntime_ENABLE_ATEN) target_compile_definitions(onnxruntime_test_all PRIVATE ENABLE_ATEN) @@ -1233,6 +1257,9 @@ block() DEPENDS ${onnxruntime_provider_test_deps} ) + onnxruntime_apply_common_test_target_settings(onnxruntime_provider_test) + onnxruntime_set_plugin_ep_test_environment(onnxruntime_provider_test) + # Expose QNN SDK headers to unit tests via an interface target if(onnxruntime_USE_QNN) add_library(qnn_sdk_headers_include INTERFACE) @@ -1242,49 +1269,9 @@ block() target_link_libraries(onnxruntime_provider_test PRIVATE qnn_sdk_headers_include) endif() - if (UNIX AND (onnxruntime_USE_TENSORRT OR onnxruntime_USE_NV)) - # The test_main.cc includes NvInfer.h where it has many deprecated declarations - # simply ignore them for TensorRT EP build - set_property(TARGET onnxruntime_provider_test APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") - endif() - # enable dynamic plugin EP usage target_compile_definitions(onnxruntime_provider_test PRIVATE ORT_UNIT_TEST_ENABLE_DYNAMIC_PLUGIN_EP_USAGE) - - - if (MSVC) - # TODO: The test code for OpenVINO, QNN, and WebGPU is getting flagged with a warning from ABSL for unreachabel code. - # Need to figure out how those particular targets/build variants are failing, but regular windows is not. - target_compile_options(onnxruntime_provider_test PRIVATE "/wd4702") - endif() - - # TODO fix shorten-64-to-32 warnings - # there are some in builds where sizeof(size_t) != sizeof(int64_t), e.g., in 'ONNX Runtime Web CI Pipeline' - if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) - target_compile_options(onnxruntime_provider_test PRIVATE -Wno-error=shorten-64-to-32) - endif() - - # copied from onnxruntime_test_all - # TODO reuse instead of copy? - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - set_target_properties(onnxruntime_provider_test PROPERTIES LINK_DEPENDS ${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js) - set_target_properties(onnxruntime_provider_test PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) - set_target_properties(onnxruntime_provider_test PROPERTIES LINK_FLAGS "-s STACK_SIZE=5242880 -s INITIAL_MEMORY=536870912 -s ALLOW_MEMORY_GROWTH=1 -s MAXIMUM_MEMORY=4294967296 -s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm] --pre-js \"${TEST_SRC_DIR}/wasm/onnxruntime_test_adapter.js\" --pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\" -s \"EXPORTED_RUNTIME_METHODS=['FS']\" --preload-file ${CMAKE_CURRENT_BINARY_DIR}/testdata@/testdata -s EXIT_RUNTIME=1") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_property(TARGET onnxruntime_provider_test APPEND_STRING PROPERTY LINK_FLAGS " -s DEFAULT_PTHREAD_STACK_SIZE=131072 -s PROXY_TO_PTHREAD=1") - endif() - if (onnxruntime_USE_JSEP) - set_target_properties(onnxruntime_provider_test PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) - set_property(TARGET onnxruntime_provider_test APPEND_STRING PROPERTY LINK_FLAGS " --pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"") - endif() - - ### - ### if you want to investigate or debug a test failure in onnxruntime_provider_test, replace the following line. - ### those flags slow down the CI test significantly, so we don't use them by default. - ### - # set_property(TARGET onnxruntime_provider_test APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=2 -s SAFE_HEAP=1 -s STACK_OVERFLOW_CHECK=2") - set_property(TARGET onnxruntime_provider_test APPEND_STRING PROPERTY LINK_FLAGS " -s ASSERTIONS=0 -s SAFE_HEAP=0 -s STACK_OVERFLOW_CHECK=1") - endif() + onnxruntime_apply_emscripten_test_link_settings(onnxruntime_provider_test) if (IOS) add_custom_command( diff --git a/include/onnxruntime/ep/README.md b/include/onnxruntime/ep/README.md new file mode 100644 index 0000000000000..64d85f80313c0 --- /dev/null +++ b/include/onnxruntime/ep/README.md @@ -0,0 +1,7 @@ +## EP adapter + +This folder contains a set of C++ header files. They are used specifically for allowing ONNX Runtime internal kernel-based EPs to use the plugin-style EP API while keep minimal changes to existing code. + +### Usage + +Make sure to include "ep/_pch.h" for all source code in the implementation. Using PCH is recommended. diff --git a/include/onnxruntime/ep/_pch.h b/include/onnxruntime/ep/_pch.h new file mode 100644 index 0000000000000..ba9c3278693eb --- /dev/null +++ b/include/onnxruntime/ep/_pch.h @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "api.h" +#include "common.h" + +// This header is only used when building WebGPU/CUDA EP as a shared library. +// +// This header file is used as a precompiled header so it is always included first. + +#pragma push_macro("ORT_EP_API_ADAPTER_HEADER_INCLUDED") +#define ORT_EP_API_ADAPTER_HEADER_INCLUDED + +#include "adapter/allocator.h" +#include "adapter/logging.h" +#include "adapter/ep.h" +#include "adapter/kernel_registry.h" + +#pragma pop_macro("ORT_EP_API_ADAPTER_HEADER_INCLUDED") + +// +// EP specific using declarations +// + +#define EP_SPECIFIC_USING_DECLARATIONS \ + using FuncManager = onnxruntime::ep::adapter::FuncManager; \ + using KernelCreatePtrFn = onnxruntime::ep::adapter::KernelCreatePtrFn; \ + using KernelDefBuilder = onnxruntime::ep::adapter::KernelDefBuilder; \ + using KernelRegistry = onnxruntime::ep::adapter::KernelRegistry; \ + using KernelCreateInfo = onnxruntime::ep::adapter::KernelCreateInfo; \ + using BuildKernelCreateInfoFn = onnxruntime::ep::adapter::KernelCreateInfo (*)(); \ + using OpKernelInfo = onnxruntime::ep::adapter::OpKernelInfo; \ + using OpKernelContext = onnxruntime::ep::adapter::OpKernelContext; \ + using OpKernel = onnxruntime::ep::adapter::OpKernel; \ + using DataTransferManager = onnxruntime::ep::adapter::DataTransferManager; \ + namespace logging { \ + using Logger = onnxruntime::ep::adapter::Logger; \ + } + +namespace onnxruntime { +namespace webgpu { +EP_SPECIFIC_USING_DECLARATIONS +} // namespace webgpu +namespace cuda { +EP_SPECIFIC_USING_DECLARATIONS +} // namespace cuda + +#ifndef DISABLE_CONTRIB_OPS +namespace contrib { +namespace webgpu { +EP_SPECIFIC_USING_DECLARATIONS +} // namespace webgpu +namespace cuda { +EP_SPECIFIC_USING_DECLARATIONS +} // namespace cuda +} // namespace contrib +#endif + +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/allocator.h b/include/onnxruntime/ep/adapter/allocator.h new file mode 100644 index 0000000000000..a9bbb6071c48b --- /dev/null +++ b/include/onnxruntime/ep/adapter/allocator.h @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// A bridge class between the EP API OrtAllocator and an IAllocator implementation. +/// +class Allocator : public OrtAllocator { + public: + explicit Allocator(const OrtMemoryInfo* memory_info, AllocatorPtr impl) + : OrtAllocator{}, memory_info_(memory_info), impl_(impl) { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + } + + private: + static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept { + auto* allocator = static_cast(this_ptr); + return allocator->impl_->Alloc(size); + } + + static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept { + auto* allocator = static_cast(this_ptr); + allocator->impl_->Free(p); + } + + static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept { + auto* allocator = static_cast(this_ptr); + return allocator->memory_info_; + } + + const OrtMemoryInfo* memory_info_; + AllocatorPtr impl_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/data_transfer_manager.h b/include/onnxruntime/ep/adapter/data_transfer_manager.h new file mode 100644 index 0000000000000..7b98a440c7050 --- /dev/null +++ b/include/onnxruntime/ep/adapter/data_transfer_manager.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include "core/common/status.h" +#include "core/common/common.h" +#include "core/framework/data_transfer.h" +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// An adapter class partially implementing the facade of `onnxruntime::DataTransferManager`. +/// +struct DataTransferManager { + explicit DataTransferManager(std::unique_ptr impl) : impl_{std::move(impl)} {} + + common::Status CopyTensor(const Tensor& src, Tensor& dst) const { + if (src.Shape().Size() != dst.Shape().Size()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, + FAIL, + "Tensor size mismatch: source tensor size is ", + src.Shape().Size(), + ", destination tensor size is ", + dst.Shape().Size()); + } + + if (impl_->CanCopy(src.Location().device, dst.Location().device)) { + return impl_->CopyTensor(src, dst); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, + FAIL, + "There's no data transfer registered for copying tensors from ", + src.Location().device.ToString(), + " to ", + dst.Location().device.ToString()); + } + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DataTransferManager); + std::unique_ptr impl_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/ep.h b/include/onnxruntime/ep/adapter/ep.h new file mode 100644 index 0000000000000..02a6c2f07b0c3 --- /dev/null +++ b/include/onnxruntime/ep/adapter/ep.h @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include "data_transfer_manager.h" + +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// Wrapper around IExecutionProvider to expose via OrtEp. +/// +class Ep : public OrtEp { + protected: + explicit Ep(IExecutionProvider* impl, AllocatorPtr temp_space_cpu_allocator, AllocatorPtr temp_space_allocator) + : OrtEp{}, + impl_(impl), + data_transfer_manager_{impl->GetDataTransfer()}, + profiler_{impl->GetProfiler()}, + temp_space_cpu_allocator_{temp_space_cpu_allocator}, + temp_space_allocator_{temp_space_allocator} { + } + + public: + inline IExecutionProvider* EpImpl() const noexcept { + return impl_.get(); + } + inline const DataTransferManager& GetDataTransferManager() const noexcept { + return data_transfer_manager_; + } + [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const { + *output = temp_space_cpu_allocator_; + return Status::OK(); + } + [[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const { + *output = temp_space_allocator_; + return Status::OK(); + } + + private: + std::unique_ptr impl_; + DataTransferManager data_transfer_manager_; + std::unique_ptr profiler_; + AllocatorPtr temp_space_cpu_allocator_; + AllocatorPtr temp_space_allocator_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/kernel_def.h b/include/onnxruntime/ep/adapter/kernel_def.h new file mode 100644 index 0000000000000..b3d3c83dd0e90 --- /dev/null +++ b/include/onnxruntime/ep/adapter/kernel_def.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// An adapter class partially implementing the facade of `onnxruntime::KernelDef`. +/// +class KernelDef { + public: + explicit KernelDef(const OrtKernelInfo* kernel_info) : kernel_info_{kernel_info} {} + + const std::string OpName() const { + return kernel_info_.GetNodeName(); + } + + const std::string Domain() const { + return kernel_info_.GetOperatorDomain(); + } + + private: + const Ort::ConstKernelInfo kernel_info_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/kernel_def_builder.h b/include/onnxruntime/ep/adapter/kernel_def_builder.h new file mode 100644 index 0000000000000..664c88919cb8a --- /dev/null +++ b/include/onnxruntime/ep/adapter/kernel_def_builder.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include + +#include "core/framework/data_types.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// Gets an OrtMLDataType for a tensor type. Throws on error. +/// +/// +/// +inline const OrtDataType* GetTensorType(ONNXTensorElementDataType elem_type) { + const OrtEpApi& ep_api = Ort::GetEpApi(); + const OrtDataType* result = nullptr; + + Ort::ThrowOnError(ep_api.GetTensorDataType(elem_type, &result)); + return result; +} + +inline const OrtDataType* MLDataTypeToOrtDataType(MLDataType ml_type) { + auto tensor_type = ml_type->AsTensorType(); + EP_ENFORCE(tensor_type != nullptr, "EP Kernel registration only supports tensor types."); + auto elem_type = tensor_type->GetElementType(); + auto primitive_type = static_cast(elem_type); + auto onnx_type = static_cast(primitive_type->GetDataType()); + return GetTensorType(onnx_type); +} + +/// +/// An adapter class partially implementing the facade of `onnxruntime::KernelDefBuilder`. +/// +struct KernelDefBuilder { + static std::unique_ptr Create() { return std::make_unique(); } + + explicit KernelDefBuilder() {} + + KernelDefBuilder& SetName(const char* op_name) { + builder_.SetOperatorType(op_name); + return *this; + } + + KernelDefBuilder& SetDomain(const char* domain) { + builder_.SetDomain(domain); + return *this; + } + + KernelDefBuilder& SinceVersion(int since_version) { + return SinceVersion(since_version, INT_MAX); + } + + KernelDefBuilder& SinceVersion(int since_version_start, int since_version_end) { + builder_.SetSinceVersion(since_version_start, since_version_end); + return *this; + } + + KernelDefBuilder& Provider(const char* provider_type) { + builder_.SetExecutionProvider(provider_type); + return *this; + } + + KernelDefBuilder& TypeConstraint(const char* arg_name, std::vector types) { + std::vector ort_types; + ort_types.reserve(types.size()); + for (const auto& type : types) { + ort_types.push_back(MLDataTypeToOrtDataType(type)); + } + builder_.AddTypeConstraint(arg_name, ort_types); + return *this; + } + + KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType type) { + builder_.AddTypeConstraint(arg_name, MLDataTypeToOrtDataType(type)); + return *this; + } + + KernelDefBuilder& MayInplace(const std::vector>& inplaces) { + for (const auto& pair : inplaces) { + builder_.AddInputOutputMutableAlias(pair.first, pair.second); + } + return *this; + } + KernelDefBuilder& MayInplace(int input_index, int output_index) { + builder_.AddInputOutputMutableAlias(input_index, output_index); + return *this; + } + + KernelDefBuilder& Alias(const std::vector>& aliases) { + for (const auto& pair : aliases) { + builder_.AddInputOutputAlias(pair.first, pair.second); + } + return *this; + } + KernelDefBuilder& Alias(int input_index, int output_index) { + builder_.AddInputOutputAlias(input_index, output_index); + return *this; + } + + KernelDefBuilder& InputMemoryType(OrtMemType type, int input_index) { + builder_.SetInputMemType(input_index, type); + return *this; + } + + KernelDefBuilder& InputMemoryType(OrtMemType type, const std::vector& input_indexes) { + for (int input_index : input_indexes) { + builder_.SetInputMemType(input_index, type); + } + return *this; + } + + KernelDefBuilder& OutputMemoryType(OrtMemType type, int output_index) { + builder_.SetOutputMemType(output_index, type); + return *this; + } + + KernelDefBuilder& OutputMemoryType(OrtMemType type, const std::vector& output_indexes) { + for (int output_index : output_indexes) { + builder_.SetOutputMemType(output_index, type); + } + return *this; + } + + KernelDefBuilder& ExecQueueId(int queue_id) { return *this; } + + Ort::KernelDef Build() { return builder_.Build(); } + + private: + Ort::KernelDefBuilder builder_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/kernel_registry.h b/include/onnxruntime/ep/adapter/kernel_registry.h new file mode 100644 index 0000000000000..01474fa0cb3ae --- /dev/null +++ b/include/onnxruntime/ep/adapter/kernel_registry.h @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include + +#include "kernel_def_builder.h" +#include "op_kernel_info.h" +#include "op_kernel.h" + +#include "core/graph/basic_types.h" +#include "core/framework/error_code_helper.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +struct FuncManager {}; +using KernelCreatePtrFn = std::add_pointer& out)>::type; + +/// +/// An adapter class partially implementing the facade of `onnxruntime::KernelCreateInfo`. +/// +struct KernelCreateInfo { + Ort::KernelDef kernel_def; + KernelCreatePtrFn kernel_create_func; + Status status; + + KernelCreateInfo(Ort::KernelDef definition, + KernelCreatePtrFn create_func) + : kernel_def(std::move(definition)), + kernel_create_func(create_func) { + assert(kernel_def != nullptr); + } + + KernelCreateInfo(KernelCreateInfo&& other) noexcept + : kernel_def(std::move(other.kernel_def)), + kernel_create_func(std::move(other.kernel_create_func)) {} + + KernelCreateInfo() = default; +}; + +/// +/// An adapter class partially implementing the facade of `onnxruntime::KernelRegistry`. +/// +struct KernelRegistry { + KernelRegistry() = default; + + static OrtStatus* CreateKernel(void* kernel_create_func_state, const OrtKernelInfo* info, OrtKernelImpl** out) { + FuncManager func_mgr; // not used + std::unique_ptr kernel; + KernelCreatePtrFn create_func = reinterpret_cast(kernel_create_func_state); + Status status = create_func(func_mgr, OpKernelInfo(info), kernel); + if (!status.IsOK()) { + return ToOrtStatus(status); + } + *out = nullptr; + status = kernel->CreateControlFlowKernelImpl(info, out); + if (!status.IsOK()) { + return ToOrtStatus(status); + } + if (*out == nullptr) { + *out = new KernelImpl(std::move(kernel)); + } + return nullptr; + } + + Status Register(KernelCreateInfo&& create_info) { + registry_.AddKernel(create_info.kernel_def, + KernelRegistry::CreateKernel, + static_cast(create_info.kernel_create_func)); + return Status::OK(); + } + + // Implicit conversion to OrtKernelRegistry* for compatibility with C API + operator OrtKernelRegistry*() const noexcept { + return registry_.operator OrtKernelRegistry*(); + } + + // Release ownership of the underlying OrtKernelRegistry* + OrtKernelRegistry* release() { + return registry_.release(); + } + + private: + Ort::KernelRegistry registry_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/logging.h b/include/onnxruntime/ep/adapter/logging.h new file mode 100644 index 0000000000000..b93c06bb3f12e --- /dev/null +++ b/include/onnxruntime/ep/adapter/logging.h @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include "core/common/logging/logging.h" +#include "core/common/path_string.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +struct Logger { + Logger(const OrtLogger* logger) : logger_(logger) {} + + bool OutputIsEnabled(logging::Severity severity, logging::DataType /* data_type */) const noexcept { + return ((OrtLoggingLevel)severity >= logger_.GetLoggingSeverityLevel()); + } + + void Log(logging::Severity severity, + const char* file_path, + int line_number, + const char* func_name, + const char* message) const noexcept { + auto path_string = onnxruntime::ToPathString(file_path); + logger_.LogMessage((OrtLoggingLevel)severity, + path_string.c_str(), + line_number, + func_name, + message); + } + + static const Logger& DefaultLogger() { return *instance_; } + static void CreateDefaultLogger(const OrtLogger* logger) { + instance_ = new Logger(logger); + } + static void DestroyDefaultLogger() { + delete instance_; + instance_ = nullptr; + } + + private: + Ort::Logger logger_; + inline static Logger* instance_ = nullptr; +}; + +namespace detail { +struct LoggerCapture { + LoggerCapture(const Logger& logger, + logging::Severity severity, + const char* category, + logging::DataType dataType, + const CodeLocation& location) : logger_{logger}, + severity_{severity}, + category_{category}, + data_type_{dataType}, + location_{location} {} + + ~LoggerCapture() { + logger_.Log(severity_, location_.file_and_path.c_str(), location_.line_num, + location_.function.c_str(), stream_.str().c_str()); + } + + std::ostream& Stream() noexcept { + return stream_; + } + + const Logger& logger_; + logging::Severity severity_; + const char* category_; + logging::DataType data_type_; + const CodeLocation& location_; + std::ostringstream stream_; +}; + +// Helper functions to dispatch to the correct Capture type based on logger type +inline ::onnxruntime::logging::Capture CreateMessageCapture( + const ::onnxruntime::logging::Logger& logger, + ::onnxruntime::logging::Severity severity, + const char* category, + ::onnxruntime::logging::DataType datatype, + const CodeLocation& location) { + return ::onnxruntime::logging::Capture(logger, severity, category, datatype, location); +} + +inline detail::LoggerCapture CreateMessageCapture( + const Logger& logger, + ::onnxruntime::logging::Severity severity, + const char* category, + ::onnxruntime::logging::DataType datatype, + const CodeLocation& location) { + return detail::LoggerCapture(logger, severity, category, datatype, location); +} + +} // namespace detail +} // namespace adapter +} // namespace ep +} // namespace onnxruntime + +// Undefine and redefine LOGS_DEFAULT +#undef LOGS_DEFAULT_CATEGORY +#define LOGS_DEFAULT_CATEGORY(severity, category) \ + LOGS_CATEGORY(::onnxruntime::ep::adapter::Logger::DefaultLogger(), severity, category) + +#undef CREATE_MESSAGE +#define CREATE_MESSAGE(logger, severity, category, datatype) \ + ::onnxruntime::ep::adapter::detail::CreateMessageCapture(logger, ::onnxruntime::logging::Severity::k##severity, category, datatype, ORT_WHERE) diff --git a/include/onnxruntime/ep/adapter/node.h b/include/onnxruntime/ep/adapter/node.h new file mode 100644 index 0000000000000..b46cc1ebe64d4 --- /dev/null +++ b/include/onnxruntime/ep/adapter/node.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// An adapter class partially implementing the facade of `onnxruntime::Node`. +/// +struct Node { + explicit Node(const OrtKernelInfo* kernel_info) : kernel_info_{kernel_info} {} + /** Gets the Node's name. */ + const std::string Name() const noexcept { + return kernel_info_.GetNodeName(); + } + + /** Gets the Node's operator type. */ + const std::string OpType() const noexcept { + return kernel_info_.GetOperatorType(); + } + + /** Gets the since version of the operator. */ + int SinceVersion() const noexcept { + return kernel_info_.GetOperatorSinceVersion(); + } + + private: + const Ort::ConstKernelInfo kernel_info_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/op_kernel.h b/include/onnxruntime/ep/adapter/op_kernel.h new file mode 100644 index 0000000000000..63c3cf428d303 --- /dev/null +++ b/include/onnxruntime/ep/adapter/op_kernel.h @@ -0,0 +1,201 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include +#include + +#include "core/framework/allocator.h" +#include "core/framework/tensor.h" + +#include "node.h" +#include "op_kernel_info.h" +#include "tensor_helper.h" + +namespace onnxruntime { +struct PrePackedWeights; +struct TensorShape; +} // namespace onnxruntime + +namespace onnxruntime { +namespace ep { +namespace adapter { + +struct OpKernelContext; + +/// +/// An adapter class partially implementing the facade of `onnxruntime::OpKernel`. +/// +struct OpKernel { + explicit OpKernel(const OpKernelInfo& info) : op_kernel_info_{info} {} + virtual ~OpKernel() {} + + Node Node() const { + return op_kernel_info_.node(); + } + const OpKernelInfo& Info() const { + return op_kernel_info_; + } + + virtual Status CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) { + return Status::OK(); + } + + virtual Status Compute(OpKernelContext* p_op_kernel_context) const = 0; + virtual Status PrePack(const Tensor& tensor, + int input_idx, + AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + return Status::OK(); + } + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernel); + OpKernelInfo op_kernel_info_; +}; + +/// +/// An adapter class partially implementing the facade of `onnxruntime::OpKernelContext`. +/// +struct OpKernelContext { + explicit OpKernelContext(OrtKernelContext* context, const OpKernel& op_kernel) : context_{context}, op_kernel_{op_kernel} { + input_tensors_.resize(InputCount()); + output_tensors_.resize(OutputCount()); + } + + template >> + const T* Input(int index) const { + if (index < 0 || static_cast(index) >= input_tensors_.size()) { + return nullptr; + } + if (input_tensors_[index] != nullptr) { + return static_cast(input_tensors_[index].get()); + } + + auto input = context_.GetInput(index); + if (input == nullptr || !input.IsTensor()) { + return nullptr; + } + + input_tensors_[index] = CreateTensorFromApiValue(input); + return static_cast(input_tensors_[index].get()); + } + Tensor* Output(int index, const TensorShape& shape) { + if (index < 0 || static_cast(index) >= output_tensors_.size()) { + return nullptr; + } + if (output_tensors_[index] != nullptr) { + return output_tensors_[index].get(); + } + + auto output = context_.GetOutput(index, shape.GetDims().data(), shape.GetDims().size()); + if (output == nullptr) { + return nullptr; + } + + output_tensors_[index] = CreateTensorFromApiValue(output); + return output_tensors_[index].get(); + } + Tensor* Output(int index, const std::vector& shape) { + return Output(index, TensorShape{shape}); + } + Tensor* Output(int index, const std::initializer_list& shape) { + return Output(index, TensorShape{shape}); + } + [[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const { + return static_cast(op_kernel_.Info().GetKernelInfo().GetEp())->GetTempSpaceCPUAllocator(output); + } + [[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const { + return static_cast(op_kernel_.Info().GetKernelInfo().GetEp())->GetTempSpaceAllocator(output); + } + size_t InputCount() const { + return context_.GetInputCount(); + } + size_t OutputCount() const { + return context_.GetOutputCount(); + } + + private: + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(OpKernelContext); + Ort::KernelContext context_; + const OpKernel& op_kernel_; + mutable std::vector> input_tensors_; + std::vector> output_tensors_; +}; + +/// +/// A bridge class between `onnxruntime::ep::adapter::OpKernel` and `onnxruntime::OrtKernelImpl`. +/// +struct KernelImpl : OrtKernelImpl { + explicit KernelImpl(std::unique_ptr impl) + : OrtKernelImpl{}, impl_(std::move(impl)) { + ort_version_supported = ORT_API_VERSION; + Compute = ComputeImpl; + Release = ReleaseImpl; + PrePackWeight = PrePackWeightImpl; + } + + private: + static OrtStatus* ORT_API_CALL ComputeImpl(_In_ OrtKernelImpl* this_ptr, + _In_ OrtKernelContext* context) noexcept { + const auto* kernel_impl = static_cast(this_ptr)->impl_.get(); + OpKernelContext ctx{context, *kernel_impl}; + Status status; + ORT_TRY { + status = kernel_impl->Compute(&ctx); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what()); + }); + } + if (status.IsOK()) { + return nullptr; + } else { + return Ort::Status{status.ErrorMessage().c_str(), static_cast(status.Code())}.release(); + } + } + + static void ORT_API_CALL ReleaseImpl(_In_ OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + static OrtStatus* ORT_API_CALL PrePackWeightImpl(_In_ OrtKernelImpl* this_ptr, + _In_ const OrtValue* weight, + int input_index, + _In_ OrtAllocator* /* allocator */, + _In_opt_ OrtSharedPrePackedWeightCache* /* prepacked_weight_cache */, + _Out_ bool* is_packed) noexcept { + auto* kernel_impl = static_cast(this_ptr)->impl_.get(); + const auto tensor = CreateTensorFromApiValue(Ort::ConstValue{weight}); + Status status; + ORT_TRY { + status = kernel_impl->PrePack(*tensor.get(), input_index, AllocatorPtr{}, *is_packed, nullptr); + } + ORT_CATCH(const std::exception& ex) { + ORT_HANDLE_EXCEPTION([&]() { + status = ORT_MAKE_STATUS(ONNXRUNTIME, RUNTIME_EXCEPTION, ex.what()); + }); + } + if (!status.IsOK()) { + return Ort::Status{status.ErrorMessage().c_str(), static_cast(status.Code())}.release(); + } + return nullptr; + } + + ~KernelImpl() = default; + + private: + std::unique_ptr impl_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/op_kernel_info.h b/include/onnxruntime/ep/adapter/op_kernel_info.h new file mode 100644 index 0000000000000..df706e5605be5 --- /dev/null +++ b/include/onnxruntime/ep/adapter/op_kernel_info.h @@ -0,0 +1,160 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include + +#include "core/common/status.h" +#include "core/framework/tensor_shape.h" +#include "core/framework/tensor.h" + +#include "node.h" +#include "kernel_def.h" +#include "tensor_helper.h" + +namespace onnxruntime { +struct DataTransferManager; +struct IExecutionProvider; +} // namespace onnxruntime + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// An adapter class partially implementing the facade of `onnxruntime::OpKernelInfo`. +/// +struct OpKernelInfo { + // + // A helper struct to cache kernel info data + // + // Because `KernelCreatePtrFn` is defined to use `const OrtKernelInfo&` as parameter type of the kernel creation function, `OpKernelInfo` has to be copyable. + // This means we cannot store cached data like `constant_input_tensors_` in `OpKernelInfo` directly to avoid ownership issues. + // + // As a workaround, we define this struct `KernelInfoCache` here to represent the cached data. We use a shared pointer to `KernelInfoCache` in `OpKernelInfo` + // to manage the lifetime of the cached data. + struct KernelInfoCache { + explicit KernelInfoCache(const OrtKernelInfo* kernel_info) : kernel_info_(kernel_info) { + Ort::ConstKernelInfo info{kernel_info}; + const int input_count = info.GetInputCount(); + constant_input_tensors.resize(input_count); + for (int i = 0; i < input_count; ++i) { + int is_constant = 0; + Ort::ConstValue const_input = info.GetTensorConstantInput(i, &is_constant); + if (is_constant && const_input != nullptr && const_input.IsTensor()) { + constant_input_tensors[i] = CreateTensorFromApiValue(const_input); + } + } + } + const OrtKernelInfo* kernel_info_; + std::vector> constant_input_tensors; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(KernelInfoCache); + }; + + explicit OpKernelInfo(const OrtKernelInfo* info) : info_(info), cache_{std::make_shared(info)} { + } + + const DataTransferManager& GetDataTransferManager() const noexcept { + return (static_cast(info_.GetEp()))->GetDataTransferManager(); + } + Node node() const noexcept { + return Node{cache_->kernel_info_}; + } + const IExecutionProvider* GetExecutionProvider() const noexcept { + return (static_cast(info_.GetEp()))->EpImpl(); + } + + KernelDef GetKernelDef() const noexcept { + return KernelDef{cache_->kernel_info_}; + } + + const Ort::ConstKernelInfo GetKernelInfo() const noexcept { + return info_; + } + + int GetInputCount() const noexcept { + return info_.GetInputCount(); + } + + bool TryGetConstantInput(int input_index, const Tensor** constant_input_value) const { + if (input_index < 0 || static_cast(input_index) >= cache_->constant_input_tensors.size()) { + return false; + } + const Tensor* tensor = cache_->constant_input_tensors[input_index].get(); + if (tensor != nullptr) { + *constant_input_value = tensor; + return true; + } + return false; + } + + template + [[nodiscard]] T GetAttrOrDefault(const std::string& name, const T& default_value) const { + T tmp; + return GetAttr(name, &tmp).IsOK() ? tmp : default_value; + } + template + void GetAttrOrDefault(const std::string& name, T* value, const T& default_value) const { + if (!GetAttr(name, value).IsOK()) + *value = default_value; + } + template + [[nodiscard]] T GetAttr(const std::string& name) const { + T value; + ORT_THROW_IF_ERROR(GetAttr(name, &value)); + return value; + } + template + Status GetAttr(const std::string& name, T* value) const { + try { + *value = info_.GetAttribute(name.c_str()); + return Status::OK(); + } catch (const Ort::Exception& ex) { + return Status(onnxruntime::common::ONNXRUNTIME, ex.GetOrtErrorCode(), ex.what()); + } + } + template + Status GetAttrs(const std::string& name, std::vector& values) const { + try { + values = info_.GetAttributes(name.c_str()); + return Status::OK(); + } catch (const Ort::Exception& ex) { + return Status(onnxruntime::common::ONNXRUNTIME, ex.GetOrtErrorCode(), ex.what()); + } + } + + Status GetAttrs(const std::string& name, TensorShapeVector& out) const { + std::vector shape; + Status status = GetAttrs(name, shape); + if (status.IsOK()) { + out.reserve(shape.size()); + out.assign(shape.begin(), shape.end()); + } + return status; + } + + template + [[nodiscard]] std::vector GetAttrsOrDefault(const std::string& name, + const std::vector& default_value = {}) const { + std::vector tmp; + return GetAttrs(name, tmp).IsOK() ? tmp : default_value; + } + [[nodiscard]] TensorShapeVector GetAttrsOrDefault(const std::string& name, + const TensorShapeVector& default_value = {}) const { + TensorShapeVector tmp; + return GetAttrs(name, tmp).IsOK() ? tmp : default_value; + } + + private: + const Ort::ConstKernelInfo info_; + std::shared_ptr cache_; +}; + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/adapter/tensor_helper.h b/include/onnxruntime/ep/adapter/tensor_helper.h new file mode 100644 index 0000000000000..4d8ee078d5836 --- /dev/null +++ b/include/onnxruntime/ep/adapter/tensor_helper.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED) +#error "This header should not be included directly. Include ep/_pch.h instead." +#endif + +#include +#include + +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace ep { +namespace adapter { + +/// +/// Create an unowned onnxruntime::Tensor from a tensor OrtValue from C API. +/// +inline std::unique_ptr CreateTensorFromApiValue(const OrtValue* ort_value) { + Ort::ConstValue value{ort_value}; + EP_ENFORCE(value.IsTensor(), "Only tensor OrtValue is supported."); + + auto type_and_shape_info = value.GetTypeInfo().GetTensorTypeAndShapeInfo(); + auto type = type_and_shape_info.GetElementType(); + auto shape_vec = type_and_shape_info.GetShape(); + + auto memory_info = value.GetTensorMemoryInfo(); + MLDataType data_type = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType(); + + return std::make_unique(data_type, + TensorShape{shape_vec}, + const_cast(value.GetTensorRawData()), + OrtMemoryInfo{ + memory_info.GetAllocatorName(), + memory_info.GetAllocatorType(), + OrtDevice{ + static_cast(memory_info.GetDeviceType()), + static_cast(memory_info.GetMemoryType()), + static_cast(memory_info.GetVendorId()), + static_cast(memory_info.GetDeviceId()), + + }, + memory_info.GetMemoryType()}); +} + +} // namespace adapter +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/api.h b/include/onnxruntime/ep/api.h new file mode 100644 index 0000000000000..b05fb9e6d1cb3 --- /dev/null +++ b/include/onnxruntime/ep/api.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#pragma push_macro("ORT_API_MANUAL_INIT") +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#pragma pop_macro("ORT_API_MANUAL_INIT") + +namespace onnxruntime { +namespace ep { + +struct ApiPtrs { + const OrtApi& ort; + const OrtEpApi& ep; + const OrtModelEditorApi& model_editor; +}; + +namespace detail { +inline std::unique_ptr g_api_ptrs; +} + +/// +/// Get the global instance of ApiPtrs. +/// +inline const ApiPtrs& Api() { + return *detail::g_api_ptrs; +} + +/// +/// Initialize the EP API pointers and global OrtEnv if not already done. +/// +inline void ApiInit(const OrtApiBase* ort_api_base) { + // Manual init for the C++ API + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + const OrtEpApi* ep_api = ort_api->GetEpApi(); + const OrtModelEditorApi* model_editor_api = ort_api->GetModelEditorApi(); + Ort::InitApi(ort_api); + + // Initialize the global API instance + if (!detail::g_api_ptrs) { + detail::g_api_ptrs = std::make_unique( + ApiPtrs{*ort_api, *ep_api, *model_editor_api}); + } +} + +} // namespace ep +} // namespace onnxruntime diff --git a/include/onnxruntime/ep/common.h b/include/onnxruntime/ep/common.h new file mode 100644 index 0000000000000..12118c938820c --- /dev/null +++ b/include/onnxruntime/ep/common.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + +#define RETURN_IF(cond, ort_api, msg) \ + do { \ + if ((cond)) { \ + return (ort_api).CreateStatus(ORT_EP_FAIL, (msg)); \ + } \ + } while (0) + +// see ORT_ENFORCE for implementations that also capture a stack trace and work in builds with exceptions disabled +// NOTE: In this simplistic implementation you must provide an argument, even it if's an empty string +#define EP_ENFORCE(condition, ...) \ + do { \ + if (!(condition)) { \ + std::ostringstream oss; \ + oss << "EP_ENFORCE failed: " << #condition << " "; \ + oss << __VA_ARGS__; \ + throw std::runtime_error(oss.str()); \ + } \ + } while (false) + +// Ignores an OrtStatus* while taking ownership of it so that it does not get leaked. +#define IGNORE_ORTSTATUS(status_expr) \ + do { \ + OrtStatus* _status = (status_expr); \ + Ort::Status _ignored{_status}; \ + } while (false) diff --git a/include/onnxruntime/ep/get_capability_utils.h b/include/onnxruntime/ep/get_capability_utils.h new file mode 100644 index 0000000000000..2f6b9dfbe1d5b --- /dev/null +++ b/include/onnxruntime/ep/get_capability_utils.h @@ -0,0 +1,253 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace ep { + +using NodeId = size_t; +constexpr int64_t kSmallInitializerThreshold = 100; + +constexpr inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) { + return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput; +} + +// Get all output nodes that consume an output from the given node. +static OrtStatus* GetOutputNodes(gsl::span node_outputs, std::vector& result) { + try { + std::vector output_nodes; + output_nodes.reserve(node_outputs.size()); // May have more + + // Gather the OrtNode consumers of every output. + for (Ort::ConstValueInfo output : node_outputs) { + if (output == nullptr) continue; // Skip missing optional output + + auto consumers_info = output.GetConsumers(); + for (const auto& consumer : consumers_info) { + output_nodes.push_back(consumer.node); + } + } + + result = std::move(output_nodes); + return nullptr; + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } catch (...) { + Ort::Status status("Unknown exception", ORT_EP_FAIL); + return status.release(); + } +} + +// Returns nodes that should be assigned to CPU EP instead of this example EP to avoid costly I/O copies. +// Based on GetCpuPreferredNodes from onnxruntime/core/framework/fallback_cpu_capability.cc +OrtStatus* GetCpuPreferredNodes(const OrtGraph& ort_graph, OrtEpGraphSupportInfo& graph_support_info, + const OrtLogger& logger, gsl::span tentative_nodes, + /*out*/ std::unordered_set& cpu_preferred_nodes) { + try { + const OrtApi& ort_api = Ort::GetApi(); + const OrtEpApi& ep_api = Ort::GetEpApi(); + Ort::ConstGraph graph{&ort_graph}; + std::vector ordered_nodes = graph.GetNodes(); + + if (ordered_nodes.empty()) { + return nullptr; + } + + std::unordered_map node_id_to_node; + std::unordered_map node_id_to_order_map; + for (size_t i = 0, num_nodes = ordered_nodes.size(); i < num_nodes; i++) { + NodeId node_id = ordered_nodes[i].GetId(); + node_id_to_node[node_id] = ordered_nodes[i]; + node_id_to_order_map[node_id] = i; + } + + // If return false, n1 will be output first; If return true, n2 will be output first + auto greater_order_comp = [&](const NodeId node_id1, const NodeId node_id2) { + return node_id_to_order_map[node_id1] > node_id_to_order_map[node_id2]; + }; + std::priority_queue, decltype(greater_order_comp)> candidates(greater_order_comp); + std::unordered_set cpu_output_args; + + std::unordered_set provider_nodes; + provider_nodes.reserve(tentative_nodes.size()); + + std::unordered_map node_to_kernel; + node_to_kernel.reserve(tentative_nodes.size()); + + for (const OrtNode* ort_node : tentative_nodes) { + Ort::ConstNode node(ort_node); + NodeId node_id = node.GetId(); + + provider_nodes.insert(node_id); + + // Expect at least one registry has a target provider's kernel for this node. + const OrtKernelDef* ort_kernel_def = nullptr; + RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_LookUpKernel(&graph_support_info, node, &ort_kernel_def)); + RETURN_IF(ort_kernel_def == nullptr, ort_api, "Must have a registered kernel definition on the target EP"); + + Ort::ConstKernelDef kernel_def(ort_kernel_def); + node_to_kernel.insert({node_id, kernel_def}); + + // Find all the direct consumers of CPU tensors. + std::vector outputs = node.GetOutputs(); + for (size_t out_index = 0; out_index < outputs.size(); out_index++) { + Ort::ConstValueInfo output = outputs[out_index]; + if (output == nullptr) continue; // Skip missing optional output + + bool is_output_on_cpu = MemTypeOnCpuExplicitly(kernel_def.GetOutputMemType(out_index)); + if (is_output_on_cpu) { + cpu_output_args.insert(output); + + auto consumer_infos = output.GetConsumers(); + for (const auto& consumer_info : consumer_infos) { + candidates.push(consumer_info.node.GetId()); + ORT_CXX_LOGF(Ort::Logger(&logger), ORT_LOGGING_LEVEL_INFO, "Candidate for fallback CPU execution: %s\n", + consumer_info.node.GetName().c_str()); + } + } + } + } + + std::unordered_set visited; + visited.reserve(candidates.size()); + + std::unordered_set cpu_nodes; + cpu_nodes.reserve(candidates.size()); + + // The algo below is trying to identity a subgraph that only depends on cpu tensors. + // Usually it is a subgraph that doing shape calculation based on a GPU tensor, then reshape it back. + // The detail: + // for each candidate, if one of its input is a cpu tensor and the Non-CPU kernel doesn't mark it as cpu input, + // force the node to CPU to avoid memory cpu and add its output to the small cpu tensors. + while (!candidates.empty()) { + NodeId cur = candidates.top(); + candidates.pop(); + + auto p = visited.insert(cur); + if (!p.second) { + continue; + } + + auto node_iter = node_id_to_node.find(cur); + RETURN_IF(node_iter == node_id_to_node.end(), ort_api, "Unable to get OrtNode for a given node ID"); + Ort::ConstNode node = node_iter->second; + + if (provider_nodes.find(cur) == provider_nodes.end()) { + // Nodes not in provider_nodes are either have EP assigned or no kernel found on target EP. + // we assume these nodes will fallback to CPU, so add all direct consumers of all outputs to candidates. + std::string ep_name = node.GetEpName(); + if (ep_name.empty() || ep_name == "CPUExecutionProvider") { + std::vector outputs = node.GetOutputs(); + + for (Ort::ConstValueInfo output : outputs) { + if (output == nullptr) continue; // Skip missing optional output + cpu_output_args.insert(output); + } + + std::vector output_nodes; + RETURN_IF_ERROR(GetOutputNodes(outputs, output_nodes)); + + for (Ort::ConstNode downstream_node : output_nodes) { + candidates.push(downstream_node.GetId()); + } + } + continue; + } + + std::vector inputs = node.GetInputs(); + bool place_in_cpu = true; + + for (size_t i = 0; i < inputs.size(); i++) { + Ort::ConstValueInfo input = inputs[i]; + if (input == nullptr) continue; // Skip missing optional input + + // skip placing on CPU if the data types is float16 or bfloat16 or + // float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz or float4e2m1 + Ort::ConstTypeInfo type_info = input.TypeInfo(); + auto type_shape_info = type_info.GetTensorTypeAndShapeInfo(); + auto elem_type = type_shape_info.GetElementType(); + if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1) { + place_in_cpu = false; + break; + } + + bool is_small_initializer = input.IsConstantInitializer() && + type_shape_info.GetElementCount() <= kSmallInitializerThreshold; + + // Allow placing on CPU if it's a small initializer or graph input + if (is_small_initializer || input.IsRequiredGraphInput() || input.IsOptionalGraphInput()) { + continue; + } + + // the input is not a CPU tensor + if (cpu_output_args.find(input) == cpu_output_args.end()) { + place_in_cpu = false; + break; + } + + // input is a CPU tensor, but it's intended to be consumed as CPU input by the target EP + bool is_input_on_cpu = MemTypeOnCpuExplicitly(node_to_kernel[cur].GetInputMemType(i)); + if (is_input_on_cpu) { + place_in_cpu = false; + break; + } + } + + if (place_in_cpu) { + cpu_nodes.insert(node); + ORT_CXX_LOGF(Ort::Logger(&logger), ORT_LOGGING_LEVEL_WARNING, + "EP optimization: Force fallback to CPU execution for node %s because the CPU execution path " + "is deemed faster than overhead involved with execution on other EPs capable of executing " + "this node.\n", + node.GetName().c_str()); + + std::vector outputs = node.GetOutputs(); + for (Ort::ConstValueInfo output : outputs) { + if (output == nullptr) continue; // Skip missing optional output + cpu_output_args.insert(output); + } + + std::vector output_nodes; + RETURN_IF_ERROR(GetOutputNodes(outputs, output_nodes)); + + for (Ort::ConstNode downstream_node : output_nodes) { + candidates.push(downstream_node.GetId()); + } + } + } + + cpu_preferred_nodes = std::move(cpu_nodes); + + return nullptr; + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } catch (...) { + Ort::Status status("Unknown exception", ORT_EP_FAIL); + return status.release(); + } +} + +} // namespace ep +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index 93d35d39390f5..a647ef3d5e22a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -32,22 +32,22 @@ class AttentionBase { int& past_sequence_length) const; protected: - AttentionBase(const OpKernelInfo& info, bool require_same_hidden_size) { + template + AttentionBase(const KernelInfoType& info, bool require_same_hidden_size) { int64_t num_heads = 0; ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); num_heads_ = static_cast(num_heads); - is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; - rotary_embedding_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); - mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); - scale_ = info.GetAttrOrDefault("scale", 0.0f); - - if (!info.GetAttrs("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) { + is_unidirectional_ = info.template GetAttrOrDefault("unidirectional", 0) == 1; + do_rotary_ = info.template GetAttrOrDefault("do_rotary", 0) == 1; + rotary_embedding_ = static_cast(info.template GetAttrOrDefault("rotary_embedding_dim", 0)); + mask_filter_value_ = info.template GetAttrOrDefault("mask_filter_value", -10000.0f); + scale_ = info.template GetAttrOrDefault("scale", 0.0f); + if (!info.template GetAttrs("qkv_hidden_sizes", qkv_hidden_sizes_).IsOK()) { qkv_hidden_sizes_.clear(); } - past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL); + past_present_share_buffer_ = info.template GetAttrOrDefault("past_present_share_buffer", 0LL); require_same_hidden_size_ = require_same_hidden_size; } diff --git a/onnxruntime/core/providers/cpu/nn/conv_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_attributes.h index 170f313c8fe80..60cf76b372c13 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_attributes.h @@ -18,9 +18,10 @@ namespace onnxruntime { struct ConvAttributes { using ConvPadVector = InlinedVector; - explicit ConvAttributes(const OpKernelInfo& info) { + template + explicit ConvAttributes(const KernelInfoType& info) { std::string auto_pad_str; - auto status = info.GetAttr("auto_pad", &auto_pad_str); + auto status = info.template GetAttr("auto_pad", &auto_pad_str); if (status.IsOK()) { auto_pad = StringToAutoPadType(auto_pad_str); } @@ -32,8 +33,8 @@ struct ConvAttributes { strides.resize(kernel_shape_.size(), 1); } - gsl::span pads_span; - status = info.GetAttrsAsSpan("pads", pads_span); + std::vector pads_attr; + status = info.GetAttrs("pads", pads_attr); if (!status.IsOK()) { if (kernel_shape_specified) { // If pads are not explicitly provided, fill the container with all zeros @@ -44,7 +45,7 @@ struct ConvAttributes { // Pads are explicitly provided, make sure that auto_pad is NOTSET ORT_ENFORCE(auto_pad == AutoPadType::NOTSET, "A Conv/ConvTranspose node has both 'auto_pad' and 'pads' attributes"); - pads.assign(pads_span.begin(), pads_span.end()); + pads.assign(pads_attr.begin(), pads_attr.end()); } status = info.GetAttrs("dilations", dilations); @@ -52,7 +53,7 @@ struct ConvAttributes { dilations.resize(kernel_shape_.size(), 1); } - status = info.GetAttr("group", &group); + status = info.template GetAttr("group", &group); if (!status.IsOK()) { group = 1; } @@ -61,9 +62,9 @@ struct ConvAttributes { // TODO: Re-enable when attributes values are guaranteed to be filled. // Make sure empty strides or dilations are defaulted to 1 if necessary std::string auto_pad_str; - ORT_ENFORCE(info.GetAttr("auto_pad", &auto_pad_str).IsOK()); + ORT_ENFORCE(info.template GetAttr("auto_pad", &auto_pad_str).IsOK()); auto_pad = StringToAutoPadType(auto_pad_str); - ORT_ENFORCE(info.GetAttr("group", &group).IsOK()); + ORT_ENFORCE(info.template GetAttr("group", &group).IsOK()); ORT_ENFORCE(info.GetAttrs("kernel_shape", kernel_shape_).IsOK()); ORT_ENFORCE(info.GetAttrs("strides", strides).IsOK()); ORT_ENFORCE(info.GetAttrs("pads", pads).IsOK()); diff --git a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h index 973743d711359..b56b965fad9be 100644 --- a/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/conv_transpose_attributes.h @@ -23,7 +23,8 @@ namespace onnxruntime { struct ConvTransposeAttributes : public ConvAttributes { - explicit ConvTransposeAttributes(const OpKernelInfo& info) + template + explicit ConvTransposeAttributes(const KernelInfoType& info) : ConvAttributes(info), output_padding(info.GetAttrsOrDefault("output_padding")), output_shape(info.GetAttrsOrDefault("output_shape")) { diff --git a/onnxruntime/core/providers/cpu/nn/pool_attributes.h b/onnxruntime/core/providers/cpu/nn/pool_attributes.h index fbbd4273757d5..66deb9eab9bc3 100644 --- a/onnxruntime/core/providers/cpu/nn/pool_attributes.h +++ b/onnxruntime/core/providers/cpu/nn/pool_attributes.h @@ -24,7 +24,8 @@ struct PoolAttributes { // Shared providers don't know about OpNodeProtoHelper PoolAttributes(const OpKernelInfo& info, #else - PoolAttributes(const OpNodeProtoHelper& info, + template + PoolAttributes(const KernelInfoType& info, #endif const std::string& op_name, int start_version) : global_pooling(IsGlobalPooling(op_name)) { @@ -37,7 +38,7 @@ struct PoolAttributes { std::string auto_padding; if (op_name != "MaxUnpool") { - ORT_ENFORCE(info.GetAttr("auto_pad", &auto_padding).IsOK()); + ORT_ENFORCE(info.template GetAttr("auto_pad", &auto_padding).IsOK()); } auto_pad = StringToAutoPadType(auto_padding); @@ -49,7 +50,7 @@ struct PoolAttributes { strides.resize(kernel_shape.size(), 1); } - if (!info.GetAttr("ceil_mode", &ceil_mode).IsOK()) { + if (!info.template GetAttr("ceil_mode", &ceil_mode).IsOK()) { ceil_mode = 0; } @@ -63,7 +64,7 @@ struct PoolAttributes { if (op_name == "AveragePool") { int64_t temp; - ORT_ENFORCE(info.GetAttr("count_include_pad", &temp).IsOK()); + ORT_ENFORCE(info.template GetAttr("count_include_pad", &temp).IsOK()); count_include_pad = (temp != 0); } diff --git a/onnxruntime/core/providers/cpu/nn/pool_base.h b/onnxruntime/core/providers/cpu/nn/pool_base.h index 00dd1b152026d..1caaef3f98b60 100644 --- a/onnxruntime/core/providers/cpu/nn/pool_base.h +++ b/onnxruntime/core/providers/cpu/nn/pool_base.h @@ -102,13 +102,15 @@ class LpPool { class PoolBase { private: - static int GetStartVersion(const OpKernelInfo& info) { + template + static int GetStartVersion(const KernelInfoType& info) { return info.node().SinceVersion(); } protected: - PoolBase(const OpKernelInfo& info) - : op_name_(info.GetKernelDef().OpName().rfind("QLinear", 0) != 0 ? info.GetKernelDef().OpName() : info.GetKernelDef().OpName().substr(7)), + template + PoolBase(const KernelInfoType& info) + : op_name_(info.node().OpType().rfind("QLinear", 0) != 0 ? info.node().OpType() : info.node().OpType().substr(7)), pool_attrs_(info, op_name_, GetStartVersion(info)) { } diff --git a/onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h b/onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h index 5725e85f8e1e4..a0392c8b27366 100644 --- a/onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h +++ b/onnxruntime/core/providers/cpu/reduction/reduction_kernel_base.h @@ -11,11 +11,12 @@ namespace onnxruntime { template class ReduceKernelBase { protected: - ReduceKernelBase(const OpKernelInfo& info, optional keepdims_override = {}) { + template + ReduceKernelBase(const KernelInfoType& info, optional keepdims_override = {}) { if (allow_multi_axes) { - axes_ = ToShapeVector(info.GetAttrsOrDefault("axes")); + axes_ = ToShapeVector(info.template GetAttrsOrDefault("axes")); } else { - auto v = info.GetAttrOrDefault("axis", 0); + auto v = info.template GetAttrOrDefault("axis", 0); axes_.push_back(v); } int64_t keepdims = 1; @@ -25,9 +26,9 @@ class ReduceKernelBase { ORT_ENFORCE(info.GetAttr("keepdims", &keepdims).IsOK()); } keepdims_ = (keepdims == 1); - int64_t noop_with_empty_axes = info.GetAttrOrDefault("noop_with_empty_axes", 0); + int64_t noop_with_empty_axes = info.template GetAttrOrDefault("noop_with_empty_axes", 0); noop_with_empty_axes_ = (noop_with_empty_axes == 1); - int64_t select_last_index = info.GetAttrOrDefault("select_last_index", 0); + int64_t select_last_index = info.template GetAttrOrDefault("select_last_index", 0); select_last_index_ = (select_last_index != 0); } diff --git a/onnxruntime/core/providers/cpu/tensor/concat.cc b/onnxruntime/core/providers/cpu/tensor/concat.cc index 732d0cab2ffae..e3d5c0600420f 100644 --- a/onnxruntime/core/providers/cpu/tensor/concat.cc +++ b/onnxruntime/core/providers/cpu/tensor/concat.cc @@ -54,180 +54,7 @@ using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExec Status ConcatBase::PrepareForCompute(OpKernelContext* ctx, const InlinedTensorsVector& input_tensors, Prepare& p) const { - size_t input_count = input_tensors.size(); - - // Must have atleast one input to concat - ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs"); - - TensorShapeVector reference_dims; - size_t reference_rank = 0; - - int reference_tensor_index = 0; - - InlinedVector input_tensor_sizes; - input_tensor_sizes.reserve(input_count); - - bool all_inputs_are_empty = true; - - for (size_t index = 0; index < input_count; ++index) { - const auto* input = input_tensors[index]; - ORT_ENFORCE(input != nullptr, "input count mismatch"); - - // find the first tensor that isn't empty - // to be used as a reference for all - // downstream shape/rank validations of other inputs - const auto& shape = input->Shape(); - const auto num_elements = shape.Size(); - if (num_elements > 0) { - reference_dims = shape.AsShapeVector(); - reference_rank = reference_dims.size(); - reference_tensor_index = onnxruntime::narrow(index); - input_tensor_sizes.push_back(num_elements); - all_inputs_are_empty = false; - break; - } else { - input_tensor_sizes.push_back(0); - } - } - - if (all_inputs_are_empty) { - // Reference dim and reference rank can just come from the first input - // No shape/rank validations will be done (as all inputs are empty). - // But the rest of the execution flow (filling in the Prepare instance - p) - // can use this info. - reference_dims = input_tensors[0]->Shape().AsShapeVector(); - reference_rank = reference_dims.size(); - } - - // Cannot concatenate scalars (but they can be stacked) - if (!is_stack_) - ORT_RETURN_IF_NOT(reference_rank > 0, "Cannot concatenate scalars"); - - // Handle and fix negative axis - // In 'stack' mode, the accepted range depends on the output rank (which is one more than the input rank) - p.axis = static_cast(HandleNegativeAxis(axis_, onnxruntime::narrow(!is_stack_ - ? reference_rank - : reference_rank + 1))); - - // Ensure all of the non concatenated axes match each other - for (size_t index = static_cast(reference_tensor_index) + 1; index < input_count; index++) { - const auto* input = input_tensors[index]; - ORT_ENFORCE(input != nullptr, "input count mismatch"); - const auto& input_shape = input->Shape(); - const auto input_dims = input_shape.GetDims(); - - // Skip shape/rank validation for inputs that are empty. - // The ONNX spec states that all dim values along axes not concatentated on - // need to be the same for all inputs (empty inputs are not explicitly exempted). - // The model in GH issue 8020 has a bunch of Loop nodes all feeding into - // the 'Concat' node and one of these Loops tend to have an iteration - // count of 0 for some inputs. If the iteration count for a Loop is zero, - // we don't execute its subgraph (since the outputs are going to be empty anyway) - // and we send an "empty" tensor(s) downstream and use ONNX shape inferred shape - // to "compose" the shape for these empty tensor(s). - // If we encounter symbolic dims in the ONNX shape inferred shape, we place a '0' - // in that position and due to the "lossy" nature of this process, the inputs' shape - // validation for such empty inputs fail and hence we skip these validations for all - // empty inputs. - // This isn't too bad as we will never use empty inputs while concatenating anyway. - // We just loosen this check to unblock model in GH issue 8020 to complete processing. - if (input_shape.Size() == 0) { - input_tensor_sizes.push_back(0); - } else { - const size_t input_rank = input_dims.size(); - - ORT_ENFORCE(input_rank == reference_rank, - "Ranks of input data are different, cannot concatenate them. expected rank: ", - reference_rank, " got: ", input_rank); - - // Ensure all the other (non-concat) axes match - int64_t tensor_size = 1; - for (size_t axis_index = 0; axis_index < reference_rank; ++axis_index) { - auto dim_value = input_dims[axis_index]; - tensor_size *= dim_value; - - // In 'concat' mode, the axis to be concatenated may be different - // But in 'stack' mode, all input shapes must be the same and must be validated - if (!is_stack_ && axis_index == p.axis) - continue; - - ORT_RETURN_IF_NOT(dim_value == reference_dims[axis_index], - "Non concat axis dimensions must match: Axis ", - axis_index, " has mismatched dimensions of ", dim_value, - " and ", reference_dims[axis_index]); - } - - input_tensor_sizes.push_back(tensor_size); // assign the computed size of the input tensor - } - } - - // Calculate the shape of the output tensor - auto output_dims = reference_dims; - - if (!is_stack_) { // 'Concat' mode - // While concatenating, the rank of the output is the same as the input rank(s) - - // Calculate the size of the concatenated axis - size_t concat_axis_size = 0; - for (size_t index = 0; index < input_count; index++) { - concat_axis_size += onnxruntime::narrow(input_tensors[index]->Shape()[onnxruntime::narrow(p.axis)]); - } - - output_dims[onnxruntime::narrow(p.axis)] = onnxruntime::narrow(concat_axis_size); - } else { // 'Stack' mode - // While stacking, the rank of the output is one more than the input rank(s). - // Stacking may be thought of as adding an unit dimension (of value 1) in the input tensors, - // and concatenating them on thie new axis. - // The value in the corresponding axis of the output will be the number of inputs that are being stacked. - output_dims.insert(output_dims.begin() + p.axis, static_cast(input_count)); - } - - TensorShape output_shape(output_dims); - - // Create output tensor - p.output_tensor = &(*ctx->Output(0, output_shape)); - - // Make note if output tensor is going to be empty - p.output_num_elements = output_shape.Size(); - - // No need to proceed further if output is going to be empty - if (p.output_num_elements == 0) - return Status::OK(); - - // The output_axis_pitch is the number of elements to add to move to the next split axis in the output. - // Can handle stacking as well. - p.output_axis_pitch = 1; - auto output_rank = !is_stack_ ? reference_rank : reference_rank + 1; - for (size_t i = output_rank; i-- > p.axis;) { - p.output_axis_pitch *= output_dims[i]; - } - - // Fill the 'Prepare' struct with available information - p.inputs.reserve(input_count); - for (size_t input_index = 0; input_index < input_count; input_index++) { - const Tensor* data_n_ptr = input_tensors[input_index]; - auto& data_n = *data_n_ptr; - - // Type sanity check (Make sure we are working on homogeneous types) - ORT_RETURN_IF_NOT(data_n.DataType() == p.output_tensor->DataType(), "Data type mismatch"); - - // The input_axis_pitch is the number of elements to add to move to the next split axis in the input - // Can handle stacking as well (as the "new dummy dimension" in the input is of unit value). - // TODO: Minor Optimization possibility: This input_axis_patch will be common across all inputs - // in 'ConcatFromSequence' (stack mode). They have to be computed for each input only while concatenating. - int64_t input_axis_pitch = 1; - const auto& data_dims = data_n.Shape().GetDims(); - for (size_t i = reference_rank; i-- > p.axis;) { - input_axis_pitch *= data_dims[i]; - } - - p.inputs.push_back({&data_n, input_axis_pitch, input_tensor_sizes[input_index]}); - } - - // Make note if the input Tensors of type 'string' - p.is_string_type = p.inputs[0].tensor->IsDataTypeString(); - - return Status::OK(); + return PrepareForComputeImpl(ctx, input_tensors, p); } namespace { diff --git a/onnxruntime/core/providers/cpu/tensor/concatbase.h b/onnxruntime/core/providers/cpu/tensor/concatbase.h index ad8b69016265d..b9085b2a9318b 100644 --- a/onnxruntime/core/providers/cpu/tensor/concatbase.h +++ b/onnxruntime/core/providers/cpu/tensor/concatbase.h @@ -2,6 +2,9 @@ // Licensed under the MIT License. #pragma once +#ifndef SHARED_PROVIDER +#include "core/providers/common.h" +#endif #include "core/common/inlined_containers.h" namespace onnxruntime { @@ -27,19 +30,199 @@ class ConcatBase { // the core method that will be invoked by the 'Concat' (CPU and GPU) // and 'ConcatFromSequence' kernels using InlinedTensorsVector = InlinedVector; + template + Status PrepareForComputeImpl(KernelContextType* ctx, const InlinedTensorsVector& input_tensors, + Prepare& p) const { + size_t input_count = input_tensors.size(); + + // Must have atleast one input to concat + ORT_RETURN_IF_NOT(input_count >= 1, "Must have 1 or more inputs"); + + TensorShapeVector reference_dims; + size_t reference_rank = 0; + + int reference_tensor_index = 0; + + InlinedVector input_tensor_sizes; + input_tensor_sizes.reserve(input_count); + + bool all_inputs_are_empty = true; + + for (size_t index = 0; index < input_count; ++index) { + const auto* input = input_tensors[index]; + ORT_ENFORCE(input != nullptr, "input count mismatch"); + + // find the first tensor that isn't empty + // to be used as a reference for all + // downstream shape/rank validations of other inputs + const auto& shape = input->Shape(); + const auto num_elements = shape.Size(); + if (num_elements > 0) { + reference_dims = shape.AsShapeVector(); + reference_rank = reference_dims.size(); + reference_tensor_index = onnxruntime::narrow(index); + input_tensor_sizes.push_back(num_elements); + all_inputs_are_empty = false; + break; + } else { + input_tensor_sizes.push_back(0); + } + } + + if (all_inputs_are_empty) { + // Reference dim and reference rank can just come from the first input + // No shape/rank validations will be done (as all inputs are empty). + // But the rest of the execution flow (filling in the Prepare instance - p) + // can use this info. + reference_dims = input_tensors[0]->Shape().AsShapeVector(); + reference_rank = reference_dims.size(); + } + + // Cannot concatenate scalars (but they can be stacked) + if (!is_stack_) + ORT_RETURN_IF_NOT(reference_rank > 0, "Cannot concatenate scalars"); + + // Handle and fix negative axis + // In 'stack' mode, the accepted range depends on the output rank (which is one more than the input rank) + p.axis = static_cast(HandleNegativeAxis(axis_, onnxruntime::narrow(!is_stack_ + ? reference_rank + : reference_rank + 1))); + + // Ensure all of the non concatenated axes match each other + for (size_t index = static_cast(reference_tensor_index) + 1; index < input_count; index++) { + const auto* input = input_tensors[index]; + ORT_ENFORCE(input != nullptr, "input count mismatch"); + const auto& input_shape = input->Shape(); + const auto input_dims = input_shape.GetDims(); + + // Skip shape/rank validation for inputs that are empty. + // The ONNX spec states that all dim values along axes not concatentated on + // need to be the same for all inputs (empty inputs are not explicitly exempted). + // The model in GH issue 8020 has a bunch of Loop nodes all feeding into + // the 'Concat' node and one of these Loops tend to have an iteration + // count of 0 for some inputs. If the iteration count for a Loop is zero, + // we don't execute its subgraph (since the outputs are going to be empty anyway) + // and we send an "empty" tensor(s) downstream and use ONNX shape inferred shape + // to "compose" the shape for these empty tensor(s). + // If we encounter symbolic dims in the ONNX shape inferred shape, we place a '0' + // in that position and due to the "lossy" nature of this process, the inputs' shape + // validation for such empty inputs fail and hence we skip these validations for all + // empty inputs. + // This isn't too bad as we will never use empty inputs while concatenating anyway. + // We just loosen this check to unblock model in GH issue 8020 to complete processing. + if (input_shape.Size() == 0) { + input_tensor_sizes.push_back(0); + } else { + const size_t input_rank = input_dims.size(); + + ORT_ENFORCE(input_rank == reference_rank, + "Ranks of input data are different, cannot concatenate them. expected rank: ", + reference_rank, " got: ", input_rank); + + // Ensure all the other (non-concat) axes match + int64_t tensor_size = 1; + for (size_t axis_index = 0; axis_index < reference_rank; ++axis_index) { + auto dim_value = input_dims[axis_index]; + tensor_size *= dim_value; + + // In 'concat' mode, the axis to be concatenated may be different + // But in 'stack' mode, all input shapes must be the same and must be validated + if (!is_stack_ && axis_index == p.axis) + continue; + + ORT_RETURN_IF_NOT(dim_value == reference_dims[axis_index], + "Non concat axis dimensions must match: Axis ", + axis_index, " has mismatched dimensions of ", dim_value, + " and ", reference_dims[axis_index]); + } + + input_tensor_sizes.push_back(tensor_size); // assign the computed size of the input tensor + } + } + + // Calculate the shape of the output tensor + auto output_dims = reference_dims; + + if (!is_stack_) { // 'Concat' mode + // While concatenating, the rank of the output is the same as the input rank(s) + + // Calculate the size of the concatenated axis + size_t concat_axis_size = 0; + for (size_t index = 0; index < input_count; index++) { + concat_axis_size += onnxruntime::narrow(input_tensors[index]->Shape()[onnxruntime::narrow(p.axis)]); + } + + output_dims[onnxruntime::narrow(p.axis)] = onnxruntime::narrow(concat_axis_size); + } else { // 'Stack' mode + // While stacking, the rank of the output is one more than the input rank(s). + // Stacking may be thought of as adding an unit dimension (of value 1) in the input tensors, + // and concatenating them on thie new axis. + // The value in the corresponding axis of the output will be the number of inputs that are being stacked. + output_dims.insert(output_dims.begin() + p.axis, static_cast(input_count)); + } + + TensorShape output_shape(output_dims); + + // Create output tensor + p.output_tensor = &(*ctx->Output(0, output_shape)); + + // Make note if output tensor is going to be empty + p.output_num_elements = output_shape.Size(); + + // No need to proceed further if output is going to be empty + if (p.output_num_elements == 0) + return Status::OK(); + + // The output_axis_pitch is the number of elements to add to move to the next split axis in the output. + // Can handle stacking as well. + p.output_axis_pitch = 1; + auto output_rank = !is_stack_ ? reference_rank : reference_rank + 1; + for (size_t i = output_rank; i-- > p.axis;) { + p.output_axis_pitch *= output_dims[i]; + } + + // Fill the 'Prepare' struct with available information + p.inputs.reserve(input_count); + for (size_t input_index = 0; input_index < input_count; input_index++) { + const Tensor* data_n_ptr = input_tensors[input_index]; + auto& data_n = *data_n_ptr; + + // Type sanity check (Make sure we are working on homogeneous types) + ORT_RETURN_IF_NOT(data_n.DataType() == p.output_tensor->DataType(), "Data type mismatch"); + + // The input_axis_pitch is the number of elements to add to move to the next split axis in the input + // Can handle stacking as well (as the "new dummy dimension" in the input is of unit value). + // TODO: Minor Optimization possibility: This input_axis_patch will be common across all inputs + // in 'ConcatFromSequence' (stack mode). They have to be computed for each input only while concatenating. + int64_t input_axis_pitch = 1; + const auto& data_dims = data_n.Shape().GetDims(); + for (size_t i = reference_rank; i-- > p.axis;) { + input_axis_pitch *= data_dims[i]; + } + + p.inputs.push_back({&data_n, input_axis_pitch, input_tensor_sizes[input_index]}); + } + + // Make note if the input Tensors of type 'string' + p.is_string_type = p.inputs[0].tensor->IsDataTypeString(); + + return Status::OK(); + } + Status PrepareForCompute(OpKernelContext* ctx, const InlinedTensorsVector& input_tensors, Prepare& p) const; protected: - ConcatBase(const OpKernelInfo& info, bool is_sequence_op = false) { - if (!info.GetAttr("axis", &axis_).IsOK()) { + template + ConcatBase(const KernelInfoType& info, bool is_sequence_op = false) { + if (!info.template GetAttr("axis", &axis_).IsOK()) { ORT_ENFORCE(false, "Must have valid 'axis' attribute"); } is_sequence_op_ = is_sequence_op; if (is_sequence_op) { // Only ConcatFromSequence supports stacking - is_stack_ = info.GetAttrOrDefault("new_axis", 0) == 0 ? false : true; + is_stack_ = info.template GetAttrOrDefault("new_axis", 0) == 0 ? false : true; } } Status ComputeImpl(Prepare& p, OpKernelContext* ctx) const; diff --git a/onnxruntime/core/providers/cpu/tensor/gather.cc b/onnxruntime/core/providers/cpu/tensor/gather.cc index 38a16ee83c86b..b13fcd4135f67 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather.cc @@ -57,30 +57,7 @@ ONNX_CPU_OPERATOR_KERNEL( Gather); Status GatherBase::PrepareForCompute(OpKernelContext* context, Prepare& p) const { - p.input_tensor = context->Input(0); - const TensorShape& input_data_shape = p.input_tensor->Shape(); - p.indices_tensor = context->Input(1); - const TensorShape& indices_shape = p.indices_tensor->Shape(); - - const auto input_rank = input_data_shape.NumDimensions(); - p.axis = HandleNegativeAxis(axis_, narrow(input_rank)); - - std::vector shape; - shape.reserve(input_rank - 1 + indices_shape.NumDimensions()); - - // replace the dimension for p.axis with the shape from the indices - for (int64_t i = 0; i < p.axis; ++i) - shape.push_back(input_data_shape[narrow(i)]); - - for (const auto dim : indices_shape.GetDims()) - shape.push_back(dim); - - for (int64_t i = p.axis + 1; i < static_cast(input_rank); ++i) - shape.push_back(input_data_shape[narrow(i)]); - - p.output_tensor = context->Output(0, TensorShape(std::move(shape))); - - return Status::OK(); + return PrepareForComputeImpl(context, p); } template diff --git a/onnxruntime/core/providers/cpu/tensor/gatherbase.h b/onnxruntime/core/providers/cpu/tensor/gatherbase.h index 195b67553a87b..1f5e85c554a78 100644 --- a/onnxruntime/core/providers/cpu/tensor/gatherbase.h +++ b/onnxruntime/core/providers/cpu/tensor/gatherbase.h @@ -2,6 +2,11 @@ // Licensed under the MIT License. #pragma once +#ifndef SHARED_PROVIDER +#include "core/providers/common.h" +#include "core/framework/tensor.h" +#endif + namespace onnxruntime { class GatherBase { @@ -13,11 +18,40 @@ class GatherBase { int64_t axis; }; + template + Status PrepareForComputeImpl(KernelContextType* context, Prepare& p) const { + p.input_tensor = context->template Input(0); + const TensorShape& input_data_shape = p.input_tensor->Shape(); + p.indices_tensor = context->template Input(1); + const TensorShape& indices_shape = p.indices_tensor->Shape(); + + const auto input_rank = input_data_shape.NumDimensions(); + p.axis = HandleNegativeAxis(axis_, narrow(input_rank)); + + std::vector shape; + shape.reserve(input_rank - 1 + indices_shape.NumDimensions()); + + // replace the dimension for p.axis with the shape from the indices + for (int64_t i = 0; i < p.axis; ++i) + shape.push_back(input_data_shape[narrow(i)]); + + for (const auto dim : indices_shape.GetDims()) + shape.push_back(dim); + + for (int64_t i = p.axis + 1; i < static_cast(input_rank); ++i) + shape.push_back(input_data_shape[narrow(i)]); + + p.output_tensor = context->Output(0, TensorShape(std::move(shape))); + + return Status::OK(); + } + Status PrepareForCompute(OpKernelContext* context, Prepare& p) const; protected: - GatherBase(const OpKernelInfo& info) { - ORT_ENFORCE(info.GetAttr("axis", &axis_).IsOK(), "Missing/Invalid 'axis' attribute value"); + template + GatherBase(const KernelInfoType& info) { + ORT_ENFORCE(info.template GetAttr("axis", &axis_).IsOK(), "Missing/Invalid 'axis' attribute value"); } private: diff --git a/onnxruntime/core/providers/cpu/tensor/pad.cc b/onnxruntime/core/providers/cpu/tensor/pad.cc index 84addee6997cc..286317b1e50dd 100644 --- a/onnxruntime/core/providers/cpu/tensor/pad.cc +++ b/onnxruntime/core/providers/cpu/tensor/pad.cc @@ -312,58 +312,9 @@ Status PadBase::HandleDimValueZero(const Mode& mode, const TensorShape& input_sh return Status::OK(); } -static void ComputePadWithAxes( - gsl::span pads_tensor_raw_data, - std::function get_axis, - size_t axes_size, - size_t data_rank, - PadsVector& pads) { - for (size_t i = 0; i < axes_size; ++i) { - const size_t axis = onnxruntime::narrow(HandleNegativeAxis(get_axis(i), data_rank)); - pads[axis] = pads_tensor_raw_data[i]; // xi_begin - pads[data_rank + axis] = pads_tensor_raw_data[axes_size + i]; // xi_end - } -} - void PadBase::ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, PadsVector& pads) { - pads.reserve(2 * data_rank); - const Tensor* axes_tensor = ctx.Input(3); - if (axes_tensor) { - const size_t num_axes_dims = axes_tensor->Shape().NumDimensions(); - ORT_ENFORCE(num_axes_dims == 1, "Axes tensor should be a 1D tensor "); - - const int64_t num_axes = axes_tensor->Shape().Size(); - ORT_ENFORCE(pads_data.size() == narrow(2 * num_axes), - "Pads tensor size should be equal to twice the number of explicitly provided axes."); - - pads.resize(2 * data_rank, 0); - if (axes_tensor->IsDataType()) { - auto axes_data = axes_tensor->DataAsSpan(); - ComputePadWithAxes( - pads_data, - [axes_data](size_t idx) -> int64_t { - return axes_data[idx]; - }, - axes_data.size(), - data_rank, - pads); - } else if (axes_tensor->IsDataType()) { - auto axes_data = axes_tensor->DataAsSpan(); - ComputePadWithAxes( - pads_data, - [axes_data](size_t idx) { - return axes_data[idx]; - }, - axes_data.size(), - data_rank, - pads); - } - } else { - ORT_ENFORCE(pads_data.size() == 2 * data_rank, - "Pads tensor size should be equal to twice the input dimension count "); - pads.assign(pads_data.begin(), pads_data.end()); - } + ComputePadsImpl(ctx, data_rank, pads_data, pads); } // Flatten no padding inner most Axis, so one memcpy cover multiple Axis. diff --git a/onnxruntime/core/providers/cpu/tensor/padbase.h b/onnxruntime/core/providers/cpu/tensor/padbase.h index e2ab6ff6c8fb1..89c157690395f 100644 --- a/onnxruntime/core/providers/cpu/tensor/padbase.h +++ b/onnxruntime/core/providers/cpu/tensor/padbase.h @@ -5,6 +5,11 @@ #include "core/common/inlined_containers.h" +#ifndef SHARED_PROVIDER +#include "core/providers/common.h" +#include "core/framework/tensor.h" +#endif + namespace onnxruntime { enum class Mode : int { @@ -43,6 +48,48 @@ class PadBase { /// input rank /// pads data from pads input /// resulting pads + template + static void ComputePadsImpl(KernelContextType& ctx, size_t data_rank, gsl::span pads_data, + PadsVector& pads) { + pads.reserve(2 * data_rank); + const Tensor* axes_tensor = ctx.template Input(3); + if (axes_tensor) { + const size_t num_axes_dims = axes_tensor->Shape().NumDimensions(); + ORT_ENFORCE(num_axes_dims == 1, "Axes tensor should be a 1D tensor "); + + const int64_t num_axes = axes_tensor->Shape().Size(); + ORT_ENFORCE(pads_data.size() == narrow(2 * num_axes), + "Pads tensor size should be equal to twice the number of explicitly provided axes."); + + pads.resize(2 * data_rank, 0); + if (axes_tensor->IsDataType()) { + auto axes_data = axes_tensor->DataAsSpan(); + ComputePadWithAxes( + pads_data, + [axes_data](size_t idx) -> int64_t { + return axes_data[idx]; + }, + axes_data.size(), + data_rank, + pads); + } else if (axes_tensor->IsDataType()) { + auto axes_data = axes_tensor->DataAsSpan(); + ComputePadWithAxes( + pads_data, + [axes_data](size_t idx) { + return axes_data[idx]; + }, + axes_data.size(), + data_rank, + pads); + } + } else { + ORT_ENFORCE(pads_data.size() == 2 * data_rank, + "Pads tensor size should be equal to twice the input dimension count "); + pads.assign(pads_data.begin(), pads_data.end()); + } + } + static void ComputePads(OpKernelContext& ctx, size_t data_rank, gsl::span pads_data, PadsVector& pads); @@ -131,7 +178,8 @@ class PadBase { size_t inner_no_pad_size, PadsVector& reshaped_pad); protected: - PadBase(const OpKernelInfo& info) : value_(info.GetAttrOrDefault("value", 0.f)) { + template + PadBase(const KernelInfoType& info) : value_(info.GetAttrOrDefault("value", 0.f)) { std::string mode; if (info.GetAttr("mode", &mode).IsOK()) { if (mode == "constant") @@ -148,19 +196,16 @@ class PadBase { const auto& kernel_def = info.GetKernelDef(); - int start_ver, end_ver; - kernel_def.SinceVersion(&start_ver, &end_ver); - // kMSDomain contrib kernel AND OnnxDomain start version >= 11 => DynamicPad - if (start_ver >= 11 || kernel_def.Domain() == kMSDomain) { + if (info.node().SinceVersion() >= 11 || kernel_def.Domain() == kMSDomain) { is_dynamic_ = true; } if (!is_dynamic_) { - gsl::span pads_span; - if (!info.GetAttrsAsSpan("pads", pads_span).IsOK()) + std::vector pads_attr; + if (!info.GetAttrs("pads", pads_attr).IsOK()) ORT_THROW("Invalid 'pads' attribute value"); - pads_.assign(pads_span.begin(), pads_span.end()); + pads_.assign(pads_attr.begin(), pads_attr.end()); // Separate out any negative pads_ into the slices_ array slices_.resize(pads_.size(), 0); for (size_t index = 0; index < pads_.size(); index++) { @@ -174,6 +219,19 @@ class PadBase { ~PadBase() = default; + static void ComputePadWithAxes( + gsl::span pads_tensor_raw_data, + std::function get_axis, + size_t axes_size, + size_t data_rank, + PadsVector& pads) { + for (size_t i = 0; i < axes_size; ++i) { + const size_t axis = onnxruntime::narrow(HandleNegativeAxis(get_axis(i), data_rank)); + pads[axis] = pads_tensor_raw_data[i]; // xi_begin + pads[data_rank + axis] = pads_tensor_raw_data[axes_size + i]; // xi_end + } + } + Mode mode_{Mode::Constant}; PadsVector pads_; // After construction, only >=0 values are in here PadsVector slices_; // All of the negative padding values are separated out into slices_ diff --git a/onnxruntime/core/providers/cpu/tensor/split.h b/onnxruntime/core/providers/cpu/tensor/split.h index 038b5bca15d17..b9e699b180124 100644 --- a/onnxruntime/core/providers/cpu/tensor/split.h +++ b/onnxruntime/core/providers/cpu/tensor/split.h @@ -22,8 +22,9 @@ class SplitBase { std::vector& split_sizes) const; protected: - SplitBase(const OpKernelInfo& info, uint32_t opset) : opset_{opset} { - axis_ = info.GetAttrOrDefault("axis", 0); + template + SplitBase(const KernelInfoType& info, uint32_t opset) : opset_{opset} { + axis_ = info.template GetAttrOrDefault("axis", 0); size_t num_inputs = info.GetInputCount(); if (num_inputs == 1) { @@ -36,7 +37,7 @@ class SplitBase { } if (opset_ >= 18) { - num_outputs_ = info.GetAttrOrDefault("num_outputs", -1); + num_outputs_ = info.template GetAttrOrDefault("num_outputs", -1); // the ONNX type/shape inferencing handles the check that num_outputs is > 0 // ORT_ENFORCE(num_outputs_ != 0, "Invalid value in 'num_outputs' attribute of 0."); diff --git a/onnxruntime/core/providers/cpu/tensor/squeeze.h b/onnxruntime/core/providers/cpu/tensor/squeeze.h index ef3a3050a49e6..6ce7b14f1d518 100644 --- a/onnxruntime/core/providers/cpu/tensor/squeeze.h +++ b/onnxruntime/core/providers/cpu/tensor/squeeze.h @@ -15,7 +15,8 @@ namespace onnxruntime { class SqueezeBase { protected: - explicit SqueezeBase(const OpKernelInfo& info) { + template + explicit SqueezeBase(const KernelInfoType& info) { TensorShapeVector axes; size_t numInputs = info.GetInputCount(); if (numInputs == 1) { diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.h b/onnxruntime/core/providers/cpu/tensor/transpose.h index 54d3584ba0dad..3076c9d574634 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.h +++ b/onnxruntime/core/providers/cpu/tensor/transpose.h @@ -37,9 +37,10 @@ class TransposeBase { concurrency::ThreadPool* tp = nullptr); protected: - TransposeBase(const OpKernelInfo& info) { + template + TransposeBase(const KernelInfoType& info) { std::vector temp_perm; - Status status = info.GetAttrs("perm", temp_perm); + Status status = info.template GetAttrs("perm", temp_perm); if (status.IsOK()) { size_t rank = temp_perm.size(); perm_.resize(temp_perm.size()); diff --git a/onnxruntime/core/providers/cpu/tensor/unsqueeze.h b/onnxruntime/core/providers/cpu/tensor/unsqueeze.h index 6960f8838ffde..5a8a318923da5 100644 --- a/onnxruntime/core/providers/cpu/tensor/unsqueeze.h +++ b/onnxruntime/core/providers/cpu/tensor/unsqueeze.h @@ -51,7 +51,8 @@ class UnsqueezeBase { } protected: - UnsqueezeBase(const OpKernelInfo& info) { + template + UnsqueezeBase(const KernelInfoType& info) { size_t num_inputs = info.GetInputCount(); if (num_inputs == 1) { // axes must be a valid attribute ORT_ENFORCE(info.GetAttrs("axes", axes_).IsOK(), "Missing/Invalid 'axes' attribute value"); diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index 4c393f8ae6574..9088ff618f771 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -128,27 +128,28 @@ class UpsampleBase { InlinedVector& scales) const; protected: - explicit UpsampleBase(const OpKernelInfo& info) + template + explicit UpsampleBase(const KernelInfoType& info) : scales_cached_(false), roi_cached_(false), use_extrapolation_(false) { const auto& node = info.node(); auto opset = node.SinceVersion(); is_resize_ = (opset >= 10); std::string mode; - ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK()); + ORT_ENFORCE(info.template GetAttr("mode", &mode).IsOK()); mode_ = StringToUpsampleMode(mode); auto input_count = info.GetInputCount(); if (input_count == 1) { // opset < 10 std::vector scales; - ORT_THROW_IF_ERROR(info.GetAttrs("scales", scales)); + ORT_THROW_IF_ERROR(info.template GetAttrs("scales", scales)); ORT_THROW_IF_ERROR(ScalesValidation(scales, mode_)); scales_.assign(scales.cbegin(), scales.cend()); scales_cached_ = true; } if (opset >= 18) { - antialias_ = info.GetAttrOrDefault("antialias", 0) == 0 ? false : true; + antialias_ = info.template GetAttrOrDefault("antialias", 0) == 0 ? false : true; if (antialias_) { ORT_ENFORCE((UpsampleMode::LINEAR == mode_ || UpsampleMode::CUBIC == mode_), @@ -156,21 +157,21 @@ class UpsampleBase { } // The attribute is absent in opset < 18, but the default value as if stretch. - std::string keep_aspect_ratio_policy = info.GetAttrOrDefault("keep_aspect_ratio_policy", "stretch"); + std::string keep_aspect_ratio_policy = info.template GetAttrOrDefault("keep_aspect_ratio_policy", "stretch"); keep_aspect_ratio_policy_ = StringToKeepAspectRatioPolicy(keep_aspect_ratio_policy); // guard against unit tests that can add an attribute - auto axes = info.GetAttrsOrDefault("axes"); + auto axes = info.template GetAttrsOrDefault("axes"); axes_.assign(axes.cbegin(), axes.cend()); } - extrapolation_value_ = info.GetAttrOrDefault("extrapolation_value", 0.0f); + extrapolation_value_ = info.template GetAttrOrDefault("extrapolation_value", 0.0f); // Coordinate transformation mode attr was introduced in version 11. // before that asymmetric mode was the only available transformation mode std::string coordinate_transform_mode_name = opset > 10 - ? info.GetAttrOrDefault("coordinate_transformation_mode", "half_pixel") + ? info.template GetAttrOrDefault("coordinate_transformation_mode", "half_pixel") : "asymmetric"; coordinate_transform_mode_ = StringToCoordinateTransformationMode(coordinate_transform_mode_name); @@ -184,13 +185,13 @@ class UpsampleBase { use_extrapolation_ = need_roi_input_ = (coordinate_transform_mode_ == TF_CROP_AND_RESIZE); std::string nearest_mode_name = (mode_ == NN && opset >= 11) - ? info.GetAttrOrDefault("nearest_mode", "round_prefer_floor") + ? info.template GetAttrOrDefault("nearest_mode", "round_prefer_floor") : ""; nearest_mode_ = StringToNearestMode(nearest_mode_name); get_nearest_pixel_ = GetNearestPixelFromOriginal(nearest_mode_); - cubic_coeff_a_ = info.GetAttrOrDefault("cubic_coeff_a", antialias_constants::kCubicCoeffA); - exclude_outside_ = info.GetAttrOrDefault("exclude_outside", 0) == 0 ? false : true; + cubic_coeff_a_ = info.template GetAttrOrDefault("cubic_coeff_a", antialias_constants::kCubicCoeffA); + exclude_outside_ = info.template GetAttrOrDefault("exclude_outside", 0) == 0 ? false : true; if ((exclude_outside_ == 1 && mode_ != CUBIC) && (antialias_ == false || mode_ != LINEAR)) { ORT_THROW( @@ -218,8 +219,22 @@ class UpsampleBase { if (scales_input_idx_ > 0) { const Tensor* scale; bool get_scale = info.TryGetConstantInput(scales_input_idx_, &scale); - auto x_shape = node.InputDefs()[0]->Shape(); - int64_t rank = x_shape ? x_shape->dim_size() : -1; + int64_t rank = -1; + if constexpr (std::is_same_v) { + auto x_shape = node.InputDefs()[0]->Shape(); + if (x_shape != nullptr) { + rank = x_shape->dim_size(); + } + } else { + int is_const; + auto tensor = info.GetKernelInfo().GetTensorConstantInput(0, &is_const); + if (is_const) { + auto type_and_shape_info = tensor.GetTensorTypeAndShapeInfo(); + if (type_and_shape_info.HasShape()) { + rank = static_cast(type_and_shape_info.GetShape().size()); + } + } + } if (get_scale && scale->Shape().Size() > 0 && ((opset < 18) || (rank > 0 && opset >= 18))) { ORT_THROW_IF_ERROR(ParseScalesData(scale, scales_, rank)); scales_cached_ = true; diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h index 5b694a7a2e3f1..5e54973ea60dc 100644 --- a/onnxruntime/core/providers/webgpu/compute_context.h +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -100,7 +100,11 @@ class ComputeContextBase { // Get the logger. // inline const logging::Logger& Logger() const { +#if defined(BUILD_WEBGPU_EP_STATIC_LIB) return *ep_.GetLogger(); +#else + return ep_.GetEpLogger(); +#endif } // diff --git a/onnxruntime/core/providers/webgpu/controlflow/if.cc b/onnxruntime/core/providers/webgpu/controlflow/if.cc index 233d1f760383f..b802c88eaee5d 100644 --- a/onnxruntime/core/providers/webgpu/controlflow/if.cc +++ b/onnxruntime/core/providers/webgpu/controlflow/if.cc @@ -3,6 +3,10 @@ #include "core/providers/webgpu/controlflow/if.h" +#if !defined(BUILD_WEBGPU_EP_STATIC_LIB) +#include "core/framework/error_code_helper.h" +#endif + using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; @@ -68,10 +72,20 @@ ONNX_OPERATOR_KERNEL_EX(If, .TypeConstraint("V", DataTypeImpl::AllFixedSizeTensorTypes()), If); +#if defined(BUILD_WEBGPU_EP_STATIC_LIB) Status If::Compute(OpKernelContext* ctx) const { // call the base CPU version. return onnxruntime::If::Compute(ctx); } +#else +Status If::CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) { + return ToStatusAndRelease(ep::Api().ep.CreateIfKernel(info, impl)); +} + +Status If::Compute(OpKernelContext* ctx) const { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "If operator should be handled by ORT core."); +} +#endif } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/controlflow/if.h b/onnxruntime/core/providers/webgpu/controlflow/if.h index 0755c5d33d7a3..193598ad85b38 100644 --- a/onnxruntime/core/providers/webgpu/controlflow/if.h +++ b/onnxruntime/core/providers/webgpu/controlflow/if.h @@ -10,6 +10,8 @@ namespace onnxruntime { namespace webgpu { +#if defined(BUILD_WEBGPU_EP_STATIC_LIB) + // Use the CPU implementation for the logic class If final : public onnxruntime::If { public: @@ -18,5 +20,16 @@ class If final : public onnxruntime::If { Status Compute(OpKernelContext* ctx) const override; }; +#else + +class If final : public OpKernel { + public: + If(const OpKernelInfo& info) : OpKernel(info) {} + + Status CreateControlFlowKernelImpl(const OrtKernelInfo* info, OrtKernelImpl** impl) override; + Status Compute(OpKernelContext* ctx) const override; +}; +#endif + } // namespace webgpu -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc index 6d66a7308f1de..792e6cbc05d3f 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.cc +++ b/onnxruntime/core/providers/webgpu/data_transfer.cc @@ -13,32 +13,45 @@ bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_dev (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); } -common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { - size_t bytes = src.SizeInBytes(); +common::Status DataTransfer::CopyTensorImpl(void const* src_data, + bool src_is_gpu, + void* dst_data, + bool dst_is_gpu, + size_t bytes) const { if (bytes > 0) { - void const* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); - - auto& src_device = src.Location().device; - auto& dst_device = dst.Location().device; - - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::GPU) { + if (dst_is_gpu) { + if (src_is_gpu) { // copy from GPU to GPU buffer_manager_.MemCpy(static_cast(const_cast(src_data)), - static_cast(dst_data), bytes); + static_cast(dst_data), + bytes); } else { // copy from CPU to GPU - buffer_manager_.Upload(const_cast(src_data), static_cast(dst_data), bytes); + buffer_manager_.Upload(const_cast(src_data), + static_cast(dst_data), + bytes); } - } else /* if (src_device.Type() == OrtDevice::GPU) */ { + } else { // copy from GPU to CPU - buffer_manager_.Download(static_cast(const_cast(src_data)), dst_data, bytes); + buffer_manager_.Download(static_cast(const_cast(src_data)), + dst_data, + bytes); } } return Status::OK(); } +common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + void const* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + return CopyTensorImpl(src_data, + src.Location().device.Type() == OrtDevice::GPU, + dst_data, + dst.Location().device.Type() == OrtDevice::GPU, + src.SizeInBytes()); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.h b/onnxruntime/core/providers/webgpu/data_transfer.h index 0adf380149acf..29f70341989f7 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.h +++ b/onnxruntime/core/providers/webgpu/data_transfer.h @@ -20,6 +20,12 @@ class DataTransfer : public IDataTransfer { common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; + common::Status CopyTensorImpl(void const* src_data, + bool src_is_gpu, + void* dst_data, + bool dst_is_gpu, + size_t bytes) const; + private: const BufferManager& buffer_manager_; }; diff --git a/onnxruntime/core/providers/webgpu/ep/api.cc b/onnxruntime/core/providers/webgpu/ep/api.cc new file mode 100644 index 0000000000000..526f8ae2faf10 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/api.cc @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#define ORT_API_MANUAL_INIT +#include "onnxruntime_cxx_api.h" +#undef ORT_API_MANUAL_INIT + +#include + +#include "core/providers/webgpu/ep/factory.h" + +// To make symbols visible on macOS/iOS +#ifdef __APPLE__ +#define EXPORT_SYMBOL __attribute__((visibility("default"))) +#else +#define EXPORT_SYMBOL +#endif + +namespace onnxruntime { +namespace webgpu { +void CleanupWebGpuContexts(); +void CleanupKernelRegistries(); +} // namespace webgpu +} // namespace onnxruntime + +namespace google { +namespace protobuf { +void ShutdownProtobufLibrary(); +} // namespace protobuf +} // namespace google + +extern "C" { +// +// Public symbols +// +EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtApiBase* ort_api_base, + const OrtLogger* default_logger, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + // Manual init for the C++ API + onnxruntime::ep::ApiInit(ort_api_base); + + if (max_factories < 1) { + return onnxruntime::ep::Api().ort.CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + // Initialize the global default logger + ::onnxruntime::ep::adapter::Logger::CreateDefaultLogger(default_logger); + + // Factory could use registration_name or define its own EP name. + std::unique_ptr factory = std::make_unique(); + + factories[0] = factory.release(); + *num_factories = 1; + + return nullptr; +} + +EXPORT_SYMBOL OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + // STEP.1 - Release the factory + delete static_cast(factory); + + // STEP.2 - Clean up cached kernel registries + onnxruntime::webgpu::CleanupKernelRegistries(); + + // STEP.3 - Clean up WebGPU contexts + onnxruntime::webgpu::CleanupWebGpuContexts(); + + // STEP.4 - Destroy the global default logger wrapper + ::onnxruntime::ep::adapter::Logger::DestroyDefaultLogger(); + + // STEP.5 - Shutdown protobuf library + google::protobuf::ShutdownProtobufLibrary(); + + return nullptr; +} + +} // extern "C" diff --git a/onnxruntime/core/providers/webgpu/ep/ep.cc b/onnxruntime/core/providers/webgpu/ep/ep.cc new file mode 100644 index 0000000000000..fb07674d42434 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/ep.cc @@ -0,0 +1,260 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep.h" + +#include "factory.h" + +#include "core/framework/run_options.h" +#include "core/framework/kernel_registry.h" +#include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/session/plugin_ep/ep_kernel_registration.h" +#include "core/providers/webgpu/webgpu_execution_provider.h" + +#include "ep/get_capability_utils.h" + +namespace onnxruntime { +namespace webgpu { +namespace ep { + +using onnxruntime::ep::Api; + +// Constructor +Ep::Ep(IExecutionProvider* impl, Factory& factory, const OrtLogger& logger, const Config& config) + : onnxruntime::ep::adapter::Ep{impl, config.cpu_allocator, config.device_allocator}, + factory_{factory}, + logger_{logger}, + config_{config} { + ort_version_supported = ORT_API_VERSION; + + // Initialize the execution provider's function table + GetName = GetNameImpl; + GetCapability = GetCapabilityImpl; + GetKernelRegistry = GetKernelRegistryImpl; + Compile = nullptr; // Per-kernel EP does not use Compile + ReleaseNodeComputeInfos = nullptr; + GetPreferredDataLayout = GetPreferredDataLayoutImpl; + ShouldConvertDataLayoutForOp = ShouldConvertDataLayoutForOpImpl; + SetDynamicOptions = nullptr; // Not implemented + OnRunStart = OnRunStartImpl; + OnRunEnd = OnRunEndImpl; + CreateAllocator = CreateAllocatorImpl; + CreateSyncStreamForDevice = nullptr; // Not stream aware + GetCompiledModelCompatibilityInfo = nullptr; // Not a compiled EP + IsConcurrentRunSupported = IsConcurrentRunSupportedImpl; +} + +// OrtEp interface implementations +const char* ORT_API_CALL Ep::GetNameImpl(const OrtEp* this_ptr) noexcept { + const auto* ep = static_cast(this_ptr); + return ep->factory_.GetName(&ep->factory_); +} + +OrtStatus* ORT_API_CALL Ep::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept { + try { + auto& ep = *static_cast(static_cast(this_ptr)->EpImpl()); + Ort::ConstGraph ort_graph{graph}; + + // Get all nodes in the graph + std::vector all_nodes = ort_graph.GetNodes(); + + if (all_nodes.empty()) { + return nullptr; // No nodes to process + } + + std::vector candidate_nodes; + + // For each node, check if we have a registered kernel for it + for (const auto& node : all_nodes) { + if (node.GetEpName() == kWebGpuExecutionProvider) { + candidate_nodes.push_back(node); + continue; + } + + const OrtKernelDef* kernel_def = nullptr; + RETURN_IF_ERROR(Api().ep.EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def)); + + if (kernel_def == nullptr) { + LOGS(ep.GetEpLogger(), INFO) << "webgpu kernel not found in registries for Op type: " + << node.GetOperatorType() << " node name: " << node.GetName(); + continue; + } + + auto cpu_node_names = ep.GetForceCpuNodeNames(); + if (std::find(cpu_node_names.begin(), + cpu_node_names.end(), + node.GetName()) != cpu_node_names.end()) { + LOGS(ep.GetEpLogger(), INFO) << "Force CPU execution for node: " << node.GetName(); + continue; + } + + // + // The following code checks if the node is really supported by webgpu EP + // + +#define FALLBACK_TO_CPU_IF_EXIST_INPUT(idx) \ + if (inputs.size() > idx && inputs[idx] != nullptr) { \ + continue; \ + } + +#define FALLBACK_TO_CPU_IF_EXIST_OUTPUT(idx) \ + if (outputs.size() > idx && outputs[idx] != nullptr) { \ + continue; \ + } + + // Check for Attention + if (node.GetOperatorType() == "Attention" && node.GetDomain() == kMSDomain) { + const auto& inputs = node.GetInputs(); + const auto& outputs = node.GetOutputs(); + + // Current implementation does not support mask_index(input[3]), past(input[4]) and past_seq_len(input[6]) + FALLBACK_TO_CPU_IF_EXIST_INPUT(3); + FALLBACK_TO_CPU_IF_EXIST_INPUT(4); + FALLBACK_TO_CPU_IF_EXIST_INPUT(6); + + // Current implementation does not support present(output[1]) + FALLBACK_TO_CPU_IF_EXIST_OUTPUT(1); + + // If attribute past_present_share_buffer is set, fallback to CPU + bool has_past_present_share_buffer = false; + for (const auto& attr : node.GetAttributes()) { + if (attr.GetName() == "past_present_share_buffer") { + int64_t val = 0; + RETURN_IF_ERROR(attr.GetValue(val)); + if (val != 0) { + has_past_present_share_buffer = true; + } + } + } + if (has_past_present_share_buffer) { + continue; + } + } + + candidate_nodes.push_back(node); + } + + std::unordered_set cpu_preferred_nodes; + RETURN_IF_ERROR(onnxruntime::ep::GetCpuPreferredNodes(*ort_graph, + *graph_support_info, + static_cast(this_ptr)->GetOrtLogger(), + candidate_nodes, + cpu_preferred_nodes)); + + for (const auto& node : candidate_nodes) { + if (cpu_preferred_nodes.count(node) == 0) { + RETURN_IF_ERROR(Api().ep.EpGraphSupportInfo_AddSingleNode(graph_support_info, node)); + } + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } + + return nullptr; +} + +OrtStatus* ORT_API_CALL Ep::GetKernelRegistryImpl( + _In_ OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry) noexcept { + try { + *kernel_registry = nullptr; + + // For the WebGPU EP, delegate to the CreateKernelRegistry function + // which properly constructs a registry using only public APIs + auto* ep = static_cast(this_ptr); + const char* ep_name = ep->factory_.GetName(&ep->factory_); + + auto& webgpu_ep = *ep->EpImpl(); + + *kernel_registry = *onnxruntime::webgpu::GetKernelRegistry(webgpu_ep.IsGraphCaptureEnabled()).get(); + return nullptr; + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); + } +} + +OrtStatus* ORT_API_CALL Ep::GetPreferredDataLayoutImpl(_In_ OrtEp* this_ptr, + _Out_ OrtEpDataLayout* preferred_data_layout) noexcept { + // Delegate to the underlying WebGPU EP's GetPreferredLayout() + // DataLayout enum values map 1:1 to OrtEpDataLayout (NCHW=0, NHWC=1) + auto* ep = static_cast(this_ptr); + *preferred_data_layout = static_cast(ep->EpImpl()->GetPreferredLayout()); + return nullptr; +} + +OrtStatus* ORT_API_CALL Ep::ShouldConvertDataLayoutForOpImpl(_In_ OrtEp* this_ptr, + _In_z_ const char* domain, + _In_z_ const char* op_type, + _In_ OrtEpDataLayout target_data_layout, + _Outptr_ int* should_convert) noexcept { + // DataLayout enum values map 1:1 to OrtEpDataLayout (NCHW=0, NHWC=1) + auto* ep = static_cast(this_ptr); + auto result = ep->EpImpl()->ShouldConvertDataLayoutForOp(domain, op_type, + static_cast(target_data_layout)); + if (result.has_value()) { + *should_convert = result.value() ? 1 : 0; + } else { + *should_convert = -1; + } + return nullptr; +} + +OrtStatus* ORT_API_CALL Ep::OnRunStartImpl(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options) noexcept { + onnxruntime::RunOptions options{}; + // currently only option "gpu_graph_id" is used + auto graph_annotation_str = Api().ort.GetRunConfigEntry(run_options, kOrtRunOptionsConfigCudaGraphAnnotation); + if (graph_annotation_str != nullptr) { + options.config_options.AddConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation, graph_annotation_str); + } + auto* ep = static_cast(this_ptr); + auto status = ep->EpImpl()->OnRunStart(options); + if (!status.IsOK()) { + return Api().ort.CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } + return nullptr; +} + +OrtStatus* ORT_API_CALL Ep::OnRunEndImpl(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options, + _In_ bool sync_stream) noexcept { + auto* ep = static_cast(this_ptr); + auto status = ep->EpImpl()->OnRunEnd(sync_stream, {}); + if (!status.IsOK()) { + return Api().ort.CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } + return nullptr; +} + +OrtStatus* ORT_API_CALL Ep::IsConcurrentRunSupportedImpl(_In_ OrtEp* /*this_ptr*/, _Out_ bool* is_concurrent_run_supported) noexcept { + *is_concurrent_run_supported = false; + return nullptr; +} + +OrtStatus* ORT_API_CALL Ep::CreateAllocatorImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _Outptr_result_maybenull_ OrtAllocator** allocator) noexcept { + auto* ep = static_cast(this_ptr); + Ort::ConstMemoryInfo ort_memory_info{memory_info}; + if (ort_memory_info.GetAllocatorType() == OrtReadOnlyAllocator) { + *allocator = new onnxruntime::ep::adapter::Allocator(memory_info, ep->config_.initializer_allocator); + } else { + *allocator = new onnxruntime::ep::adapter::Allocator(memory_info, ep->config_.device_allocator); + } + return nullptr; +} + +} // namespace ep +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/ep/ep.h b/onnxruntime/core/providers/webgpu/ep/ep.h new file mode 100644 index 0000000000000..af9d73e23a3ff --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/ep.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include "core/providers/webgpu/webgpu_execution_provider.h" + +namespace onnxruntime { +namespace webgpu { +namespace ep { + +class Factory; + +/// +/// A bridge class between the EP API and the WebGPU EP implementation. +/// +class Ep : public onnxruntime::ep::adapter::Ep { + public: + struct Config { + AllocatorPtr cpu_allocator; + AllocatorPtr device_allocator; + AllocatorPtr initializer_allocator; + }; + + // Do not use a std::unique_ptr for impl_ because this requires the actual type definition. + Ep(IExecutionProvider* impl, Factory& factory, const OrtLogger& logger, const Config& config); + + inline const OrtLogger& GetOrtLogger() const noexcept { + return logger_; + } + + private: + static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept; + + static OrtStatus* ORT_API_CALL GetKernelRegistryImpl( + _In_ OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtKernelRegistry** kernel_registry) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _Outptr_result_maybenull_ OrtAllocator** allocator) noexcept; + + static OrtStatus* ORT_API_CALL GetPreferredDataLayoutImpl(_In_ OrtEp* this_ptr, + _Out_ OrtEpDataLayout* preferred_data_layout) noexcept; + + static OrtStatus* ORT_API_CALL ShouldConvertDataLayoutForOpImpl(_In_ OrtEp* this_ptr, + _In_z_ const char* domain, + _In_z_ const char* op_type, + _In_ OrtEpDataLayout target_data_layout, + _Outptr_ int* should_convert) noexcept; + + static OrtStatus* ORT_API_CALL OnRunStartImpl(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options) noexcept; + + static OrtStatus* ORT_API_CALL OnRunEndImpl(_In_ OrtEp* this_ptr, + _In_ const OrtRunOptions* run_options, + _In_ bool sync_stream) noexcept; + + static OrtStatus* ORT_API_CALL IsConcurrentRunSupportedImpl(_In_ OrtEp* this_ptr, + _Out_ bool* is_concurrent_run_supported) noexcept; + + Factory& factory_; + const OrtLogger& logger_; + Config config_{}; +}; + +} // namespace ep +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/ep/factory.cc b/onnxruntime/core/providers/webgpu/ep/factory.cc new file mode 100644 index 0000000000000..6a88f6efe1fd3 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/factory.cc @@ -0,0 +1,216 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "factory.h" +#include "ep.h" + +#include "core/framework/error_code_helper.h" +#include "core/graph/constants.h" + +#include "core/framework/execution_provider.h" +#include "core/framework/config_options.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/providers/webgpu/webgpu_context.h" +#include "core/providers/webgpu/allocator.h" + +namespace onnxruntime { +namespace webgpu { +namespace ep { + +using onnxruntime::ep::Api; + +// Constructor +Factory::Factory() : OrtEpFactory{}, + config_{ + CPUAllocator::DefaultInstance(), // CPU allocator + std::make_shared(WebGpuContextFactory::DefaultContext().BufferManager(), false), // default device allocator + std::make_shared(WebGpuContextFactory::DefaultContext().InitializerBufferManager(), true), // initializer device allocator + }, + default_memory_info_{WEBGPU_BUFFER, OrtMemoryInfoDeviceType_GPU, + 0, // vendor id + 0, // device id + OrtDeviceMemoryType_DEFAULT, + 0, // alignment + OrtDeviceAllocator}, + readonly_memory_info_{WEBGPU_BUFFER, OrtMemoryInfoDeviceType_GPU, + 0, // vendor id + 0, // device id + OrtDeviceMemoryType_DEFAULT, + 0, // alignment + OrtReadOnlyAllocator} { + ort_version_supported = ORT_API_VERSION; + + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + + GetSupportedDevices = GetSupportedDevicesImpl; + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; // TODO + ReleaseAllocator = ReleaseAllocatorImpl; // TODO + CreateDataTransfer = CreateDataTransferImpl; // TODO +} + +// Static C API implementations + +const char* ORT_API_CALL Factory::GetNameImpl(const OrtEpFactory* this_ptr) noexcept { + return kWebGpuExecutionProvider; +} + +const char* ORT_API_CALL Factory::GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { + return "Microsoft"; +} + +uint32_t ORT_API_CALL Factory::GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + return 0; +} + +const char* ORT_API_CALL Factory::GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + return "0.1.0"; +} + +OrtStatus* ORT_API_CALL Factory::GetSupportedDevicesImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + auto factory = static_cast(this_ptr); + + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (Api().ort.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // TODO: any metadata or options to add? + OrtEpDevice* ep_device = nullptr; + ORT_API_RETURN_IF_ERROR(Api().ep.CreateEpDevice(this_ptr, + &device, nullptr, nullptr, + &ep_device)); + ORT_API_RETURN_IF_ERROR(Api().ep.EpDevice_AddAllocatorInfo(ep_device, factory->default_memory_info_)); + ORT_API_RETURN_IF_ERROR(Api().ep.EpDevice_AddAllocatorInfo(ep_device, factory->readonly_memory_info_)); + ep_devices[num_ep_devices++] = ep_device; + } + } + + return nullptr; +} + +OrtStatus* ORT_API_CALL Factory::CreateEpImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept { + if (num_devices == 0) { + return Api().ort.CreateStatus(ORT_INVALID_ARGUMENT, "No hardware devices provided to create WebGPU EP."); + } + + OrtKeyValuePairs* session_config_entries = nullptr; + ORT_API_RETURN_IF_ERROR(Api().ort.GetSessionOptionsConfigEntries(session_options, &session_config_entries)); + Ort::KeyValuePairs session_config_entries_holder(session_config_entries); // allow automatic release + + auto config_options = ConfigOptions{}; + const char* const* keys = nullptr; + const char* const* values = nullptr; + size_t num_entries = 0; + Api().ort.GetKeyValuePairs(session_config_entries, &keys, &values, &num_entries); + for (size_t i = 0; i < num_entries; ++i) { + auto status = config_options.AddConfigEntry(keys[i], values[i]); + if (!status.IsOK()) { + return Api().ort.CreateStatus((OrtErrorCode)status.Code(), status.ErrorMessage().c_str()); + } + } + + try { + auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(config_options); + auto webgpu_ep = webgpu_ep_factory->CreateProvider(*session_options, *logger); + auto webgpu_ep_impl = static_cast(webgpu_ep.release()); + webgpu_ep_impl->SetEpLogger(logger); + int device_id = webgpu_ep_impl->GetDeviceId(); + auto& webgpu_context = WebGpuContextFactory::GetContext(device_id); + auto factory = static_cast(this_ptr); + *ep = new Ep(webgpu_ep_impl, *factory, *logger, factory->config_); + return nullptr; + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } +} + +void ORT_API_CALL Factory::ReleaseEpImpl(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { + delete static_cast(ep); +} + +OrtStatus* ORT_API_CALL Factory::CreateAllocatorImpl( + OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept { + auto factory = static_cast(this_ptr); + Ort::ConstMemoryInfo ort_memory_info{memory_info}; + if (ort_memory_info.GetAllocatorType() == OrtReadOnlyAllocator) { + *allocator = new onnxruntime::ep::adapter::Allocator(memory_info, factory->config_.initializer_allocator); + } else { + *allocator = new onnxruntime::ep::adapter::Allocator(memory_info, factory->config_.device_allocator); + } + return nullptr; +} + +void ORT_API_CALL Factory::ReleaseAllocatorImpl(OrtEpFactory* this_ptr, OrtAllocator* allocator) noexcept { + onnxruntime::ep::adapter::Allocator* ptr = static_cast(allocator); + delete ptr; +} + +OrtStatus* ORT_API_CALL Factory::CreateDataTransferImpl( + OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept { + try { + *data_transfer = OrtWebGpuCreateDataTransfer(); // TODO(fs-eire): pass context id if needed + return nullptr; + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } +} + +bool ORT_API_CALL Factory::IsStreamAwareImpl(const OrtEpFactory* this_ptr) noexcept { + return false; // Default: not stream aware +} + +OrtStatus* ORT_API_CALL Factory::CreateSyncStreamForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept { + *stream = nullptr; + return Api().ort.CreateStatus(ORT_NOT_IMPLEMENTED, + "CreateSyncStreamForDevice is not implemented for this EP factory."); +} + +OrtStatus* ORT_API_CALL Factory::ValidateCompiledModelCompatibilityInfoImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept { + *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return nullptr; +} + +OrtStatus* ORT_API_CALL Factory::SetEnvironmentOptionsImpl( + OrtEpFactory* this_ptr, + const OrtKeyValuePairs* options) noexcept { + return nullptr; // Default implementation does nothing +} + +} // namespace ep +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/ep/factory.h b/onnxruntime/core/providers/webgpu/ep/factory.h new file mode 100644 index 0000000000000..cd886e19b56d7 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/ep/factory.h @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "ep.h" + +namespace onnxruntime { +namespace webgpu { +namespace ep { + +/// +/// A bridge class between the EP API and the WebGPU EP Factory implementation. +/// +class Factory : public OrtEpFactory { + private: + // Static C API implementations + static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVendorImpl(const OrtEpFactory* this_ptr) noexcept; + static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept; + + static OrtStatus* ORT_API_CALL CreateEpImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* logger, + OrtEp** ep) noexcept; + + static void ORT_API_CALL ReleaseEpImpl(OrtEpFactory* this_ptr, OrtEp* ep) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl( + OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept; + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* this_ptr, OrtAllocator* allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl( + OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept; + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl( + OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept; + + static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfoImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept; + + static OrtStatus* ORT_API_CALL SetEnvironmentOptionsImpl( + OrtEpFactory* this_ptr, + const OrtKeyValuePairs* options) noexcept; + + Ep::Config config_; + Ort::MemoryInfo default_memory_info_; + Ort::MemoryInfo readonly_memory_info_; // used for initializers + + public: + Factory(); + ~Factory() = default; +}; + +} // namespace ep +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/tensor/cast.cc b/onnxruntime/core/providers/webgpu/tensor/cast.cc index daf4aa323c12e..2ceee741b4a49 100644 --- a/onnxruntime/core/providers/webgpu/tensor/cast.cc +++ b/onnxruntime/core/providers/webgpu/tensor/cast.cc @@ -113,7 +113,7 @@ template KernelCreateInfo CreateCastKernelInfo(bool enable_graph_capture) { const auto& type_constraints = CastOpTypeConstraints(enable_graph_capture); - KernelCreateFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreatePtrFn kernel_create_fn = [](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info); return Status::OK(); }; diff --git a/onnxruntime/core/providers/webgpu/tensor/concat.cc b/onnxruntime/core/providers/webgpu/tensor/concat.cc index 283a9e5fe8262..fa73c7d15ebdf 100644 --- a/onnxruntime/core/providers/webgpu/tensor/concat.cc +++ b/onnxruntime/core/providers/webgpu/tensor/concat.cc @@ -98,7 +98,7 @@ Status Concat::ComputeInternal(ComputeContext& context) const { } Prepare prepare; - ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), input_tensors, prepare)); + ORT_RETURN_IF_ERROR(PrepareForComputeImpl(&context.KernelContext(), input_tensors, prepare)); if (prepare.output_num_elements == 0) { return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/tensor/gather.cc b/onnxruntime/core/providers/webgpu/tensor/gather.cc index 39d07991f3c5a..8793e89e7b255 100644 --- a/onnxruntime/core/providers/webgpu/tensor/gather.cc +++ b/onnxruntime/core/providers/webgpu/tensor/gather.cc @@ -41,7 +41,7 @@ Status GatherProgram::GenerateShaderCode(ShaderHelper& shader) const { Status Gather::ComputeInternal(ComputeContext& context) const { Prepare p; - ORT_RETURN_IF_ERROR(PrepareForCompute(&context.KernelContext(), p)); + ORT_RETURN_IF_ERROR(PrepareForComputeImpl(&context.KernelContext(), p)); uint32_t data_size = onnxruntime::narrow(p.output_tensor->Shape().Size()); if (data_size == 0) { return Status::OK(); diff --git a/onnxruntime/core/providers/webgpu/tensor/pad.cc b/onnxruntime/core/providers/webgpu/tensor/pad.cc index 0e77ec46bbddb..7a576c4b53ecf 100644 --- a/onnxruntime/core/providers/webgpu/tensor/pad.cc +++ b/onnxruntime/core/providers/webgpu/tensor/pad.cc @@ -49,7 +49,7 @@ Status Pad::ComputeInternal(ComputeContext& context) const { const auto pads_data = pads_tensor->DataAsSpan(); // Compute Pads by applying axes if specified otherwise copy the supplied pads. - PadBase::ComputePads(context.KernelContext(), data_rank, pads_data, pads); + PadBase::ComputePadsImpl(context.KernelContext(), data_rank, pads_data, pads); // Separate out any negative pads into the slices array PadBase::SeparateNegativeToSlices(pads, slices); diff --git a/onnxruntime/core/providers/webgpu/tensor/shape_op.cc b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc index b211d48dab1c9..e5e6e4d61eed5 100644 --- a/onnxruntime/core/providers/webgpu/tensor/shape_op.cc +++ b/onnxruntime/core/providers/webgpu/tensor/shape_op.cc @@ -3,11 +3,74 @@ #include "core/providers/webgpu/webgpu_kernel.h" #include "core/providers/webgpu/webgpu_supported_types.h" -#include "core/providers/cpu/tensor/shape_op.h" +// #include "core/providers/cpu/tensor/shape_op.h" namespace onnxruntime { namespace webgpu { +#ifndef SHARED_PROVIDER +#include "core/common/common.h" +#include "core/common/narrow.h" +#include "core/framework/op_kernel.h" +#endif + +#include +#include + +class Shape final : public OpKernel { + public: + Shape(const OpKernelInfo& info) : OpKernel(info) { + info.GetAttrOrDefault("start", &start_index_, 0); + + if (start_index_ != 0) { + // "start" is provided and is non-default (default is 0) + needs_slicing_ = true; + } + + if (info.GetAttr("end", &end_index_).IsOK()) { + needs_slicing_ = true; + } + } + + // Takes a tensor as input and outputs an 1D int64 tensor + // containing the shape of the input tensor. + Status Compute(OpKernelContext* context) const override { + const auto* input = context->Input(0); + const TensorShape& input_shape = input->Shape(); + + int64_t rank = gsl::narrow_cast(input_shape.NumDimensions()); + + if (!needs_slicing_) { // vanilla use of Shape (no slicing) + Tensor* output = context->Output(0, {rank}); + input_shape.CopyDims(output->MutableData(), static_cast(rank)); + } else { // slicing is needed + int64_t true_start = start_index_; + int64_t true_end = end_index_; + + // Deal with negative(s) and clamp + true_start = true_start < 0 ? true_start + rank : true_start; + true_start = true_start < 0 ? 0 : ((true_start > rank) ? rank : true_start); + + true_end = true_end < 0 ? true_end + rank : true_end; + true_end = true_end < 0 ? 0 : ((true_end > rank) ? rank : true_end); + + auto slice_length = true_end - true_start; + Tensor* output = context->Output(0, {slice_length < 0 ? 0 : slice_length}); + + if (slice_length > 0) { + input_shape.CopyDims(output->MutableData(), onnxruntime::narrow(true_start), onnxruntime::narrow(slice_length)); + } + } + + return Status::OK(); + } + + private: + bool needs_slicing_ = false; + int64_t start_index_ = 0; + int64_t end_index_ = std::numeric_limits::max(); +}; + ONNX_OPERATOR_VERSIONED_KERNEL_EX( Shape, kOnnxDomain, diff --git a/onnxruntime/core/providers/webgpu/tensor/upsample.cc b/onnxruntime/core/providers/webgpu/tensor/upsample.cc index fb406883ba4ba..8f51ed45004bf 100644 --- a/onnxruntime/core/providers/webgpu/tensor/upsample.cc +++ b/onnxruntime/core/providers/webgpu/tensor/upsample.cc @@ -90,7 +90,7 @@ Status Upsample::ComputeInternal(ComputeContext& context) const { InlinedVector scales_array(input_dims.size()); // opset < 10 - if (OpKernel::Node().InputDefs().size() == 1) { + if (OpKernel::Node().SinceVersion() < 10) { scales_array = scales_; // Compute output shape from scales attributes and input dims ComputeOutputShape(scales_array, input_dims, output_dims); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index ee6b7707384e2..992d04c2a263a 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -426,8 +426,8 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, 17, ScatterElements); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ScatterElements); -std::unique_ptr RegisterKernels(bool enable_graph_capture = false) { - auto kernel_registry = std::make_unique(); +std::unique_ptr RegisterKernels(bool enable_graph_capture) { + auto kernel_registry = std::make_unique(); static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list becoming empty after ops-reducing @@ -797,6 +797,51 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals return kernel_registry; } +#if !defined(BUILD_WEBGPU_EP_STATIC_LIB) + +namespace { +std::shared_ptr g_kernel_registry; +std::shared_ptr g_graph_capture_kernel_registry; +} // namespace + +void CleanupKernelRegistries() { + g_kernel_registry.reset(); + g_graph_capture_kernel_registry.reset(); +} +#endif + +std::shared_ptr GetKernelRegistry(bool enable_graph_capture) { + // kernel registry variables are defined differently based on build configuration + // + // - When building as a static library, use static local variable. This is because + // we don't have a reliable way to explicitly destroy the kernel registry after + // use. + // + // - When building as a shared library, use global variables. The cleanup will be performed + // when `ReleaseEpFactory` is called. + if (enable_graph_capture) { +#if defined(BUILD_WEBGPU_EP_STATIC_LIB) + static std::shared_ptr registry = RegisterKernels(true); + return registry; +#else + if (g_graph_capture_kernel_registry == nullptr) { + g_graph_capture_kernel_registry = RegisterKernels(true); + } + return g_graph_capture_kernel_registry; +#endif + } else { +#if defined(BUILD_WEBGPU_EP_STATIC_LIB) + static std::shared_ptr registry = RegisterKernels(false); + return registry; +#else + if (g_kernel_registry == nullptr) { + g_kernel_registry = RegisterKernels(false); + } + return g_kernel_registry; +#endif + } +} + } // namespace webgpu using namespace webgpu; @@ -840,6 +885,7 @@ std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { }; } +#if defined(BUILD_WEBGPU_EP_STATIC_LIB) std::vector> WebGpuExecutionProvider::GetCapability( const onnxruntime::GraphViewer& graph, const IKernelLookup& kernel_lookup, @@ -930,16 +976,7 @@ std::vector> WebGpuExecutionProvider::GetCapa } return result; } - -std::shared_ptr WebGpuExecutionProvider::GetKernelRegistry() const { - if (enable_graph_capture_) { - static std::shared_ptr registry = webgpu::RegisterKernels(true); - return registry; - } else { - static std::shared_ptr registry = webgpu::RegisterKernels(false); - return registry; - } -} +#endif // defined(BUILD_WEBGPU_EP_STATIC_LIB) std::unique_ptr WebGpuExecutionProvider::GetDataTransfer() const { return std::make_unique(BufferManager()); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index bf0963f67cf1e..26ba16560bc07 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -4,6 +4,11 @@ #pragma once +#include +#include +#include +#include + #include "core/framework/execution_provider.h" #include "core/framework/session_options.h" #include "core/graph/constants.h" @@ -28,6 +33,9 @@ class GpuBufferAllocator; // Forward declare CapturedCommandInfo which is now defined in webgpu_context.h struct CapturedCommandInfo; + +// The actual implementation of kernel registration. +std::shared_ptr GetKernelRegistry(bool enable_graph_capture); } // namespace webgpu struct WebGpuExecutionProviderConfig { @@ -42,13 +50,17 @@ class WebGpuExecutionProvider : public IExecutionProvider { WebGpuExecutionProvider(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& config); ~WebGpuExecutionProvider() override; +#if defined(BUILD_WEBGPU_EP_STATIC_LIB) std::vector> GetCapability( const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/, const GraphOptimizerRegistry& /* graph_optimizer_registry */, IResourceAccountant* /* resource_accountant */) const override; - std::shared_ptr GetKernelRegistry() const override; + std::shared_ptr GetKernelRegistry() const override { + return webgpu::GetKernelRegistry(enable_graph_capture_); + } +#endif std::unique_ptr GetDataTransfer() const override; #if defined(__wasm__) std::unique_ptr GetExternalDataLoader() const override; @@ -81,6 +93,16 @@ class WebGpuExecutionProvider : public IExecutionProvider { Status ReplayGraph(int graph_annotation_id) override; webgpu::BufferManager& BufferManager() const; AllocatorPtr PrepackAllocator() const { return prepack_allocator_; } + std::span GetForceCpuNodeNames() const { return force_cpu_node_names_; } + +#if !defined(BUILD_WEBGPU_EP_STATIC_LIB) + inline onnxruntime::ep::adapter::Logger& GetEpLogger() const { + return *ep_logger_; + } + inline void SetEpLogger(const OrtLogger* logger) { + ep_logger_ = std::make_unique(logger); + } +#endif private: bool IsGraphCaptureAllowed() const; @@ -109,6 +131,10 @@ class WebGpuExecutionProvider : public IExecutionProvider { // Allocator for prepacked weights (uses buffers without mapping) AllocatorPtr prepack_allocator_; + +#if !defined(BUILD_WEBGPU_EP_STATIC_LIB) + std::unique_ptr ep_logger_; +#endif }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index cd791e31dcc2f..5db1db45ba84b 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -258,11 +258,11 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( // WebGPU DataTransfer implementation wrapper for the C API with lazy initialization struct WebGpuDataTransferImpl : OrtDataTransferImpl { - WebGpuDataTransferImpl(const OrtApi& ort_api_in) + WebGpuDataTransferImpl(const OrtApi& ort_api_in, int context_id) : ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()}, data_transfer_{nullptr}, - context_id_{0}, // Always use context 0 for Environment's data transfer + context_id_{context_id}, // Always use context 0 for Environment's data transfer init_mutex_{} { ort_version_supported = ORT_API_VERSION; CanCopy = CanCopyImpl; // OrtDataTransferImpl::CanCopy callback @@ -301,9 +301,9 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { // If both are GPU, they must have the same device ID if (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) { - uint64_t src_device_id = impl.ep_api.MemoryDevice_GetDeviceId(src_memory_device); - uint64_t dst_device_id = impl.ep_api.MemoryDevice_GetDeviceId(dst_memory_device); - if (src_device_id != dst_device_id) { + int src_device_id = impl.ep_api.MemoryDevice_GetDeviceId(src_memory_device); + int dst_device_id = impl.ep_api.MemoryDevice_GetDeviceId(dst_memory_device); + if (src_device_id != impl.context_id_ || dst_device_id != impl.context_id_) { return false; // Cannot copy between different devices } } @@ -346,9 +346,19 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { // Now perform the actual tensor copy for (size_t idx = 0; idx < num_tensors; ++idx) { - const OrtValue* src_tensor = src_tensors[idx]; - OrtValue* dst_tensor = dst_tensors[idx]; - auto status = impl.data_transfer_->CopyTensor(src_tensor->Get(), *dst_tensor->GetMutable()); + Ort::ConstValue src_value{src_tensors[idx]}; + const void* src_data = src_value.GetTensorRawData(); + size_t size = src_value.GetTensorSizeInBytes(); + bool src_is_gpu = src_value.GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU; + + Ort::UnownedValue dst_value{dst_tensors[idx]}; + void* dst_data = dst_value.GetTensorMutableRawData(); + bool dst_is_gpu = dst_value.GetTensorMemoryInfo().GetDeviceType() == OrtMemoryInfoDeviceType_GPU; + auto status = impl.data_transfer_->CopyTensorImpl(src_data, + src_is_gpu, + dst_data, + dst_is_gpu, + src_value.GetTensorSizeInBytes()); if (!status.IsOK()) { return OrtApis::CreateStatus(ORT_RUNTIME_EXCEPTION, status.ErrorMessage().c_str()); } @@ -377,14 +387,18 @@ struct WebGpuDataTransferImpl : OrtDataTransferImpl { std::mutex init_mutex_; // Protects lazy initialization }; -OrtDataTransferImpl* OrtWebGpuCreateDataTransfer() { +OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(int context_id /* = 0 */) { +#if defined(BUILD_WEBGPU_EP_STATIC_LIB) // Validate API version is supported const OrtApi* api = OrtApis::GetApi(ORT_API_VERSION); if (!api) { // API version not supported - return nullptr to indicate failure return nullptr; } - return new WebGpuDataTransferImpl(*api); + return new WebGpuDataTransferImpl(*api, context_id); +#else + return new WebGpuDataTransferImpl(onnxruntime::ep::Api().ort, context_id); +#endif } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h index 021e33ef25309..876a2e11d791a 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory_creator.h @@ -22,6 +22,6 @@ struct WebGpuProviderFactoryCreator { // C API to create data transfer for WebGPU EP with lazy initialization // Context will be determined from tensors during the first CopyTensors call // Caller takes ownership of the returned OrtDataTransferImpl* -OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(); +OrtDataTransferImpl* OrtWebGpuCreateDataTransfer(int context_id = 0); } // namespace onnxruntime diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 1c4e7800b7d2e..edfd9c193545a 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -174,67 +174,10 @@ class FuseExecutionProvider : public IExecutionProvider { }; namespace test { -static void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims, - const std::vector& expected_values); static constexpr const ORTCHAR_T* MODEL_URI = ORT_TSTR("testdata/mul_1.onnx"); static constexpr const ORTCHAR_T* MODEL_URI_NO_OPSET = ORT_TSTR("testdata/mul_1.noopset.onnx"); // static const std::string MODEL_URI = "./testdata/squeezenet/model.onnx"; // TODO enable this after we've weights? -static void CreateMatMulModel(std::unique_ptr& p_model, ProviderType provider_type) { - std::unordered_map domain_to_version; - domain_to_version[onnxruntime::kOnnxDomain] = 7; - // Generate the input & output def lists - std::vector model_specific_functions; - p_model = std::make_unique("test", true, ModelMetaData(), PathString(), - IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, - model_specific_functions, DefaultLoggingManager().DefaultLogger(), - ModelOptions(true, true)); - onnxruntime::Graph& graph = p_model->MainGraph(); - - TypeProto tensor_float; - tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); - - std::vector input_defs; - auto& input_arg_a = graph.GetOrCreateNodeArg("A", &tensor_float); - input_defs.push_back(&input_arg_a); - - auto& input_arg_b = graph.GetOrCreateNodeArg("B", &tensor_float); - input_defs.push_back(&input_arg_b); - - std::vector output_defs; - auto& output_arg = graph.GetOrCreateNodeArg("Y", &tensor_float); - output_defs.push_back(&output_arg); - - // Create a simple model - auto& node = graph.AddNode("node1", "MatMul", "MatMul", input_defs, output_defs, nullptr, onnxruntime::kOnnxDomain); - if (provider_type == kCpuExecutionProvider) { - node.SetExecutionProviderType(provider_type); - } else { -#if defined(USE_CUDA) || defined(USE_WEBGPU) - node.SetExecutionProviderType(provider_type); -#endif - } - Status status = graph.Resolve(); - ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); -} - -template -void VerifyOutputs(const Tensor& tensor, const std::vector& expected_dims, - const std::vector& expected_values) { - TensorShape expected_shape(expected_dims); - ASSERT_EQ(expected_shape, tensor.Shape()); - const std::vector found(tensor.Data(), - tensor.Data() + expected_values.size()); - ASSERT_EQ(expected_values, found); -} - -void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims, - const std::vector& expected_values) { - ASSERT_EQ(1u, fetches.size()); - auto& rtensor = fetches.front().Get(); - VerifyOutputs(rtensor, expected_dims, expected_values); -} - void RunModel(InferenceSession& session_object, const RunOptions& run_options, bool is_preallocate_output_vec = false) { @@ -272,174 +215,6 @@ void RunModel(InferenceSession& session_object, VerifyOutputs(fetches, expected_dims_mul_y, expected_values_mul_y); } -void RunModelWithBindingMatMul(InferenceSession& session_object, - const RunOptions& run_options, - ProviderType bind_provider_type, - bool is_preallocate_output_vec, - ProviderType allocation_provider, - IExecutionProvider* gpu_provider, - OrtDevice* output_device, - bool enable_graph_capture) { - std::unique_ptr io_binding; - Status st = session_object.NewIOBinding(&io_binding); - ASSERT_TRUE(st.IsOK()); - - // bind a value to A with input that will produce invalid output in order to test replacement of a feed - std::vector values_mul_x_tmp = {12.f, 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}; - std::vector dims_mul_x_A_tmp = {3, 4}; - std::vector values_mul_x = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; - std::vector dims_mul_x_A = {3, 4}; - std::vector dims_mul_x_B = {4, 3}; - - auto cpu_alloc = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; - onnxruntime::AllocatorPtr gpu_alloc = nullptr; - if (allocation_provider == kWebGpuExecutionProvider) { - // Use session_object.GetAllocator to get the OrtAllocator for WebGPU. - // Otherwise, gpu_provider->CreatePreferredAllocators() will create a new OrtAllocator which will go to the create UMA path. - // And it can't be used for copying buffer to buffer since the target buffer is still in mapped state. - OrtMemoryInfo mem_info(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)); - gpu_alloc = session_object.GetAllocator(mem_info); - } else if (allocation_provider == kCudaExecutionProvider) { - gpu_alloc = gpu_provider->CreatePreferredAllocators()[0]; - } - if (enable_graph_capture) { - // For graph capture, all inputs/outputs should be in preallocated gpu memory. - ASSERT_TRUE(is_preallocate_output_vec); - OrtValue input_ml_value_A_cpu; - CreateMLValue(cpu_alloc, dims_mul_x_A, values_mul_x, &input_ml_value_A_cpu); - auto& cpu_tensor_a = input_ml_value_A_cpu.Get(); - Tensor gpu_tensor_a(cpu_tensor_a.DataType(), cpu_tensor_a.Shape(), gpu_alloc); - st = gpu_provider->GetDataTransfer()->CopyTensor(cpu_tensor_a, gpu_tensor_a); - ASSERT_TRUE(st.IsOK()); - OrtValue input_ml_value_A; - Tensor::InitOrtValue(std::move(gpu_tensor_a), input_ml_value_A); - - OrtValue input_ml_value_B_cpu; - CreateMLValue(cpu_alloc, dims_mul_x_B, values_mul_x, &input_ml_value_B_cpu); - auto& cpu_tensor_b = input_ml_value_B_cpu.Get(); - Tensor gpu_tensor_b(cpu_tensor_b.DataType(), cpu_tensor_b.Shape(), gpu_alloc); - st = gpu_provider->GetDataTransfer()->CopyTensor(cpu_tensor_b, gpu_tensor_b); - ASSERT_TRUE(st.IsOK()); - OrtValue input_ml_value_B; - Tensor::InitOrtValue(std::move(gpu_tensor_b), input_ml_value_B); - - ASSERT_STATUS_OK(io_binding->BindInput("A", input_ml_value_A)); - ASSERT_STATUS_OK(io_binding->BindInput("B", input_ml_value_B)); - } else { - auto input_allocator = io_binding->GetCPUAllocator(bind_provider_type); - OrtValue input_tmp; - CreateMLValue(input_allocator, dims_mul_x_A_tmp, values_mul_x_tmp, &input_tmp); - ASSERT_STATUS_OK(io_binding->BindInput("A", input_tmp)); - const void* tmp_A = io_binding->GetInputs()[0].Get().DataRaw(); // location of data post binding - - // prepare inputs - /* - 0 1 2 3 0 1 2 - 4 5 6 7 3 4 5 - 8 9 10 11 6 7 8 - 9 10 11 - */ - // bind one input to cpu allocator from bind_provider_type, and another on user provided CPU memory - // so both code pathes are covered - OrtValue input_ml_value_A; - CreateMLValue(input_allocator, dims_mul_x_A, values_mul_x, &input_ml_value_A); - - OrtValue input_ml_value_B; - CreateMLValue(cpu_alloc, dims_mul_x_B, values_mul_x, &input_ml_value_B); - - ASSERT_STATUS_OK(io_binding->BindInput("A", input_ml_value_A)); - ASSERT_STATUS_OK(io_binding->BindInput("B", input_ml_value_B)); - - // check location of 'A' post-binding has changed to validate that the previous value was replaced - ASSERT_TRUE(io_binding->GetInputs()[0].Get().DataRaw() != tmp_A); - } - // prepare outputs - std::vector expected_output_dims = {3, 3}; - OrtValue output_ml_value; - if (is_preallocate_output_vec) { - if (allocation_provider == kCpuExecutionProvider) { - AllocateMLValue(cpu_alloc, expected_output_dims, &output_ml_value); - } else if (allocation_provider == kCudaExecutionProvider || allocation_provider == kWebGpuExecutionProvider) { - AllocateMLValue(gpu_alloc, expected_output_dims, &output_ml_value); - } else { - ORT_THROW("Unsupported provider"); - } - } - - if (output_device) { - // output should be allocated on specified device (if not preallocated here) - ASSERT_STATUS_OK(io_binding->BindOutput("Y", *output_device)); - } else { - ASSERT_STATUS_OK(io_binding->BindOutput("Y", output_ml_value)); - } - - ASSERT_TRUE(io_binding->SynchronizeInputs().IsOK()); - - // prepare expected inputs and outputs - std::vector expected_values_mul_y = {42, 48, 54, 114, 136, 158, 186, 224, 262}; - std::vector expected_values_mul_y_2 = {174, 216, 258, 102, 128, 154, 30, 40, 50}; - - // Now run - ASSERT_STATUS_OK(session_object.Run(run_options, *io_binding)); - - if ((is_preallocate_output_vec && (allocation_provider == kCudaExecutionProvider || allocation_provider == kWebGpuExecutionProvider)) || - (output_device && output_device->Type() == OrtDevice::GPU)) { -#if defined(USE_CUDA) || defined(USE_WEBGPU) - // in this case we need to copy the tensor from cuda to cpu - std::vector& outputs = io_binding->GetOutputs(); - ASSERT_EQ(1u, outputs.size()); - auto& rtensor = outputs.front().Get(); - auto element_type = rtensor.DataType(); - auto& shape = rtensor.Shape(); - Tensor cpu_tensor(element_type, shape, cpu_alloc); -#ifdef USE_CUDA - st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, cpu_tensor); -#endif -#ifdef USE_WEBGPU - st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, cpu_tensor); -#endif - ASSERT_TRUE(st.IsOK()); - OrtValue ml_value; - Tensor::InitOrtValue(std::move(cpu_tensor), ml_value); - VerifyOutputs({ml_value}, expected_output_dims, expected_values_mul_y); -#endif - } else { - if (allocation_provider == kCudaExecutionProvider || allocation_provider == kWebGpuExecutionProvider) { - ASSERT_STATUS_OK(gpu_provider->Sync()); - } - VerifyOutputs(io_binding->GetOutputs(), expected_output_dims, expected_values_mul_y); - } - - if (enable_graph_capture) { - // Update input_a's value. Run again. Replay the captured graph - OrtValue input_a2; - CreateMLValue(cpu_alloc, dims_mul_x_A_tmp, values_mul_x_tmp, &input_a2); - auto& cpu_tensor_a2 = input_a2.Get(); - st = gpu_provider->GetDataTransfer()->CopyTensor(cpu_tensor_a2, const_cast(io_binding->GetInputs()[0].Get())); - ASSERT_TRUE(st.IsOK()); - - st = session_object.Run(run_options, *io_binding.get()); - - std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; - ASSERT_TRUE(st.IsOK()); - - // Copy the tensor from gpu to cpu - std::vector& outputs = io_binding->GetOutputs(); - ASSERT_EQ(1u, outputs.size()); - auto& rtensor = outputs.front().Get(); - auto element_type = rtensor.DataType(); - auto& shape = rtensor.Shape(); - std::unique_ptr cpu_tensor = std::make_unique(element_type, shape, cpu_alloc); - st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); - ASSERT_TRUE(st.IsOK()); - OrtValue ml_value; - ml_value.Init(cpu_tensor.release(), - DataTypeImpl::GetType(), - DataTypeImpl::GetType()->GetDeleteFunc()); - VerifyOutputs({ml_value}, expected_output_dims, expected_values_mul_y_2); - } -} - TEST(InferenceSessionTests, NoTimeout) { SessionOptions so; @@ -1006,110 +781,6 @@ TEST(InferenceSessionTests, TestRegisterExecutionProvider) { RunModel(session_object, run_options); } -static void TestBindHelper(const std::string& log_str, - ProviderType bind_provider_type, - ProviderType run_provider_type, - bool preallocate_output, - ProviderType allocation_provider = kCpuExecutionProvider, - OrtDevice* output_device = nullptr, - bool enable_graph_capture = false) { - SessionOptions so; - - so.session_logid = "InferenceSessionTests." + log_str; - so.session_log_verbosity_level = 1; // change to 1 for detailed logging - InferenceSession session_object{so, GetEnvironment()}; - IExecutionProvider* gpu_provider{}; - - if (bind_provider_type == kCudaExecutionProvider || bind_provider_type == kWebGpuExecutionProvider) { -#ifdef USE_CUDA - { - auto provider = DefaultCudaExecutionProvider(); - gpu_provider = provider.get(); - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(provider))); - } -#endif -#ifdef USE_WEBGPU - { - ConfigOptions config_options{}; - ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kEnableGraphCapture, - enable_graph_capture ? webgpu::options::kEnableGraphCapture_ON : webgpu::options::kEnableGraphCapture_OFF) - .IsOK()); - auto provider = WebGpuExecutionProviderWithOptions(config_options); - gpu_provider = provider.get(); - ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(provider))); - } -#endif - } - - std::unique_ptr p_model; - CreateMatMulModel(p_model, run_provider_type); - - std::string s1; - p_model->ToProto().SerializeToString(&s1); - std::stringstream sstr(s1); - ASSERT_STATUS_OK(session_object.Load(sstr)); - ASSERT_STATUS_OK(session_object.Initialize()); - - RunOptions run_options; - run_options.run_log_verbosity_level = so.session_log_verbosity_level; - run_options.run_tag = so.session_logid; - - RunModelWithBindingMatMul(session_object, - run_options, - bind_provider_type, - preallocate_output, - allocation_provider, - gpu_provider, - output_device, - enable_graph_capture); -} - -TEST(InferenceSessionTests, TestBindCpu) { - TestBindHelper("TestBindCpu", - kCpuExecutionProvider, - kCpuExecutionProvider, - false /* don't preallocate output */); -} - -TEST(InferenceSessionTests, TestIOBindingReuse) { - SessionOptions so; - InferenceSession session_object(so, GetEnvironment()); - std::unique_ptr p_model; - CreateMatMulModel(p_model, kCpuExecutionProvider); - - std::string s1; - p_model->ToProto().SerializeToString(&s1); - std::stringstream sstr(s1); - ASSERT_TRUE(session_object.Load(sstr).IsOK()); - ASSERT_STATUS_OK(session_object.Initialize()); - std::unique_ptr io_binding; - Status st = session_object.NewIOBinding(&io_binding); - ASSERT_TRUE(st.IsOK()); - - OrtValue ml_value1; - const std::vector v1{2.f}; - const int64_t shape[] = {1}; - CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], shape, v1, &ml_value1); - ASSERT_STATUS_OK(io_binding->BindOutput("foo", ml_value1)); - ASSERT_TRUE(io_binding->GetOutputs().size() == 1); - auto span = io_binding->GetOutputs()[0].Get().DataAsSpan(); - ASSERT_TRUE(static_cast(span.size()) == v1.size()); - for (size_t i = 0; i < v1.size(); ++i) { - ASSERT_TRUE(v1[i] == span[i]); - } - - OrtValue ml_value2; - const std::vector v2{3.f}; - CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], shape, v2, &ml_value2); - ASSERT_STATUS_OK(io_binding->BindOutput("foo", ml_value2)); - ASSERT_TRUE(io_binding->GetOutputs().size() == 1); - span = io_binding->GetOutputs()[0].Get().DataAsSpan(); - ASSERT_TRUE(static_cast(span.size()) == v2.size()); - for (size_t i = 0; i < v2.size(); ++i) { - ASSERT_TRUE(v2[i] == span[i]); - } -} - TEST(InferenceSessionTests, InvalidInputTypeOfTensorElement) { SessionOptions so; @@ -1148,67 +819,6 @@ TEST(InferenceSessionTests, InvalidInputTypeOfTensorElement) { ASSERT_TRUE(!st.IsOK()); } -#if defined(USE_CUDA) || defined(USE_WEBGPU) -#if USE_CUDA -constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; -#elif USE_WEBGPU -constexpr const char* kGpuExecutionProvider = kWebGpuExecutionProvider; -#endif - -TEST(InferenceSessionTests, TestBindCuda) { - TestBindHelper("TestBindCuda", - kGpuExecutionProvider, - kGpuExecutionProvider, - false /* don't preallocate output */); -} - -TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCuda) { - TestBindHelper("TestBindCudaPreallocateOutputOnCuda", - kGpuExecutionProvider, - kGpuExecutionProvider, - true /* preallocate output on GPU */, - kGpuExecutionProvider); -} - -TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCpu) { - TestBindHelper("TestBindCudaPreallocateOutputOnCpu", - kGpuExecutionProvider, - kGpuExecutionProvider, - true /* preallocate output on CPU */, - kCpuExecutionProvider); -} - -TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCpu2) { - TestBindHelper("TestBindCudaPreallocateOutputOnCpu2", - kGpuExecutionProvider, - kCpuExecutionProvider, - true /* preallocate output on CPU */, - kCpuExecutionProvider); -} -#ifndef USE_WEBGPU -TEST(InferenceSessionTests, TestBindCudaSpecifyOutputDeviceOnCuda) { - OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0); - - TestBindHelper("TestBindCudaPreallocateOutputOnCuda", - kGpuExecutionProvider, - kGpuExecutionProvider, - false /* preallocate output on GPU */, - kGpuExecutionProvider, - &device /* specify output device */); -} -#else -TEST(InferenceSessionTests, TestGraphCapture) { - TestBindHelper("TestGraphCapture", - kGpuExecutionProvider, - kGpuExecutionProvider, - true /* preallocate output on GPU */, - kGpuExecutionProvider, - nullptr, - true /* enable graph capture*/); -} -#endif // !USE_WEBGPU -#endif - TEST(InferenceSessionTests, ModelWithoutOpset) { SessionOptions so; diff --git a/onnxruntime/test/providers/io_binding_test.cc b/onnxruntime/test/providers/io_binding_test.cc new file mode 100644 index 0000000000000..125b51d6d4c3a --- /dev/null +++ b/onnxruntime/test/providers/io_binding_test.cc @@ -0,0 +1,402 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/onnx_protobuf.h" +#include "core/session/inference_session.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/graph/model.h" +#include "core/framework/tensorprotoutils.h" +#include "core/session/IOBinding.h" + +#include "test/unittest_util/framework_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" +#include "test/test_environment.h" + +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace test { + +static void CreateMatMulModel(std::unique_ptr& p_model, ProviderType provider_type) { + std::unordered_map domain_to_version; + domain_to_version[onnxruntime::kOnnxDomain] = 7; + // Generate the input & output def lists + std::vector model_specific_functions; + p_model = std::make_unique("test", true, ModelMetaData(), PathString(), + IOnnxRuntimeOpSchemaRegistryList(), domain_to_version, + model_specific_functions, DefaultLoggingManager().DefaultLogger(), + ModelOptions(true, true)); + onnxruntime::Graph& graph = p_model->MainGraph(); + + TypeProto tensor_float; + tensor_float.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT); + + std::vector input_defs; + auto& input_arg_a = graph.GetOrCreateNodeArg("A", &tensor_float); + input_defs.push_back(&input_arg_a); + + auto& input_arg_b = graph.GetOrCreateNodeArg("B", &tensor_float); + input_defs.push_back(&input_arg_b); + + std::vector output_defs; + auto& output_arg = graph.GetOrCreateNodeArg("Y", &tensor_float); + output_defs.push_back(&output_arg); + + // Create a simple model + auto& node = graph.AddNode("node1", "MatMul", "MatMul", input_defs, output_defs, nullptr, onnxruntime::kOnnxDomain); + if (provider_type == kCpuExecutionProvider) { + node.SetExecutionProviderType(provider_type); + } else { +#if defined(USE_CUDA) || defined(USE_WEBGPU) + node.SetExecutionProviderType(provider_type); +#endif + } + Status status = graph.Resolve(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); +} + +void RunModelWithBindingMatMul(InferenceSession& session_object, + const RunOptions& run_options, + ProviderType bind_provider_type, + bool is_preallocate_output_vec, + ProviderType allocation_provider, + IExecutionProvider* gpu_provider, + OrtDevice* output_device, + bool enable_graph_capture) { + std::unique_ptr io_binding; + Status st = session_object.NewIOBinding(&io_binding); + ASSERT_TRUE(st.IsOK()); + + // bind a value to A with input that will produce invalid output in order to test replacement of a feed + std::vector values_mul_x_tmp = {12.f, 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}; + std::vector dims_mul_x_A_tmp = {3, 4}; + std::vector values_mul_x = {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f}; + std::vector dims_mul_x_A = {3, 4}; + std::vector dims_mul_x_B = {4, 3}; + + auto cpu_alloc = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + onnxruntime::AllocatorPtr gpu_alloc = nullptr; + if (allocation_provider == kWebGpuExecutionProvider) { + // Use session_object.GetAllocator to get the OrtAllocator for WebGPU. + // Otherwise, gpu_provider->CreatePreferredAllocators() will create a new OrtAllocator which will go to the create UMA path. + // And it can't be used for copying buffer to buffer since the target buffer is still in mapped state. + OrtMemoryInfo mem_info(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0)); + gpu_alloc = session_object.GetAllocator(mem_info); + } else if (allocation_provider == kCudaExecutionProvider) { + gpu_alloc = gpu_provider->CreatePreferredAllocators()[0]; + } + if (enable_graph_capture) { + // For graph capture, all inputs/outputs should be in preallocated gpu memory. + ASSERT_TRUE(is_preallocate_output_vec); + OrtValue input_ml_value_A_cpu; + CreateMLValue(cpu_alloc, dims_mul_x_A, values_mul_x, &input_ml_value_A_cpu); + auto& cpu_tensor_a = input_ml_value_A_cpu.Get(); + Tensor gpu_tensor_a(cpu_tensor_a.DataType(), cpu_tensor_a.Shape(), gpu_alloc); + st = gpu_provider->GetDataTransfer()->CopyTensor(cpu_tensor_a, gpu_tensor_a); + ASSERT_TRUE(st.IsOK()); + OrtValue input_ml_value_A; + Tensor::InitOrtValue(std::move(gpu_tensor_a), input_ml_value_A); + + OrtValue input_ml_value_B_cpu; + CreateMLValue(cpu_alloc, dims_mul_x_B, values_mul_x, &input_ml_value_B_cpu); + auto& cpu_tensor_b = input_ml_value_B_cpu.Get(); + Tensor gpu_tensor_b(cpu_tensor_b.DataType(), cpu_tensor_b.Shape(), gpu_alloc); + st = gpu_provider->GetDataTransfer()->CopyTensor(cpu_tensor_b, gpu_tensor_b); + ASSERT_TRUE(st.IsOK()); + OrtValue input_ml_value_B; + Tensor::InitOrtValue(std::move(gpu_tensor_b), input_ml_value_B); + + ASSERT_STATUS_OK(io_binding->BindInput("A", input_ml_value_A)); + ASSERT_STATUS_OK(io_binding->BindInput("B", input_ml_value_B)); + } else { + auto input_allocator = io_binding->GetCPUAllocator(bind_provider_type); + OrtValue input_tmp; + CreateMLValue(input_allocator, dims_mul_x_A_tmp, values_mul_x_tmp, &input_tmp); + ASSERT_STATUS_OK(io_binding->BindInput("A", input_tmp)); + const void* tmp_A = io_binding->GetInputs()[0].Get().DataRaw(); // location of data post binding + + // prepare inputs + /* + 0 1 2 3 0 1 2 + 4 5 6 7 3 4 5 + 8 9 10 11 6 7 8 + 9 10 11 + */ + // bind one input to cpu allocator from bind_provider_type, and another on user provided CPU memory + // so both code pathes are covered + OrtValue input_ml_value_A; + CreateMLValue(input_allocator, dims_mul_x_A, values_mul_x, &input_ml_value_A); + + OrtValue input_ml_value_B; + CreateMLValue(cpu_alloc, dims_mul_x_B, values_mul_x, &input_ml_value_B); + + ASSERT_STATUS_OK(io_binding->BindInput("A", input_ml_value_A)); + ASSERT_STATUS_OK(io_binding->BindInput("B", input_ml_value_B)); + + // check location of 'A' post-binding has changed to validate that the previous value was replaced + ASSERT_TRUE(io_binding->GetInputs()[0].Get().DataRaw() != tmp_A); + } + // prepare outputs + std::vector expected_output_dims = {3, 3}; + OrtValue output_ml_value; + if (is_preallocate_output_vec) { + if (allocation_provider == kCpuExecutionProvider) { + AllocateMLValue(cpu_alloc, expected_output_dims, &output_ml_value); + } else if (allocation_provider == kCudaExecutionProvider || allocation_provider == kWebGpuExecutionProvider) { + AllocateMLValue(gpu_alloc, expected_output_dims, &output_ml_value); + } else { + ORT_THROW("Unsupported provider"); + } + } + + if (output_device) { + // output should be allocated on specified device (if not preallocated here) + ASSERT_STATUS_OK(io_binding->BindOutput("Y", *output_device)); + } else { + ASSERT_STATUS_OK(io_binding->BindOutput("Y", output_ml_value)); + } + + ASSERT_TRUE(io_binding->SynchronizeInputs().IsOK()); + + // prepare expected inputs and outputs + std::vector expected_values_mul_y = {42, 48, 54, 114, 136, 158, 186, 224, 262}; + std::vector expected_values_mul_y_2 = {174, 216, 258, 102, 128, 154, 30, 40, 50}; + + // Now run + ASSERT_STATUS_OK(session_object.Run(run_options, *io_binding)); + + if ((is_preallocate_output_vec && (allocation_provider == kCudaExecutionProvider || allocation_provider == kWebGpuExecutionProvider)) || + (output_device && output_device->Type() == OrtDevice::GPU)) { +#if defined(USE_CUDA) || defined(USE_WEBGPU) + // in this case we need to copy the tensor from cuda to cpu + std::vector& outputs = io_binding->GetOutputs(); + ASSERT_EQ(1u, outputs.size()); + auto& rtensor = outputs.front().Get(); + auto element_type = rtensor.DataType(); + auto& shape = rtensor.Shape(); + Tensor cpu_tensor(element_type, shape, cpu_alloc); +#ifdef USE_CUDA + st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, cpu_tensor); +#endif +#ifdef USE_WEBGPU + st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, cpu_tensor); +#endif + ASSERT_TRUE(st.IsOK()); + OrtValue ml_value; + Tensor::InitOrtValue(std::move(cpu_tensor), ml_value); + VerifyOutputs({ml_value}, expected_output_dims, expected_values_mul_y); +#endif + } else { + if (allocation_provider == kCudaExecutionProvider || allocation_provider == kWebGpuExecutionProvider) { + ASSERT_STATUS_OK(gpu_provider->Sync()); + } + VerifyOutputs(io_binding->GetOutputs(), expected_output_dims, expected_values_mul_y); + } + + if (enable_graph_capture) { + // Update input_a's value. Run again. Replay the captured graph + OrtValue input_a2; + CreateMLValue(cpu_alloc, dims_mul_x_A_tmp, values_mul_x_tmp, &input_a2); + auto& cpu_tensor_a2 = input_a2.Get(); + st = gpu_provider->GetDataTransfer()->CopyTensor(cpu_tensor_a2, const_cast(io_binding->GetInputs()[0].Get())); + ASSERT_TRUE(st.IsOK()); + + st = session_object.Run(run_options, *io_binding.get()); + + std::cout << "Run returned status: " << st.ErrorMessage() << std::endl; + ASSERT_TRUE(st.IsOK()); + + // Copy the tensor from gpu to cpu + std::vector& outputs = io_binding->GetOutputs(); + ASSERT_EQ(1u, outputs.size()); + auto& rtensor = outputs.front().Get(); + auto element_type = rtensor.DataType(); + auto& shape = rtensor.Shape(); + std::unique_ptr cpu_tensor = std::make_unique(element_type, shape, cpu_alloc); + st = gpu_provider->GetDataTransfer()->CopyTensor(rtensor, *cpu_tensor.get()); + ASSERT_TRUE(st.IsOK()); + OrtValue ml_value; + ml_value.Init(cpu_tensor.release(), + DataTypeImpl::GetType(), + DataTypeImpl::GetType()->GetDeleteFunc()); + VerifyOutputs({ml_value}, expected_output_dims, expected_values_mul_y_2); + } +} + +static void TestBindHelper(const std::string& log_str, + ProviderType bind_provider_type, + ProviderType run_provider_type, + bool preallocate_output, + ProviderType allocation_provider = kCpuExecutionProvider, + OrtDevice* output_device = nullptr, + bool enable_graph_capture = false) { + SessionOptions so; + + so.session_logid = "InferenceSessionTests." + log_str; + so.session_log_verbosity_level = 1; // change to 1 for detailed logging + InferenceSession session_object{so, GetEnvironment()}; + IExecutionProvider* gpu_provider{}; + + if (bind_provider_type == kCudaExecutionProvider || bind_provider_type == kWebGpuExecutionProvider) { +#ifdef USE_CUDA + { + auto provider = DefaultCudaExecutionProvider(); + gpu_provider = provider.get(); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(provider))); + } +#endif +#ifdef USE_WEBGPU + { + ConfigOptions config_options{}; + ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kEnableGraphCapture, + enable_graph_capture ? webgpu::options::kEnableGraphCapture_ON : webgpu::options::kEnableGraphCapture_OFF) + .IsOK()); + auto provider = WebGpuExecutionProviderWithOptions(config_options); + gpu_provider = provider.get(); + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(provider))); + } +#endif + } + + std::unique_ptr p_model; + CreateMatMulModel(p_model, run_provider_type); + + std::string s1; + p_model->ToProto().SerializeToString(&s1); + std::stringstream sstr(s1); + ASSERT_STATUS_OK(session_object.Load(sstr)); + ASSERT_STATUS_OK(session_object.Initialize()); + + RunOptions run_options; + run_options.run_log_verbosity_level = so.session_log_verbosity_level; + run_options.run_tag = so.session_logid; + + RunModelWithBindingMatMul(session_object, + run_options, + bind_provider_type, + preallocate_output, + allocation_provider, + gpu_provider, + output_device, + enable_graph_capture); +} + +TEST(InferenceSessionTests, TestBindCpu) { + TestBindHelper("TestBindCpu", + kCpuExecutionProvider, + kCpuExecutionProvider, + false /* don't preallocate output */); +} + +TEST(InferenceSessionTests, TestIOBindingReuse) { + SessionOptions so; + InferenceSession session_object(so, GetEnvironment()); + std::unique_ptr p_model; + CreateMatMulModel(p_model, kCpuExecutionProvider); + + std::string s1; + p_model->ToProto().SerializeToString(&s1); + std::stringstream sstr(s1); + ASSERT_TRUE(session_object.Load(sstr).IsOK()); + ASSERT_STATUS_OK(session_object.Initialize()); + std::unique_ptr io_binding; + Status st = session_object.NewIOBinding(&io_binding); + ASSERT_TRUE(st.IsOK()); + + OrtValue ml_value1; + const std::vector v1{2.f}; + const int64_t shape[] = {1}; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], shape, v1, &ml_value1); + ASSERT_STATUS_OK(io_binding->BindOutput("foo", ml_value1)); + ASSERT_TRUE(io_binding->GetOutputs().size() == 1); + auto span = io_binding->GetOutputs()[0].Get().DataAsSpan(); + ASSERT_TRUE(static_cast(span.size()) == v1.size()); + for (size_t i = 0; i < v1.size(); ++i) { + ASSERT_TRUE(v1[i] == span[i]); + } + + OrtValue ml_value2; + const std::vector v2{3.f}; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], shape, v2, &ml_value2); + ASSERT_STATUS_OK(io_binding->BindOutput("foo", ml_value2)); + ASSERT_TRUE(io_binding->GetOutputs().size() == 1); + span = io_binding->GetOutputs()[0].Get().DataAsSpan(); + ASSERT_TRUE(static_cast(span.size()) == v2.size()); + for (size_t i = 0; i < v2.size(); ++i) { + ASSERT_TRUE(v2[i] == span[i]); + } +} + +#if defined(USE_CUDA) || defined(USE_WEBGPU) +#if USE_CUDA +constexpr const char* kGpuExecutionProvider = kCudaExecutionProvider; +#elif USE_WEBGPU +constexpr const char* kGpuExecutionProvider = kWebGpuExecutionProvider; +#endif + +TEST(InferenceSessionTests, TestBindCuda) { + TestBindHelper("TestBindCuda", + kGpuExecutionProvider, + kGpuExecutionProvider, + false /* don't preallocate output */); +} + +TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCuda) { + TestBindHelper("TestBindCudaPreallocateOutputOnCuda", + kGpuExecutionProvider, + kGpuExecutionProvider, + true /* preallocate output on GPU */, + kGpuExecutionProvider); +} + +TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCpu) { + TestBindHelper("TestBindCudaPreallocateOutputOnCpu", + kGpuExecutionProvider, + kGpuExecutionProvider, + true /* preallocate output on CPU */, + kCpuExecutionProvider); +} + +TEST(InferenceSessionTests, TestBindCudaPreallocateOutputOnCpu2) { + TestBindHelper("TestBindCudaPreallocateOutputOnCpu2", + kGpuExecutionProvider, + kCpuExecutionProvider, + true /* preallocate output on CPU */, + kCpuExecutionProvider); +} +#ifndef USE_WEBGPU +TEST(InferenceSessionTests, TestBindCudaSpecifyOutputDeviceOnCuda) { + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0); + + TestBindHelper("TestBindCudaPreallocateOutputOnCuda", + kGpuExecutionProvider, + kGpuExecutionProvider, + false /* preallocate output on GPU */, + kGpuExecutionProvider, + &device /* specify output device */); +} +#else +TEST(InferenceSessionTests, TestGraphCapture) { + TestBindHelper("TestGraphCapture", + kGpuExecutionProvider, + kGpuExecutionProvider, + true /* preallocate output on GPU */, + kGpuExecutionProvider, + nullptr, + true /* enable graph capture*/); +} +#endif // !USE_WEBGPU +#endif + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/unittest_util/framework_test_utils.h b/onnxruntime/test/unittest_util/framework_test_utils.h index 9c5893948ff1b..4870b4e5c5d7f 100644 --- a/onnxruntime/test/unittest_util/framework_test_utils.h +++ b/onnxruntime/test/unittest_util/framework_test_utils.h @@ -10,6 +10,7 @@ #include "core/framework/ort_value.h" #include +#include #ifdef USE_CUDA #include "core/providers/providers.h" @@ -128,5 +129,22 @@ inline int OpCount(const OpCountMap& op_count_map, const std::string& op_type) { void SparseIndicesChecker(const ONNX_NAMESPACE::TensorProto& indices_proto, gsl::span expected_indicies); #endif // DISABLE_SPARSE_TENSORS +template +void VerifyOutputs(const Tensor& tensor, const std::vector& expected_dims, + const std::vector& expected_values) { + TensorShape expected_shape(expected_dims); + ASSERT_EQ(expected_shape, tensor.Shape()); + const std::vector found(tensor.Data(), + tensor.Data() + expected_values.size()); + ASSERT_EQ(expected_values, found); +} + +inline void VerifyOutputs(const std::vector& fetches, const std::vector& expected_dims, + const std::vector& expected_values) { + ASSERT_EQ(1u, fetches.size()); + auto& rtensor = fetches.front().Get(); + VerifyOutputs(rtensor, expected_dims, expected_values); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc index fd2cf2f712628..3257cfc99e245 100644 --- a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.cc @@ -167,7 +167,7 @@ void Shutdown() { g_plugin_ep_infrastructure_state.reset(); } -std::unique_ptr MakeEp(const logging::Logger* logger) { +std::unique_ptr MakeEp(const logging::Logger* logger, const ConfigOptions* config_options) { if (!IsInitialized()) { return nullptr; } @@ -182,6 +182,13 @@ std::unique_ptr MakeEp(const logging::Logger* logger) { StrMapToKeyValueCstrVectors(state.config.default_ep_options, default_ep_option_key_cstrs, default_ep_option_value_cstrs); + if (config_options != nullptr) { + for (const auto& [key, value] : config_options->configurations) { + default_ep_option_key_cstrs.push_back(key.c_str()); + default_ep_option_value_cstrs.push_back(value.c_str()); + } + } + OrtSessionOptions ort_session_options{}; ORT_THROW_IF_ERROR(AddEpOptionsToSessionOptions(state.selected_c_ep_devices, default_ep_option_key_cstrs, diff --git a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h index 680045be9330c..091946da8fc26 100644 --- a/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h +++ b/onnxruntime/test/unittest_util/test_dynamic_plugin_ep.h @@ -17,6 +17,7 @@ namespace onnxruntime { struct IExecutionProviderFactory; class IExecutionProvider; +struct ConfigOptions; namespace logging { class Logger; @@ -74,7 +75,7 @@ bool IsInitialized(); void Shutdown(); // Returns a dynamic plugin EP `IExecutionProvider` instance, or `nullptr` if uninitialized. -std::unique_ptr MakeEp(const logging::Logger* logger = nullptr); +std::unique_ptr MakeEp(const logging::Logger* logger = nullptr, const ConfigOptions* config_options = nullptr); // Gets the dynamic plugin EP name, or `std::nullopt` if uninitialized. std::optional GetEpName(); diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 4bc300fc7263a..6936eddaa129f 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -14,8 +14,13 @@ #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_options.h" #endif +#if defined(USE_WEBGPU) +#include "core/graph/constants.h" +#include "core/session/abi_session_options_impl.h" +#endif #include "core/session/onnxruntime_cxx_api.h" #include "test/util/include/providers.h" +#include "test/unittest_util/test_dynamic_plugin_ep.h" namespace onnxruntime { @@ -282,19 +287,30 @@ std::unique_ptr DefaultXnnpackExecutionProvider() { } std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) { -#if defined(USE_WEBGPU) && defined(BUILD_WEBGPU_EP_STATIC_LIB) +#if defined(USE_WEBGPU) ConfigOptions config_options{}; +#if defined(BUILD_WEBGPU_EP_STATIC_LIB) + size_t config_entry_key_offset = 0; +#else + // used to remove the EP prefix from the config entry keys + size_t config_entry_key_offset = OrtSessionOptions::GetProviderOptionPrefix(kWebGpuExecutionProvider).length(); +#endif + // Disable storage buffer cache - ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode, + ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kStorageBufferCacheMode + config_entry_key_offset, webgpu::options::kBufferCacheMode_Disabled) .IsOK()); if (!is_nhwc) { // Enable NCHW support - ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kPreferredLayout, + ORT_ENFORCE(config_options.AddConfigEntry(webgpu::options::kPreferredLayout + config_entry_key_offset, webgpu::options::kPreferredLayout_NCHW) .IsOK()); } +#if defined(BUILD_WEBGPU_EP_STATIC_LIB) return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); +#else + return dynamic_plugin_ep_infra::MakeEp(nullptr, &config_options); +#endif #else ORT_UNUSED_PARAMETER(is_nhwc); return nullptr; @@ -302,8 +318,12 @@ std::unique_ptr DefaultWebGpuExecutionProvider(bool is_nhwc) } std::unique_ptr WebGpuExecutionProviderWithOptions(const ConfigOptions& config_options) { -#if defined(USE_WEBGPU) && defined(BUILD_WEBGPU_EP_STATIC_LIB) +#if defined(USE_WEBGPU) +#if defined(BUILD_WEBGPU_EP_STATIC_LIB) return WebGpuProviderFactoryCreator::Create(config_options)->CreateProvider(); +#else + return dynamic_plugin_ep_infra::MakeEp(nullptr, &config_options); +#endif #else ORT_UNUSED_PARAMETER(config_options); return nullptr;