Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] WebGPU EP [skip ci] #21904

Closed
wants to merge 167 commits into from
Closed
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
167 commits
Select commit Hold shift + click to select a range
4037bd4
[WIP] WebGPU EP initial commit
fs-eire Aug 28, 2024
9c36250
update C-API
fs-eire Aug 28, 2024
3a0756d
fix build break
fs-eire Aug 28, 2024
5199e98
add an empty symbols.txt file
fs-eire Aug 28, 2024
1c68dbd
fix an error in doc
fs-eire Aug 29, 2024
7db03de
remove string_join.h in favor of absl::StrJoin
fs-eire Aug 29, 2024
6a373c2
fix DLL copy
fs-eire Aug 29, 2024
ee42bba
update doc: require --skip_tests
fs-eire Aug 29, 2024
5fac202
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Aug 29, 2024
3f46e5c
update dawn version
fs-eire Aug 29, 2024
9f61279
disable Tint tests
fs-eire Aug 29, 2024
6bb6335
fix one build break in Linux
fs-eire Aug 29, 2024
d839dbc
remove unused variables
fs-eire Aug 30, 2024
b70943d
make webgpu build on linux and known to most tools (#21937)
guschmue Aug 30, 2024
c33ac2e
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Aug 30, 2024
8437267
revert type of ShaderVariable::rank_ to int
fs-eire Aug 30, 2024
3caf032
output Impl() for variables
fs-eire Aug 30, 2024
84494c4
code formatting
fs-eire Aug 30, 2024
aa70163
better format of Uniform
fs-eire Aug 30, 2024
d772db7
revise document
fs-eire Aug 30, 2024
6ef3dad
more build fix for linux
fs-eire Aug 31, 2024
a56f6c3
apply formatter
fs-eire Aug 31, 2024
12cd79d
simple test runner
fs-eire Aug 31, 2024
14c8966
Program macros update - allow extend
fs-eire Aug 31, 2024
4fff35f
fix BucketCacheManager
fs-eire Sep 1, 2024
4fd8ad1
add a method to get logger from ComputeContext
fs-eire Sep 1, 2024
3bd92ad
add verbose log for cache key
fs-eire Sep 1, 2024
6a1bbfe
revise suite test
fs-eire Sep 1, 2024
947aee1
device lost handler
fs-eire Sep 1, 2024
99b2578
add '-a' and '-t' to test runner
fs-eire Sep 1, 2024
aa7b3f5
atol/rtol 0.0001 -> 0.001
fs-eire Sep 1, 2024
e659acd
Fix uniform
fs-eire Sep 2, 2024
6ad89c5
add some unary ops
fs-eire Sep 2, 2024
8361fc3
various of fixes
fs-eire Sep 2, 2024
c89159d
fix workgroup_size, cache key stringnify and indices type
fs-eire Sep 3, 2024
5ea5936
shape_uniforms preparation
fs-eire Sep 3, 2024
7d83054
allow uniforms of input/output shape/stride being added automatically
fs-eire Sep 3, 2024
7a64cc7
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 3, 2024
1d53ac8
fix build (linux)
fs-eire Sep 3, 2024
4d52602
fix stride
fs-eire Sep 3, 2024
3761aad
fix "{res_name}_bi2o_{name}"
fs-eire Sep 3, 2024
351da84
Add Expand operator (#21933)
qjia7 Sep 3, 2024
0b7ce77
support onnxruntime_test_all
fs-eire Sep 3, 2024
33726b1
reflect change in WebGpuProviderFactoryCreator::Create signature (#21…
guschmue Sep 3, 2024
50ea9eb
compare the content of WEBGPU_BUFFER, not the address (#21967)
guschmue Sep 3, 2024
d6f6148
fix tanh
fs-eire Sep 3, 2024
626edaf
support size==0 for element wise operators
fs-eire Sep 4, 2024
8913da1
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 4, 2024
bacc54c
use shared ComputeBroadcastOutputShape()
fs-eire Sep 4, 2024
7ecc5bb
add workgroup_idx
fs-eire Sep 4, 2024
ae836b1
expose name for shader variable
fs-eire Sep 4, 2024
243078b
add uniform for 1D variable
fs-eire Sep 5, 2024
4d48d28
fix GetElementAt with uniform
fs-eire Sep 5, 2024
dbe673b
document update folder
fs-eire Sep 5, 2024
38f182e
fix adapter/device creating: add toggles
fs-eire Sep 5, 2024
eb80f7c
more strict shape&stride usage check
fs-eire Sep 6, 2024
39d5509
fix vector realloc
fs-eire Sep 6, 2024
cd961c3
simplify cache hint interface.
fs-eire Sep 6, 2024
ddc2fbb
revise expand
fs-eire Sep 6, 2024
e8be835
revise unary
fs-eire Sep 6, 2024
bd7d592
Elu/Relu/LeakyRelu/ThresholdedRelu/Gelu
fs-eire Sep 6, 2024
eecac18
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 6, 2024
601e50f
remove unused field in class Gelu
fs-eire Sep 6, 2024
8f36da2
remove out-of-dated comments
fs-eire Sep 6, 2024
72ebd85
Clip
fs-eire Sep 7, 2024
a3244ae
fix rank in shader helper
fs-eire Sep 7, 2024
5a2ae8c
fix shader variable
fs-eire Sep 9, 2024
aa54ff8
move components number from variable to program
fs-eire Sep 9, 2024
969384d
mark components in cache key
fs-eire Sep 9, 2024
6b82486
Add FastGelu op (#21991)
qjia7 Sep 10, 2024
2b3e7c2
use 'set/add' as prefix for some functions
fs-eire Sep 10, 2024
ef0d53b
remove unnecessary cache hint for FastGelu
fs-eire Sep 10, 2024
c4ca47f
revise unary - expose consts in header
fs-eire Sep 10, 2024
8806d57
use path for header file
fs-eire Sep 10, 2024
0568e2b
a few revises to the code (#22047)
fs-eire Sep 10, 2024
b7a9c0e
use OrtMutex
fs-eire Sep 11, 2024
f65ade9
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 11, 2024
d4a963d
[webgpu-native] Add transpose op (#21986)
axinging Sep 11, 2024
8b61532
PushErrorScope and PopErrorScope
fs-eire Sep 11, 2024
dce0f18
placeholder for setting proc table
fs-eire Sep 12, 2024
8978d89
Revert "placeholder for setting proc table"
fs-eire Sep 12, 2024
43ccaf4
allow setting "ValidationMode"
fs-eire Sep 12, 2024
eae4c3f
make shape/stride correct when component != 1
fs-eire Sep 13, 2024
b8c369d
expose number of components
fs-eire Sep 13, 2024
c3086d6
Fix build errors
skottmckay Sep 13, 2024
c5cf2ab
[WebGPU EP] Support Shape operator (#22095)
satyajandhyala Sep 14, 2024
0bc714f
[webgpu EP] Binary operators (#22112)
fs-eire Sep 17, 2024
4421676
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 17, 2024
2e91a8b
use f32 for pow anyway
fs-eire Sep 17, 2024
87f9edb
Cast operator
fs-eire Sep 17, 2024
19ee9f3
do not use virtual function for getting ProgramMetadata
fs-eire Sep 17, 2024
d9f7f19
reshape, squeeze and unsqueeze
fs-eire Sep 18, 2024
07675cf
fix Cast and Clip
fs-eire Sep 18, 2024
dfab322
[webgpu-native] Add where op (#22014)
axinging Sep 20, 2024
cb9f3a4
fix linux build break
fs-eire Sep 20, 2024
207be92
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Sep 24, 2024
929725e
expose KernelContext
fs-eire Sep 25, 2024
c5e5af3
revise fast gelu
fs-eire Sep 25, 2024
82cd59e
expose Rank in IndicesHelper
fs-eire Sep 25, 2024
2393dbf
fix: move inline impl to .h
fs-eire Sep 25, 2024
9bdbd85
add const modifier
fs-eire Sep 25, 2024
0101ce8
remove toggle "disable_workgroup_init"
fs-eire Sep 25, 2024
3896706
set backend type to D3D12 since we always uses dxc (win).
fs-eire Sep 25, 2024
f02e85a
update build configurations to webgpu EP (#22047)
fs-eire Sep 25, 2024
e5233ce
enable build pipeline on Windows for WebGPU
fs-eire Sep 26, 2024
0f7a5f6
[webgpu native] Add RotaryEmbedding op (#22194)
axinging Sep 27, 2024
41f6ff3
[webgpu native] Add transpose shared (#22098)
axinging Sep 27, 2024
b1b5e1f
[webgpu-native] Add gather (#22183)
qjia7 Sep 27, 2024
92a08e2
[Native-WebGPU] Add Concat (#22225)
satyajandhyala Sep 27, 2024
8da1f7a
[webgpu-native] Add MatmulNBits (#22150)
qjia7 Sep 27, 2024
f9b6b7c
[WebGPU-Native] Tile Operator (#22239)
prathikr Sep 30, 2024
c1ae1fd
use Abseil OStringStream in WebGPU EP string concat (#22241)
fs-eire Sep 30, 2024
b574f2c
Range
fs-eire Sep 30, 2024
14ea5db
webgpu: support MultiHeadAttention operator (#22144)
xhcao Sep 30, 2024
c70441e
[webgpu-native] support for webgpu layernorms (#22249)
guschmue Oct 1, 2024
468c720
nodejs binding support webgpu
fs-eire Oct 1, 2024
cbf106e
fix where
fs-eire Oct 1, 2024
bce7a98
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Oct 1, 2024
5086c7c
revert some changes that are not necessary
fs-eire Oct 1, 2024
fe7d3e4
revise perftest help msg
fs-eire Oct 1, 2024
d219bb7
[webgpu-native] Fix a few build errors on Linux (#22286)
snnn Oct 1, 2024
7f7d6da
format
fs-eire Oct 1, 2024
c561ed6
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Oct 1, 2024
27640e3
fix issues for e2e phi3 (#22287)
guschmue Oct 1, 2024
4129cd6
fix perf problem: force Flush by end of session
fs-eire Oct 3, 2024
5fd65e7
Uniform buffer mode: LazyRelease -> Simple
fs-eire Oct 3, 2024
dcf2062
nodejs binding support IO binding for webgpu
fs-eire Oct 3, 2024
481111b
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Oct 4, 2024
b84401d
fix matmul test after conflict resolve
fs-eire Oct 4, 2024
08434d2
a few build fixes
fs-eire Oct 4, 2024
1b01583
fix build break in android build
fs-eire Oct 4, 2024
130dc9b
fix duplicate "it"
fs-eire Oct 4, 2024
646a744
always disable DAWN_ENABLE_SPIRV_VALIDATION
fs-eire Oct 4, 2024
da6406b
Enable OBJC/OBJCXX for all projects if necessary
fs-eire Oct 5, 2024
53ff621
minimal webgpu io-binding support for python (#22334)
guschmue Oct 7, 2024
74b5131
reset dispatch count
fs-eire Oct 8, 2024
4f4efcb
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Oct 9, 2024
0a8f872
remove unnecessary initialization options in test
fs-eire Oct 9, 2024
3f104fb
support ORT profiling in node.js
fs-eire Oct 9, 2024
8261ca6
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Oct 11, 2024
e7d05ba
[webgpu-native] support webgpu profiling (#22255)
qjia7 Oct 11, 2024
613ad6d
check ValidationMode for push/pop error scope
fs-eire Oct 11, 2024
f4bb64e
[WebGPU EP] Remove unused variable. (#22412)
edgchen1 Oct 11, 2024
cf8c478
fix typo
fs-eire Oct 14, 2024
be95c28
add "WebGPU:" prefix for config entries
fs-eire Oct 14, 2024
e76fe1f
nodejs support EP config
fs-eire Oct 14, 2024
c7c1e82
support session options "optimizedModelFilePath" and "extra"
fs-eire Oct 15, 2024
cae9dab
allow to set a list for specifying force CPU nodes (#22431)
fs-eire Oct 16, 2024
2beeaf5
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Oct 16, 2024
2c27a59
fix bug in Gather
fs-eire Oct 17, 2024
de6cf6b
Fix Unsqueeze
fs-eire Oct 17, 2024
59298ff
fix supports_device() in python interface
fs-eire Oct 17, 2024
f18a65f
fix linux build
fs-eire Oct 17, 2024
cc340e2
enable pybind
fs-eire Oct 17, 2024
8553aec
fix mac build
fs-eire Oct 17, 2024
289145c
add test_dynamicquantizelinear_expanded_cpu
fs-eire Oct 17, 2024
f2ee91c
remove temp tests
fs-eire Oct 17, 2024
3b462d6
add power preference
fs-eire Oct 18, 2024
0d49662
exclude dx11
fs-eire Oct 18, 2024
ddef640
WIN32 macro -> _WIN32
fs-eire Oct 18, 2024
d312b38
[webgpu-native] opt matmulnbits (#22472)
qjia7 Oct 21, 2024
89b4549
webgpu ep support dawn proc table (#22509)
fs-eire Oct 21, 2024
dc01790
[webgpu-native] Fix CI errors on MacOS (#22535)
qjia7 Oct 22, 2024
26ff482
Merge remote-tracking branch 'origin/main' into fs-eire/webgpu-ep
fs-eire Oct 22, 2024
872149b
replace OrtMutex
fs-eire Oct 22, 2024
a00f5b8
allow build flag --use_external_dawn (#22552)
fs-eire Oct 24, 2024
d4764a8
update dawn patch
fs-eire Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ option(onnxruntime_TVM_USE_LLVM "Build TVM with LLVM. Set customized path to llv
option(onnxruntime_TVM_USE_HASH "Build ipp-crypto library for support hash algorithm. It is defined for TVM only")
option(onnxruntime_USE_XNNPACK "Build with XNNPACK support. Provides an alternative math library on ARM, WebAssembly and x86." OFF)
option(onnxruntime_USE_WEBNN "Build with WebNN support. Enable hardware acceleration in web browsers." OFF)
option(onnxruntime_USE_WEBGPU "Build with WebGPU support. Enable WebGPU via C/C++ interface." OFF)

# Options related to reducing the binary size produced by the build
# XNNPACK EP requires the internal NHWC contrib ops to be available, so this option must be OFF when onnxruntime_USE_XNNPACK is ON
Expand Down Expand Up @@ -907,6 +908,11 @@ if (onnxruntime_USE_WEBNN)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBNN=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES webnn)
endif()
if (onnxruntime_USE_WEBGPU)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_WEBGPU=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_WEBGPU=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES webgpu)
endif()
if (onnxruntime_USE_CANN)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CANN=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CANN=1)
Expand Down
1 change: 1 addition & 0 deletions cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d839
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/204da9c522cebec5220bba52cd3542ebcaf99e7a.zip;1827348efd47831c13074245274d41b7cae8a557
directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e
cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.5.2.zip;11071a47594b20f00af09aad83e0d5203ccf6029
dawn;https://github.com/google/dawn/archive/511eb80847afe6bded34ec491a38d5d78ba2d604.zip;c493f5aca5586f6634e25d0121c85df71189fb99
12 changes: 12 additions & 0 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,18 @@ if (onnxruntime_USE_COREML)
FetchContent_Populate(coremltools)
endif()

if (onnxruntime_USE_WEBGPU)
FetchContent_Declare(
dawn
URL ${DEP_URL_dawn}
URL_HASH SHA1=${DEP_SHA1_dawn}
)
set(DAWN_FETCH_DEPENDENCIES ON)
set(DAWN_ENABLE_INSTALL ON)
set(TINT_BUILD_TESTS OFF)
onnxruntime_fetchcontent_makeavailable(dawn)
endif()

message("Finished fetching external dependencies")

set(onnxruntime_LINK_DIRS )
Expand Down
7 changes: 7 additions & 0 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ endif()
if(onnxruntime_USE_WEBNN)
set(PROVIDERS_WEBNN onnxruntime_providers_webnn)
endif()
if(onnxruntime_USE_WEBGPU)
set(PROVIDERS_WEBGPU onnxruntime_providers_webgpu)
endif()
if (onnxruntime_USE_CANN)
set(PROVIDERS_CANN onnxruntime_providers_cann)
endif()
Expand Down Expand Up @@ -151,6 +154,10 @@ if (onnxruntime_USE_WEBNN)
include(onnxruntime_providers_webnn.cmake)
endif()

if (onnxruntime_USE_WEBGPU)
include(onnxruntime_providers_webgpu.cmake)
endif()

if (onnxruntime_USE_NNAPI_BUILTIN)
include(onnxruntime_providers_nnapi.cmake)
endif()
Expand Down
7 changes: 6 additions & 1 deletion cmake/onnxruntime_providers_cpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ file(GLOB_RECURSE onnxruntime_js_contrib_ops_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/js/*.cc"
)

file(GLOB_RECURSE onnxruntime_webgpu_contrib_ops_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/webgpu/*.cc"
)

file(GLOB onnxruntime_providers_common_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/*.cc"
Expand All @@ -60,7 +65,7 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc"
)
endif()
set(onnxruntime_cpu_neural_speed_srcs
set(onnxruntime_cpu_neural_speed_srcs
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc"
Expand Down
37 changes: 37 additions & 0 deletions cmake/onnxruntime_providers_webgpu.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

if (onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_EXTENDED_MINIMAL_BUILD)
message(FATAL_ERROR "WebGPU EP can not be used in a basic minimal build. Please build with '--minimal_build extended'")
endif()

# find_package(Dawn REQUIRED)

add_compile_definitions(USE_WEBGPU=1)
if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
add_definitions(-DENABLE_WEBASSEMBLY_THREADS=1)
endif()
file(GLOB_RECURSE onnxruntime_providers_webgpu_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/webgpu/*.cc"
# "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.h"
# "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc"
)
if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_webgpu_contrib_ops_cc_srcs})
list(APPEND onnxruntime_providers_webgpu_cc_srcs ${onnxruntime_webgpu_contrib_ops_cc_srcs})
endif()

source_group(TREE ${REPO_ROOT} FILES ${onnxruntime_providers_webgpu_cc_srcs})
onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface)
target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn)

# Copy webgpu_dawn.dll to the output directory
add_custom_command(
TARGET onnxruntime_providers_webgpu
POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different "$<TARGET_FILE:dawn::webgpu_dawn>" "$<TARGET_FILE_DIR:onnxruntime_providers_webgpu>"
VERBATIM )

set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime")
2 changes: 1 addition & 1 deletion cmake/onnxruntime_providers_webnn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@

add_dependencies(onnxruntime_providers_webnn onnx ${onnxruntime_EXTERNAL_DEPENDENCIES})
set_target_properties(onnxruntime_providers_webnn PROPERTIES FOLDER "ONNXRuntime")
set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(onnxruntime_providers_webnn PROPERTIES LINKER_LANGUAGE CXX)
12 changes: 12 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,10 @@ if(onnxruntime_USE_JSEP)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_js)
endif()

if(onnxruntime_USE_WEBGPU)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_webgpu)
endif()

if(onnxruntime_USE_RKNPU)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_rknpu)
endif()
Expand Down Expand Up @@ -598,6 +602,7 @@ set(ONNXRUNTIME_TEST_LIBS
${PROVIDERS_NNAPI}
${PROVIDERS_VSINPU}
${PROVIDERS_JS}
${PROVIDERS_WEBGPU}
${PROVIDERS_QNN}
${PROVIDERS_SNPE}
${PROVIDERS_RKNPU}
Expand Down Expand Up @@ -658,6 +663,13 @@ if(onnxruntime_USE_JSEP)
list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_js)
endif()

if(onnxruntime_USE_WEBGPU)
list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/webgpu/*)
list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_webgpu)
list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_webgpu)
list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_webgpu)
endif()

# QNN EP tests require CPU EP op implementations for accuracy evaluation, so disable on minimal
# or reduced op builds.
if(onnxruntime_USE_QNN AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD)
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/graph/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ constexpr const char* kSnpeExecutionProvider = "SNPEExecutionProvider";
constexpr const char* kTvmExecutionProvider = "TvmExecutionProvider";
constexpr const char* kXnnpackExecutionProvider = "XnnpackExecutionProvider";
constexpr const char* kWebNNExecutionProvider = "WebNNExecutionProvider";
constexpr const char* kWebGpuExecutionProvider = "WebGpuExecutionProvider";
constexpr const char* kCannExecutionProvider = "CANNExecutionProvider";
constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider";
constexpr const char* kVSINPUExecutionProvider = "VSINPUExecutionProvider";
Expand Down
63 changes: 63 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,38 @@ typedef struct OrtMIGraphXProviderOptions {
bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false
} OrtMIGraphXProviderOptions;

/** \brief WebGPU Execution Provider Options
*
* When a user wants to use WebGPU as the execution provider, there are 2 ways to specify the WebGPU device:
*
* 1. Use the default WebGPU device. The default WebGPU device is managed by WebGPU EP internally. The user doesn't
* need to provide any device information in this case. All the fields should be set to nullptr or 0.
*
* 2. Use a custom WebGPU device. The user should create their own handles of `WGPUInstance`, `WGPUAdapter`, and
* `WGPUDevice` and use arbitrary number in [1..65536) as the device id. The user should provide the handles
* and the device id in the options.
*
* When specifying an existing Device ID, the user should provide the handles of `WGPUInstance`, `WGPUAdapter`, and
* `WGPUDevice` in the options. The device id should be the same as the one used previously.
*
* It's user's responsibility to manage the lifecycle of the handles and ensure the handles are valid during the
* lifetime of the inference session.
*
* About DawnProcTable:
*
* When using an ONNX Runtime build that is not directly linked dawn during the build, a pointer to the runtime memory
* address of the DawnProcTable should be provided. Otherwise, keep it as nullptr.
*
* \see OrtApi::SessionOptionsAppendExecutionProvider_WGPU
*/
typedef struct OrtWGPUProviderOptions {
int device_id; // WebGPU device id.
void* instance_handle; // WebGPU instance handle.
void* adapter_handle; // WebGPU adapter handle.
void* device_handle; // WebGPU device handle.
void* dawn_proc_table; // DawnProcTable pointer.
} OrtWGPUProviderOptions;

/** \brief OpenVINO Provider Options
*
* \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO
Expand Down Expand Up @@ -4667,6 +4699,37 @@ struct OrtApi {
_In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array,
_In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths,
size_t num_external_initializer_files);

/** \brief Append WebGPU execution provider to session options
*
* If WebGPU is not available, this function will return failure.
*
* \param[in] options
* \param[in] wgpu_options - specify the WebGPU provider options.
* \param[in] string_options_keys - keys to configure the string options
* \param[in] string_options_values - values to configure the string options
* \param[in] num_keys - number of keys passed in
*
* Supported keys are listed as below. All entries are optional.
*
* | Key | Possible Values | Default Value |
* | ------------------------------ | ---------------------------------------------- | -------------- |
* | "preferredLayout" | "NHWC" or "NCHW" | "NHWC" |
* | "enableGraphCapture" | "1" or "0" | "0" |
* | "storageBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "bucket" |
* | "uniformBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "lazyRelease" |
* | "queryResolveBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "disabled" |
* | "defaultBufferCacheMode" | "disabled", "lazyRelease", "simple", "bucket" | "disabled" |
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.20.
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_WGPU,
fs-eire marked this conversation as resolved.
Show resolved Hide resolved
_In_ OrtSessionOptions* options, _In_ const OrtWGPUProviderOptions* wgpu_options,
_In_reads_(num_keys) const char* const* string_options_keys,
_In_reads_(num_keys) const char* const* string_options_values,
_In_ size_t num_keys);
};

/*
Expand Down
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,9 @@
SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl
SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_WGPU
SessionOptionsImpl& AppendExecutionProvider_WGPU(const OrtWGPUProviderOptions& wgpu_options,
const std::unordered_map<std::string, std::string>& string_options = {});

Check warning on line 895 in include/onnxruntime/core/session/onnxruntime_cxx_api.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_api.h:895: Lines should be <= 120 characters long [whitespace/line_length] [2]
/// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK.
SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
const std::unordered_map<std::string, std::string>& provider_options = {});
Expand Down
19 changes: 19 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,25 @@
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_WGPU(const OrtWGPUProviderOptions& wgpu_options,

Check warning on line 842 in include/onnxruntime/core/session/onnxruntime_cxx_inline.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_inline.h:842: Lines should be <= 120 characters long [whitespace/line_length] [2]
const std::unordered_map<std::string, std::string>& string_options) {

Check warning on line 843 in include/onnxruntime/core/session/onnxruntime_cxx_inline.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_inline.h:843: Lines should be <= 120 characters long [whitespace/line_length] [2]
auto num_entries = string_options.size();
std::vector<const char*> keys, values;
if (num_entries > 0) {
keys.reserve(num_entries);
values.reserve(num_entries);

for (const auto& entry : string_options) {
keys.push_back(entry.first.c_str());
values.push_back(entry.second.c_str());
}
}

ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_WGPU(this->p_, &wgpu_options, keys.data(), values.data(), num_entries));

Check warning on line 856 in include/onnxruntime/core/session/onnxruntime_cxx_inline.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_inline.h:856: Lines should be <= 120 characters long [whitespace/line_length] [2]
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options) {
ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_CANN(this->p_, &provider_options));
Expand Down
70 changes: 70 additions & 0 deletions onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/webgpu/webgpu_contrib_kernels.h"

#include "core/framework/op_kernel.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention);
// LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, SkipSimplifiedLayerNormalization);

// template <>
// KernelCreateInfo BuildKernelCreateInfo<void>() {
// KernelCreateInfo info;
// return info;
// }

Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention)>,

Check warning on line 38 in onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc:38: Lines should be <= 120 characters long [whitespace/line_length] [2]
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,

Check warning on line 39 in onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc:39: Lines should be <= 120 characters long [whitespace/line_length] [2]
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,

Check warning on line 40 in onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc:40: Lines should be <= 120 characters long [whitespace/line_length] [2]
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,

Check warning on line 41 in onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc:41: Lines should be <= 120 characters long [whitespace/line_length] [2]
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,

Check warning on line 42 in onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc:42: Lines should be <= 120 characters long [whitespace/line_length] [2]
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,

Check warning on line 43 in onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc:43: Lines should be <= 120 characters long [whitespace/line_length] [2]
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
// // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
// SimplifiedLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipSimplifiedLayerNormalization)>
};

for (auto& function_table_entry : function_table) {
KernelCreateInfo info = function_table_entry();
if (info.kernel_def != nullptr) { // filter disabled entries where type is void
ORT_RETURN_IF_ERROR(kernel_registry.Register(std::move(info)));
}
}
return Status::OK();
}

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
17 changes: 17 additions & 0 deletions onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/framework/op_kernel.h"
#include "core/framework/kernel_registry.h"

namespace onnxruntime {
namespace contrib {
namespace webgpu {

Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry);

} // namespace webgpu
} // namespace contrib
} // namespace onnxruntime
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/get_execution_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,14 @@ constexpr ProviderInfo kProvidersInPriorityOrder[] =
true,
#else
false,
#endif
},
{
kWebGpuExecutionProvider,
#ifdef USE_WEBGPU
true,
#else
false,
#endif
},
{
Expand Down
Loading
Loading