diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 03ea773a25130..bc2d8117930bc 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -11,4 +11,4 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - uses: gradle/wrapper-validation-action@v1 + - uses: gradle/wrapper-validation-action@v2 diff --git a/.github/workflows/publish-csharp-apidocs.yml b/.github/workflows/publish-csharp-apidocs.yml index c03399f4693be..5bc21595bf882 100644 --- a/.github/workflows/publish-csharp-apidocs.yml +++ b/.github/workflows/publish-csharp-apidocs.yml @@ -37,7 +37,7 @@ jobs: wget https://github.com/dotnet/docfx/releases/download/v${DOCFXVERSION}/docfx-linux-x64-v${DOCFXVERSION}.zip -O build/docfx/docfx.zip unzip build/docfx/docfx.zip -d build/docfx - name: Install NuGet - uses: nuget/setup-nuget@v1 + uses: nuget/setup-nuget@v2 - name: Build Documentation run: | build/docfx/docfx metadata csharp/ApiDocs/docfx.json diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index c94e3fa5bcb8c..181f3fb17d332 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -13,7 +13,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v9.0.0 + - uses: actions/stale@v8 with: # Comma separated list of labels that can be assigned to issues to exclude them from being marked as stale exempt-issue-labels: contributions welcome, feature request, regression diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 90fe8276ea9c7..c9be4aa65d0cc 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -117,8 +117,7 @@ option(onnxruntime_CROSS_COMPILING "Cross compiling onnx runtime" OFF) option(onnxruntime_GCOV_COVERAGE "Compile with options necessary to run code coverage" OFF) option(onnxruntime_DONT_VECTORIZE "Do not vectorize operations in Eigen" OFF) -#It's preferred to turn it OFF when onnxruntime is dynamically linked to PROTOBUF. But Tensort always required the full version of protobuf. -cmake_dependent_option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF "NOT onnxruntime_USE_TENSORRT" ON) +option(onnxruntime_USE_FULL_PROTOBUF "Link to libprotobuf instead of libprotobuf-lite when this option is ON" OFF) option(tensorflow_C_PACKAGE_PATH "Path to tensorflow C package installation dir") option(onnxruntime_ENABLE_LANGUAGE_INTEROP_OPS "Enable operator implemented in language other than cpp" OFF) option(onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS "Dump debug information about node inputs and outputs when executing the model." OFF) @@ -1601,7 +1600,7 @@ if (UNIX AND onnxruntime_USE_NCCL) else() set(onnxruntime_USE_NCCL OFF) set(onnxruntime_USE_MPI OFF) -message( WARNING "MPI and NCCL disabled on Win build." ) + message( WARNING "MPI and NCCL are disabled because build is on Windows or USE_NCCL is set to OFF." ) endif() if (onnxruntime_USE_MPI) diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index c6c9d8f4894c5..7e7819ac31a19 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -66,11 +66,7 @@ if(onnxruntime_USE_CUDA) set(PROVIDERS_CUDA onnxruntime_providers_cuda) endif() if(onnxruntime_USE_COREML) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - set(PROVIDERS_COREML onnxruntime_providers_coreml coreml_proto) - else() - set(PROVIDERS_COREML onnxruntime_providers_coreml) - endif() + set(PROVIDERS_COREML onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_NNAPI_BUILTIN) set(PROVIDERS_NNAPI onnxruntime_providers_nnapi) diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake index 2ca4a22aca7d2..c9f35e5337f9b 100644 --- a/cmake/onnxruntime_providers_coreml.cmake +++ b/cmake/onnxruntime_providers_coreml.cmake @@ -7,6 +7,27 @@ endif() add_compile_definitions(USE_COREML=1) +# Check if we can build the coremltools code for creating an mlpackage with an mlprogram. +# The coremltools source requires std::filesystem::path which is only available from iOS 13 on. +set(_enable_ML_PROGRAM ON) +if (IOS AND CMAKE_OSX_DEPLOYMENT_TARGET VERSION_LESS 13.0) + message(WARNING "CoreML ML Program is not supported on iOS < 13.0. Excluding ML Program support from build.") + set(_enable_ML_PROGRAM OFF) +elseif(LINUX) + # uuid-dev is required. we don't bother installing on CIs as it's really for manual developer testing. + find_library(LibUUID_LIBRARY NAMES uuid) + find_path(LibUUID_INCLUDE_DIR NAMES uuid/uuid.h) + if (NOT LibUUID_INCLUDE_DIR) + message(STATUS "uuid/uuid.h was not found as is required for ML Program support. " + "Run `sudo apt install uuid-dev` if you need to test ML Program related CoreML EP code. ") + set(_enable_ML_PROGRAM OFF) + endif() +endif() + +if (_enable_ML_PROGRAM) + add_compile_definitions(COREML_ENABLE_MLPROGRAM=1) +endif() + # Compile CoreML proto definition to ${CMAKE_CURRENT_BINARY_DIR}/coreml_proto set(COREML_PROTO_ROOT ${coremltools_SOURCE_DIR}/mlmodel/format) file(GLOB coreml_proto_srcs "${COREML_PROTO_ROOT}/*.proto") @@ -19,8 +40,8 @@ target_compile_definitions(coreml_proto PUBLIC $) set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility=hidden") set_target_properties(coreml_proto PROPERTIES COMPILE_FLAGS "-fvisibility-inlines-hidden") -set(_src_sub_dir "coreml_proto/") +set(_src_sub_dir "coreml_proto/") onnxruntime_protobuf_generate( APPEND_PATH GEN_SRC_SUB_DIR ${_src_sub_dir} @@ -55,6 +76,10 @@ file(GLOB_RECURSE onnxruntime_providers_shared_utils_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/shared/utils/utils.cc" ) +file(GLOB onnxruntime_providers_coreml_public_headers CONFIGURE_DEPENDS + "${ONNXRUNTIME_INCLUDE_DIR}/core/providers/coreml/*.h" +) + file(GLOB onnxruntime_providers_coreml_cc_srcs_top CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/coreml/*.h" @@ -67,15 +92,38 @@ file(GLOB_RECURSE "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.h" "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/*.cc" ) -if (NOT CMAKE_SYSTEM_NAME STREQUAL "Darwin" AND NOT CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(REMOVE_ITEM onnxruntime_providers_coreml_cc_srcs_nested - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.h" - "${ONNXRUNTIME_ROOT}/core/providers/coreml/builders/model_builder.cc" + +if(_enable_ML_PROGRAM) + # Add helpers to create mlpackage weights. limit to just the files we need to minimize the changes to make them + # build on Windows and Linux. + file(GLOB + onnxruntime_providers_coreml_milblob_cc_srcs CONFIGURE_DEPENDS + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/*.cpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Util/*.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/BlobDataType.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageFormat.hpp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/FileWriter.?pp" + "${coremltools_SOURCE_DIR}/mlmodel/src/MILBlob/Blob/StorageWriter.?pp" + ) + + # Add helpers to create mlpackage + file(GLOB + onnxruntime_providers_coreml_modelpackage_cc_srcs CONFIGURE_DEPENDS + "${coremltools_SOURCE_DIR}/modelpackage/src/ModelPackage.?pp" + "${coremltools_SOURCE_DIR}/modelpackage/src/Utils/JsonMap.?pp" ) + + set(coremltools_srcs + ${onnxruntime_providers_coreml_milblob_cc_srcs} + ${onnxruntime_providers_coreml_modelpackage_cc_srcs} + ) + + source_group(TREE ${coremltools_SOURCE_DIR} PREFIX coremltools FILES ${coremltools_srcs}) endif() # Add CoreML objective c++ source code -if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") +if (APPLE) file(GLOB onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" @@ -83,26 +131,79 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm" ) +else() + # add the Model implementation that uses the protobuf types but excludes any actual CoreML dependencies + # by using stub implementations on non-Apple platforms. + file(GLOB + onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils_stub.cc" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h" + "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model_stub.cc" + ) endif() set(onnxruntime_providers_coreml_cc_srcs ${onnxruntime_providers_coreml_cc_srcs_top} ${onnxruntime_providers_coreml_cc_srcs_nested} ${onnxruntime_providers_shared_utils_cc_srcs} + ${onnxruntime_providers_coreml_objcc_srcs} ) -source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_coreml_cc_srcs}) +source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_providers_coreml_cc_srcs}) +source_group(TREE ${ONNXRUNTIME_INCLUDE_DIR} FILES ${onnxruntime_providers_coreml_public_headers}) + onnxruntime_add_static_library(onnxruntime_providers_coreml - ${onnxruntime_providers_coreml_cc_srcs} ${onnxruntime_providers_coreml_objcc_srcs} + ${onnxruntime_providers_coreml_public_headers} + ${onnxruntime_providers_coreml_cc_srcs} + ${coremltools_srcs} ) + onnxruntime_add_include_to_target(onnxruntime_providers_coreml - onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface + onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 + safeint_interface ) -if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - onnxruntime_add_include_to_target(onnxruntime_providers_coreml coreml_proto) - target_link_libraries(onnxruntime_providers_coreml PRIVATE coreml_proto "-framework Foundation" "-framework CoreML") - add_dependencies(onnxruntime_providers_coreml coreml_proto) + +onnxruntime_add_include_to_target(onnxruntime_providers_coreml coreml_proto) +target_link_libraries(onnxruntime_providers_coreml PRIVATE coreml_proto) +add_dependencies(onnxruntime_providers_coreml coreml_proto) + +if (APPLE) + target_compile_definitions(onnxruntime_providers_coreml PRIVATE __APPLE__) endif() + +if (_enable_ML_PROGRAM) + # Setup coremltools fp16 and json dependencies for creating an mlpackage. + # + # These are also used by external/xnnpack.cmake. fp16 depends on psimd + FetchContent_Declare(psimd URL ${DEP_URL_psimd} URL_HASH SHA1=${DEP_SHA1_psimd}) + onnxruntime_fetchcontent_makeavailable(psimd) + set(PSIMD_SOURCE_DIR ${psimd_SOURCE_DIR}) + FetchContent_Declare(fp16 URL ${DEP_URL_fp16} URL_HASH SHA1=${DEP_SHA1_fp16}) + set(FP16_BUILD_TESTS OFF CACHE INTERNAL "") + set(FP16_BUILD_BENCHMARKS OFF CACHE INTERNAL "") + onnxruntime_fetchcontent_makeavailable(fp16) + + # need to tweak the include paths to match what the coreml source code expects + target_include_directories(onnxruntime_providers_coreml PRIVATE + ${fp16_SOURCE_DIR}/include + ${nlohmann_json_SOURCE_DIR}/single_include/nlohmann + ${coremltools_SOURCE_DIR} + ${coremltools_SOURCE_DIR}/mlmodel/src/ + ${coremltools_SOURCE_DIR}/modelpackage/src/ + ) + + add_dependencies(onnxruntime_providers_coreml nlohmann_json::nlohmann_json fp16) + + if (LINUX) + target_link_libraries(onnxruntime_providers_coreml PRIVATE uuid) + endif() +endif() + +if (APPLE) + target_link_libraries(onnxruntime_providers_coreml PRIVATE "-framework Foundation" "-framework CoreML") +endif() + add_dependencies(onnxruntime_providers_coreml ${onnxruntime_EXTERNAL_DEPENDENCIES}) set_target_properties(onnxruntime_providers_coreml PROPERTIES CXX_STANDARD_REQUIRED ON) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index d485abe6bb1a6..85a9bf50460d3 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -44,12 +44,7 @@ set(contrib_ops_excluded_files "bert/packed_multihead_attention.cc" "bert/packed_multihead_attention_impl.h" "bert/packed_multihead_attention_impl.cu" - "diffusion/group_norm.cc" "diffusion/group_norm_impl.cu" - "diffusion/group_norm_impl.h" - "diffusion/group_norm_impl_kernel.cuh" - "diffusion/group_norm_common_base.h" - "diffusion/group_norm_common_base.cc" "diffusion/nhwc_conv.cc" "math/gemm_float8.cc" "math/gemm_float8.cu" diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 308caad296831..3ed695327c183 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -567,11 +567,7 @@ if(onnxruntime_USE_ROCM) endif() if(onnxruntime_USE_COREML) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) - else() - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) - endif() + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_ACL) @@ -676,15 +672,9 @@ endif() if(onnxruntime_USE_COREML) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/coreml/*) - if (CMAKE_SYSTEM_NAME STREQUAL "Darwin" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) - list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto) - else() - list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml) - list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml) - endif() + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_coreml coreml_proto) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_coreml coreml_proto) + list(APPEND onnxruntime_test_providers_libs onnxruntime_providers_coreml coreml_proto) endif() if(onnxruntime_USE_XNNPACK) diff --git a/cmake/winml.cmake b/cmake/winml.cmake index 268ee3960e75a..57cecd3e66adb 100644 --- a/cmake/winml.cmake +++ b/cmake/winml.cmake @@ -827,6 +827,7 @@ if (winml_is_inbox) get_target_property(compile_options ${target} COMPILE_OPTIONS) get_target_property(include_directories ${target} INCLUDE_DIRECTORIES) get_target_property(link_libraries ${target} LINK_LIBRARIES) + get_target_property(link_flags ${target} LINK_FLAGS) get_target_property(link_options ${target} LINK_OPTIONS) add_library(${new_target} SHARED ${sources}) @@ -835,6 +836,7 @@ if (winml_is_inbox) target_compile_options(${new_target} PRIVATE ${compile_options}) target_include_directories(${new_target} PRIVATE ${include_directories}) target_link_libraries(${new_target} PRIVATE ${link_libraries}) + set_property(TARGET ${new_target} PROPERTY LINK_FLAGS "${link_flags}") target_link_options(${new_target} PRIVATE ${link_options}) endfunction() diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e7b537d6894c8..f523e97293427 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -461,7 +461,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -2252,7 +2252,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5154,7 +5154,7 @@ This version of the operator has been available since version 1 of the 'com.micr
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : I
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : I
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5743,12 +5743,14 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
beginning_timestamp_token_id : int
+
The id of the first timestamp
decoder : graph (required)
Decoder subgraph to execute in a loop.
decoder_output_cross_qk : int
If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.
decoder_start_token_id : int
-
The id of the token that indicates decoding starts.
+
The id of the token that indicates decoding starts (i.e. the start of transcription token id)
early_stopping : int
early stop or not
encoder : graph
@@ -5761,10 +5763,18 @@ This version of the operator has been available since version 1 of the 'com.micr
Must be 2 for whisper
no_repeat_ngram_size : int
no repeat ngrams size
-
no_speech_token : int
+
no_speech_token_id : int
The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.
+
no_timestamps_token_id : int
+
The id of the token that indicates no timestamps
pad_token_id : int (required)
The id of the padding token
+
start_of_lm_token_id : int
+
The id of the token that indicates LM starts
+
transcribe_token_id : int
+
The id of the transcribe task
+
translate_token_id : int
+
The id of the translate task
vocab_size : int
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
@@ -5783,11 +5793,11 @@ This version of the operator has been available since version 1 of the 'com.micr
num_return_sequences : I
The number of returned sequences in the batch. Shape is (1)
length_penalty (optional) : T
-
Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+
Exponential penalty to the length. Default value 1.0 means no penalty. Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. Shape is (1,)
repetition_penalty (optional) : T
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
vocab_mask (optional) : M
-
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)
prefix_vocab_mask (optional) : M
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
attention_mask (optional) : I
@@ -5797,7 +5807,7 @@ This version of the operator has been available since version 1 of the 'com.micr
logits_processor (optional) : I
Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)
cross_qk_layer_head (optional) : I
-
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
+
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
temperature (optional) : T
@@ -5812,11 +5822,11 @@ This version of the operator has been available since version 1 of the 'com.micr
sequences_scores (optional) : T
Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
scores (optional) : T
-
Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
cross_qk (optional) : V
-
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers,B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F].If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
+
Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]
non_speech_probs (optional) : T
-
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token.Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph.The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]
+
For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. The shape of non_speech_probs is [B]
#### Type Constraints diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 2ea557b7d61fe..8ff2135c6b1f6 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -765,7 +765,7 @@ Do not modify directly.* |Sigmoid|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Sign|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(double), tensor(float), tensor(float16)| +|SimplifiedLayerNormalization|*in* X:**T**
*in* scale:**V**
*out* Y:**V**
*out* inv_std_var:**U**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**U** = tensor(double), tensor(float)
**V** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(double), tensor(float), tensor(float16)| |Size|*in* data:**T**
*out* size:**T1**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| @@ -784,7 +784,7 @@ Do not modify directly.* |||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h index 03715eb5b78b2..55abb90b981f5 100644 --- a/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h +++ b/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h @@ -28,9 +28,12 @@ enum COREMLFlags { // dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes. COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES = 0x008, + // Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. + COREML_FLAG_CREATE_MLPROGRAM = 0x010, + // Keep COREML_FLAG_LAST at the end of the enum definition // And assign the last COREMLFlag to it - COREML_FLAG_LAST = COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES, + COREML_FLAG_LAST = COREML_FLAG_CREATE_MLPROGRAM, }; #ifdef __cplusplus diff --git a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java index eb124decf75f3..cec3fadf446ca 100644 --- a/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java +++ b/java/src/main/java/ai/onnxruntime/providers/CoreMLFlags.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2023, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2024, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime.providers; @@ -14,7 +14,18 @@ public enum CoreMLFlags implements OrtFlags { /** Enables CoreML on subgraphs. */ ENABLE_ON_SUBGRAPH(2), // COREML_FLAG_ENABLE_ON_SUBGRAPH(0x002) /** Only enable usage of CoreML if the device has an Apple Neural Engine. */ - ONLY_ENABLE_DEVICE_WITH_ANE(4); // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004), + ONLY_ENABLE_DEVICE_WITH_ANE(4), // COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE(0x004) + /** + * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also + * allow inputs with dynamic shapes. However, the performance may be negatively impacted if inputs + * have dynamic shapes. + */ + ONLY_ALLOW_STATIC_INPUT_SHAPES(8), // COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES(0x008) + /** + * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or + * later. + */ + CREATE_MLPROGRAM(16); // COREML_FLAG_CREATE_MLPROGRAM(0x010) /** The native value of the enum. */ public final int value; diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 7fef2dc784b7b..9925197e4507c 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -673,7 +673,7 @@ private void runProvider(OrtProvider provider) throws OrtException { // CoreML gives slightly different answers on a 2020 13" M1 MBP assertArrayEquals(expectedOutput, resultArray, 1e-2f); } else { - assertArrayEquals(expectedOutput, resultArray, 1e-6f); + assertArrayEquals(expectedOutput, resultArray, 1e-5f); } } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); diff --git a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java index 1ed883ace36e5..0e3bc15ba9c70 100644 --- a/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java +++ b/java/src/test/java/ai/onnxruntime/providers/ProviderOptionsTest.java @@ -96,7 +96,7 @@ private static void runProvider(OrtProvider provider, OrtSession.SessionOptions OnnxValue resultTensor = result.get(0); float[] resultArray = TestHelpers.flattenFloat(resultTensor.getValue()); assertEquals(expectedOutput.length, resultArray.length); - assertArrayEquals(expectedOutput, resultArray, 1e-6f); + assertArrayEquals(expectedOutput, resultArray, 1e-5f); } catch (OrtException e) { throw new IllegalStateException("Failed to execute a scoring operation", e); } diff --git a/js/common/lib/tensor-impl-type-mapping.ts b/js/common/lib/tensor-impl-type-mapping.ts index c4a43ea27fea1..b29cb8cbd6d35 100644 --- a/js/common/lib/tensor-impl-type-mapping.ts +++ b/js/common/lib/tensor-impl-type-mapping.ts @@ -14,7 +14,6 @@ export const NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP = new Map { - if (!isBigIntChecked) { - isBigIntChecked = true; - const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && typeof BigInt64Array.from === 'function'; - const isBigUint64ArrayAvailable = - typeof BigUint64Array !== 'undefined' && typeof BigUint64Array.from === 'function'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + +// the following code allows delaying execution of BigInt/Float16Array checking. This allows lazy initialization for +// NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP and NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, which allows BigInt/Float16Array +// polyfill if available. +let isTypedArrayChecked = false; +export const checkTypedArray = () => { + if (!isTypedArrayChecked) { + isTypedArrayChecked = true; + const isBigInt64ArrayAvailable = typeof BigInt64Array !== 'undefined' && BigInt64Array.from; + const isBigUint64ArrayAvailable = typeof BigUint64Array !== 'undefined' && BigUint64Array.from; + const isFloat16ArrayAvailable = typeof Float16Array !== 'undefined' && Float16Array.from; if (isBigInt64ArrayAvailable) { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('int64', BigInt64Array); @@ -53,5 +58,12 @@ export const checkBigInt = () => { NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('uint64', BigUint64Array); NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(BigUint64Array, 'uint64'); } + if (isFloat16ArrayAvailable) { + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Float16Array); + NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP.set(Float16Array, 'float16'); + } else { + // if Float16Array is not available, use 'Uint16Array' to store the data. + NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP.set('float16', Uint16Array); + } } }; diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index e3e2b9c728556..56682ef98e117 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -5,7 +5,7 @@ import {tensorToDataURL, tensorToImageData} from './tensor-conversion-impl.js'; import {TensorToDataUrlOptions, TensorToImageDataOptions} from './tensor-conversion.js'; import {tensorFromGpuBuffer, tensorFromImage, tensorFromPinnedBuffer, tensorFromTexture} from './tensor-factory-impl.js'; import {CpuPinnedConstructorParameters, GpuBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters} from './tensor-factory.js'; -import {checkBigInt, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; +import {checkTypedArray, NUMERIC_TENSOR_TYPE_TO_TYPEDARRAY_MAP, NUMERIC_TENSOR_TYPEDARRAY_TO_TYPE_MAP, SupportedTypedArray, SupportedTypedArrayConstructors} from './tensor-impl-type-mapping.js'; import {calculateSize, tensorReshape} from './tensor-utils-impl.js'; import {Tensor as TensorInterface} from './tensor.js'; @@ -67,8 +67,8 @@ export class Tensor implements TensorInterface { arg0: TensorType|TensorDataType|readonly string[]|readonly boolean[]|CpuPinnedConstructorParameters| TextureConstructorParameters|GpuBufferConstructorParameters, arg1?: TensorDataType|readonly number[]|readonly string[]|readonly boolean[], arg2?: readonly number[]) { - // perform one-time check for BigInt support - checkBigInt(); + // perform one-time check for BigInt/Float16Array support + checkTypedArray(); let type: TensorType; let dims: readonly number[]; @@ -103,7 +103,7 @@ export class Tensor implements TensorInterface { } case 'gpu-buffer': { if ((type !== 'float32' && type !== 'float16' && type !== 'int32' && type !== 'int64' && type !== 'uint32' && - type !== 'bool')) { + type !== 'uint8' && type !== 'bool')) { throw new TypeError(`unsupported type "${type}" to create tensor from gpu buffer`); } this.gpuBufferData = arg0.gpuBuffer; @@ -142,7 +142,9 @@ export class Tensor implements TensorInterface { throw new TypeError(`Unsupported tensor type: ${arg0}.`); } if (Array.isArray(arg1)) { - if (arg0 === 'float16') { + if (arg0 === 'float16' && typedArrayConstructor === Uint16Array) { + // When no Float16Array polyfill is used, we cannot create 'float16' tensor from number array. + // // Throw error here because when user try to use number array as data, // e.g. new Tensor('float16', [1, 2, 3, 4], dims)), it will actually call // Uint16Array.from(arg1) which generates wrong data. diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 6c08d1fe8e057..d5da33640dc7d 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -135,7 +135,7 @@ export declare namespace Tensor { /** * supported data types for constructing a tensor from a WebGPU buffer */ - export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'bool'; + export type GpuBufferDataTypes = 'float32'|'float16'|'int32'|'int64'|'uint32'|'uint8'|'bool'; /** * represent where the tensor data is stored diff --git a/js/common/package-lock.json b/js/common/package-lock.json index a5ada877b916a..3988ac80707e0 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -9,13 +9,13 @@ "version": "1.18.0", "license": "MIT", "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" } }, "node_modules/ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "node_modules/balanced-match": { @@ -34,9 +34,9 @@ } }, "node_modules/jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "node_modules/lunr": { @@ -46,9 +46,9 @@ "dev": true }, "node_modules/marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true, "bin": { "marked": "bin/marked.js" @@ -58,24 +58,24 @@ } }, "node_modules/minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "dependencies": { "brace-expansion": "^2.0.1" }, "engines": { - "node": ">=10" + "node": ">=16 || 14 >=14.17" }, "funding": { "url": "https://github.com/sponsors/isaacs" } }, "node_modules/shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "dependencies": { "ansi-sequence-parser": "^1.1.0", @@ -85,30 +85,30 @@ } }, "node_modules/typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "dependencies": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" }, "bin": { "typedoc": "bin/typedoc" }, "engines": { - "node": ">= 14.14" + "node": ">= 16" }, "peerDependencies": { - "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x" + "typescript": "4.6.x || 4.7.x || 4.8.x || 4.9.x || 5.0.x || 5.1.x || 5.2.x || 5.3.x" } }, "node_modules/typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true, "bin": { @@ -116,7 +116,7 @@ "tsserver": "bin/tsserver" }, "engines": { - "node": ">=4.2.0" + "node": ">=14.17" } }, "node_modules/vscode-oniguruma": { @@ -134,9 +134,9 @@ }, "dependencies": { "ansi-sequence-parser": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.0.tgz", - "integrity": "sha512-lEm8mt52to2fT8GhciPCGeCXACSz2UwIN4X2e2LJSnZ5uAbn2/dsYdOmUXq0AtWS5cpAupysIneExOgH0Vd2TQ==", + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ansi-sequence-parser/-/ansi-sequence-parser-1.1.1.tgz", + "integrity": "sha512-vJXt3yiaUL4UU546s3rPXlsry/RnM730G1+HkpKE012AN0sx1eOrxSu95oKDIonskeLTijMgqWZ3uDEe3NFvyg==", "dev": true }, "balanced-match": { @@ -155,9 +155,9 @@ } }, "jsonc-parser": { - "version": "3.2.0", - "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.0.tgz", - "integrity": "sha512-gfFQZrcTc8CnKXp6Y4/CBT3fTc0OVuDofpre4aEeEpSBPV5X5v4+Vmx+8snU7RLPrNHPKSgLxGo9YuQzz20o+w==", + "version": "3.2.1", + "resolved": "https://registry.npmjs.org/jsonc-parser/-/jsonc-parser-3.2.1.tgz", + "integrity": "sha512-AilxAyFOAcK5wA1+LeaySVBrHsGQvUFCDWXKpZjzaL0PqW+xfBOttn8GNtWKFWqneyMZj41MWF9Kl6iPWLwgOA==", "dev": true }, "lunr": { @@ -167,24 +167,24 @@ "dev": true }, "marked": { - "version": "4.2.12", - "resolved": "https://registry.npmjs.org/marked/-/marked-4.2.12.tgz", - "integrity": "sha512-yr8hSKa3Fv4D3jdZmtMMPghgVt6TWbk86WQaWhDloQjRSQhMMYCAro7jP7VDJrjjdV8pxVxMssXS8B8Y5DZ5aw==", + "version": "4.3.0", + "resolved": "https://registry.npmjs.org/marked/-/marked-4.3.0.tgz", + "integrity": "sha512-PRsaiG84bK+AMvxziE/lCFss8juXjNaWzVbN5tXAm4XjeaS9NAHhop+PjQxz2A9h8Q4M/xGmzP8vqNwy6JeK0A==", "dev": true }, "minimatch": { - "version": "7.4.2", - "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-7.4.2.tgz", - "integrity": "sha512-xy4q7wou3vUoC9k1xGTXc+awNdGaGVHtFUaey8tiX4H1QRc04DZ/rmDFwNm2EBsuYEhAZ6SgMmYf3InGY6OauA==", + "version": "9.0.3", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.3.tgz", + "integrity": "sha512-RHiac9mvaRw0x3AYRgDC1CxAP7HTcNrrECeA8YYJeWnpo+2Q5CegtZjaotWTWxDG3UeGA1coE05iH1mPjT/2mg==", "dev": true, "requires": { "brace-expansion": "^2.0.1" } }, "shiki": { - "version": "0.14.1", - "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.1.tgz", - "integrity": "sha512-+Jz4nBkCBe0mEDqo1eKRcCdjRtrCjozmcbTUjbPTX7OOJfEbTZzlUWlZtGe3Gb5oV1/jnojhG//YZc3rs9zSEw==", + "version": "0.14.7", + "resolved": "https://registry.npmjs.org/shiki/-/shiki-0.14.7.tgz", + "integrity": "sha512-dNPAPrxSc87ua2sKJ3H5dQ/6ZaY8RNnaAqK+t0eG7p0Soi2ydiqbGOTaZCqaYvA/uZYfS1LJnemt3Q+mSfcPCg==", "dev": true, "requires": { "ansi-sequence-parser": "^1.1.0", @@ -194,21 +194,21 @@ } }, "typedoc": { - "version": "0.23.26", - "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.23.26.tgz", - "integrity": "sha512-5m4KwR5tOLnk0OtMaRn9IdbeRM32uPemN9kur7YK9wFqx8U0CYrvO9aVq6ysdZSV1c824BTm+BuQl2Ze/k1HtA==", + "version": "0.25.7", + "resolved": "https://registry.npmjs.org/typedoc/-/typedoc-0.25.7.tgz", + "integrity": "sha512-m6A6JjQRg39p2ZVRIN3NKXgrN8vzlHhOS+r9ymUYtcUP/TIQPvWSq7YgE5ZjASfv5Vd5BW5xrir6Gm2XNNcOow==", "dev": true, "requires": { "lunr": "^2.3.9", - "marked": "^4.2.12", - "minimatch": "^7.1.3", - "shiki": "^0.14.1" + "marked": "^4.3.0", + "minimatch": "^9.0.3", + "shiki": "^0.14.7" } }, "typescript": { - "version": "4.9.5", - "resolved": "https://registry.npmjs.org/typescript/-/typescript-4.9.5.tgz", - "integrity": "sha512-1FXk9E2Hm+QzZQ7z+McJiHL4NW1F2EzMu9Nq9i3zAaGqibafqYwCVU6WyWAuyQRRzOlxou8xZSyXLEN8oKj24g==", + "version": "5.2.2", + "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.2.2.tgz", + "integrity": "sha512-mI4WrpHsbCIcwT9cF4FZvr80QUeKvsUsUvKDoR+X/7XHQH98xYD8YHZg7ANtz2GtZt/CBq2QJ0thkGJMHfqc1w==", "dev": true, "peer": true }, diff --git a/js/common/package.json b/js/common/package.json index 64ab2736adbe3..cd2612aab4984 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -9,7 +9,7 @@ }, "author": "fs-eire", "scripts": { - "build:cjs": "tsc --module commonjs --outDir ./dist/cjs", + "build:cjs": "tsc --module commonjs --moduleResolution node10 --outDir ./dist/cjs", "build:esm": "tsc", "build:bundles": "webpack", "build": "node ./build.js", @@ -18,7 +18,7 @@ "test": "mocha ./test/**/*.js --timeout 30000" }, "devDependencies": { - "typedoc": "^0.23.22" + "typedoc": "^0.25.7" }, "main": "dist/cjs/index.js", "exports": { diff --git a/js/common/test/tsconfig.json b/js/common/test/tsconfig.json index 2e4927ac3b325..e9068ad837a81 100644 --- a/js/common/test/tsconfig.json +++ b/js/common/test/tsconfig.json @@ -2,7 +2,7 @@ "extends": "../../tsconfig.tools.json", "exclude": ["type-tests/**/*.ts"], "compilerOptions": { - "module": "ES2022", + "module": "Node16", "sourceMap": true } } diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index b21af8e715db3..4a8c92bb97bfd 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -62,6 +62,7 @@ Do not modify directly.* | LessOrEqual | ai.onnx(12-15,16+) | | | Log | ai.onnx(6-12,13+) | | | MatMul | ai.onnx(1-12,13+) | | +| MatMulNBits | com.microsoft(1+) | | | MaxPool | ai.onnx(1-7,8-9,10,11,12+); com.ms.internal.nhwc(1-7,8-9,10,11,12+) | need perf optimization; need implementing activation | | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index 6922d7ff5df6e..c0517ce363644 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -92,6 +92,34 @@ export class ShapeUtil { return ShapeUtil.getSizeFromDimensionRange(dims, 0, dims.length); } + /** + * convert dims corresponding to type change to pack. ex. uint8 data to uint32 + */ + static convertShape(dims: readonly number[], size = 4): readonly number[] { + const rank = dims.length; + if (rank === 0) { + return []; + } + const newDims = new Array(rank); + let i = rank - 1; + while (i >= 0) { + if (dims[i] % size === 0) { + newDims[i] = dims[i] / size; + break; + } + if (size % dims[i] !== 0) { + throw new Error('cannot convert shape'); + } + newDims[i] = 1; + size /= dims[i]; + i--; + } + for (i--; i >= 0; i--) { + newDims[i] = dims[i]; + } + return newDims; + } + /** * calculate the size (number of elements) from the given axis (inclusive) */ diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index ac08c5fb1f7ab..ba874c8dd0f80 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -20,6 +20,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm'; import {instanceNorm} from './ops/instance-norm'; import {layerNorm} from './ops/layer-norm'; import {matMul} from './ops/matmul'; +import {matMulNBits, parseMatMulNBitsAttributes} from './ops/matmulnbits'; import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion'; import {pad} from './ops/pad'; import * as pool from './ops/pool'; @@ -92,6 +93,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['LessOrEqual', [binaryOps.lessOrEqual]], ['Log', [unaryOps.log]], ['MatMul', [matMul]], + ['MatMulNBits', [matMulNBits, parseMatMulNBitsAttributes]], // TODO: support new attributes for MaxPool-8 and MaxPool-10 ['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]], ['Mul', [binaryOps.mul]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 3f73d9cb7c5bc..d5f97213e49ce 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -85,28 +85,28 @@ const createLayerNormProgramInfo = ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.norm_count')} let offset = global_idx * uniforms.norm_size_vectorized; - var meanVector = ${fillVector('f32', components)}; - var meanSquareVector = ${fillVector('f32', components)}; + var mean_vector = ${fillVector('f32', components)}; + var mean_square_vector = ${fillVector('f32', components)}; for (var h: u32 = 0u; h < uniforms.norm_size_vectorized; h++) { let value = ${castToF32(dataType, components, 'x[h + offset]')}; - meanVector += value; - meanSquareVector += value * value; + mean_vector += value; + mean_square_vector += value * value; } - let mean = ${sumVector('meanVector', components)} / uniforms.norm_size; - let invStdDev = - inverseSqrt(${sumVector('meanSquareVector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); + let mean = ${sumVector('mean_vector', components)} / uniforms.norm_size; + let inv_std_dev = inverseSqrt(${ + sumVector('mean_square_vector', components)} / uniforms.norm_size - mean * mean + uniforms.epsilon); for (var j: u32 = 0; j < uniforms.norm_size_vectorized; j++) { let f32input = ${castToF32(dataType, components, 'x[j + offset]')}; let f32scale = ${castToF32(dataType, components, 'scale[j]')}; - output[j + offset] = ${variables[0].type.value}((f32input - mean) * invStdDev * f32scale + output[j + offset] = ${variables[0].type.value}((f32input - mean) * inv_std_dev * f32scale ${bias ? `+ ${castToF32(dataType, components, 'bias[j]')}` : ''} ); } ${hasMeanDataOutput ? 'mean_data_output[global_idx] = mean' : ''}; - ${hasInvStdOutput ? 'inv_std_output[global_idx] = invStdDev' : ''}; + ${hasInvStdOutput ? 'inv_std_output[global_idx] = inv_std_dev' : ''}; }`; }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts new file mode 100644 index 0000000000000..ead7635cf3ac4 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmulnbits.ts @@ -0,0 +1,184 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; + +import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; + +// TODO support quantization bits not equal to 4 +export interface MatMulNBitsAttributes extends AttributeWithCacheKey { + k: number; + n: number; + accuracyLevel: number; + bits: number; + blockSize: number; +} + +const validateInputs = (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): void => { + if (inputs.length < 3 || inputs.length > 4) { + throw new Error('MatMulNBits requires 3 or 4 inputs'); + } + const a = inputs[0]; + const aRank = a.dims.length; + if (a.dims[aRank - 1] !== attributes.k) { + throw new Error('The last dim of input shape does not match the k value'); + } + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const b = inputs[1]; + if (!ShapeUtil.areEqual(b.dims, [attributes.n, nBlocksPerCol, blobSize])) { + throw new Error('The second inputs must be 3D tensor with shape N X nBlocksPerCol X blobSize'); + } + const scales = inputs[2]; + const scalesShape = scales.dims; + if (ShapeUtil.size(scalesShape) !== attributes.n * nBlocksPerCol) { + throw new Error('scales input size error.'); + } + if (inputs.length === 4) { + const zeroPoints = inputs[3]; + const zeroPointsShape = zeroPoints.dims; + const expectedZeroPointsSize = + attributes.bits > 4 ? (attributes.n * nBlocksPerCol) : attributes.n * Math.floor((nBlocksPerCol + 1) / 2); + if (ShapeUtil.size(zeroPointsShape) !== expectedZeroPointsSize) { + throw new Error('zeroPoints input size error.'); + } + } +}; + +export const createMatMulNBitsProgramInfo = + (inputs: readonly TensorView[], attributes: MatMulNBitsAttributes): ProgramInfo => { + const a = inputs[0]; + const b = inputs[1]; + const scales = inputs[2]; + const aRank = a.dims.length; + const outputShape = a.dims.slice(0, aRank - 1).concat(attributes.n); + const outputSize = ShapeUtil.size(outputShape); + + + const programUniforms: ProgramUniform[] = [ + {type: DataType.uint32, data: outputSize}, {type: DataType.uint32, data: attributes.k}, + {type: DataType.uint32, data: attributes.n}, {type: DataType.uint32, data: attributes.accuracyLevel}, + {type: DataType.uint32, data: attributes.bits}, {type: DataType.uint32, data: attributes.blockSize} + ]; + programUniforms.push(...createTensorShapeVariables(a.dims)); + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(b.dims))); + programUniforms.push(...createTensorShapeVariables(scales.dims)); + if (inputs.length === 4) { + programUniforms.push(...createTensorShapeVariables(ShapeUtil.convertShape(inputs[3].dims))); + } + programUniforms.push(...createTensorShapeVariables(outputShape)); + const getShaderSource = (shaderHelper: ShaderHelper) => { + const a = inputVariable('a', inputs[0].dataType, inputs[0].dims.length); + const b = inputVariable('b', DataType.uint32, inputs[1].dims.length); + const scales = inputVariable('scales', inputs[2].dataType, inputs[2].dims.length); + const inputVariables = [a, b, scales]; + const zeroPoints = + inputs.length === 4 ? inputVariable('zero_points', DataType.uint32, inputs[3].dims.length) : undefined; + if (zeroPoints) { + inputVariables.push(zeroPoints); + } + const output = outputVariable('output', inputs[0].dataType, outputShape.length); + const uniforms: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, {name: 'k', type: 'u32'}, {name: 'n', type: 'u32'}, + {name: 'accuracy_level', type: 'u32'}, {name: 'bits', type: 'u32'}, {name: 'block_size', type: 'u32'} + ]; + const nBlocksPerCol = Math.floor((attributes.k + attributes.blockSize - 1) / attributes.blockSize); + const blobSize = attributes.blockSize / 8 * attributes.bits; + const wordPerBlob = blobSize / 4; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return ` + fn ortUnpack8x4snorm(value: u32) -> array<${dataType}, 8>{ + var result = array<${dataType}, 8>(); + var offset: u32 = 0; + let count: u32 = 4; + for (var i: u32 = 0; i < 8u; i++) { + result[i] = ${dataType}(extractBits(value, offset, count)); + offset += count; + } + return result; + } + ${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')} + var value: ${dataType} = 0.0; + let output_indices = ${output.offsetToIndices('global_idx')}; + var a_indices: ${a.type.indices} = output_indices; + var n = ${output.indicesGet('output_indices', aRank - 1)}; + // Two zero points are packed into one byte because uniforms.bits <= 4. + // zero_point_offset is either 0 or 4. It is bit offset within one byte. + // TODO support zero_point_offset for bits > 4 + ${ + zeroPoints ? ` + var zero_point_index: u32 = n * ((${nBlocksPerCol} + 1) / 2) / 4; + var zero_point_word: u32 = ${zeroPoints.getByOffset('zero_point_index')}; + var zero_point_offset: u32 = 0;` : + ''} + var scale_idex = n * ${nBlocksPerCol}; + var b_indices: ${b.type.indices}; + ${b.indicesSet('b_indices', '0', 'n')}; + var block_offset: u32 = 0; + for (var block: u32 = 0; block < ${nBlocksPerCol}; block++) { + // The scale and zero points are computed per block. + let scale = ${scales.getByOffset('scale_idex')}; + // The default zero point is 8 for unsigned 4-bit quantization. + let zero_point: ${dataType} = ${ + zeroPoints ? `${dataType}(extractBits(zero_point_word, zero_point_offset, 4))` : 8.0}; + ${b.indicesSet('b_indices', '1', 'block')}; + var word_offset: u32 = block_offset; + for (var word: u32 = 0; word < ${wordPerBlob}; word++) { + ${b.indicesSet('b_indices', '2', 'word')}; + let b_value = ${b.getByIndices('b_indices')}; + let b_quantized_values: array<${dataType}, 8> = ortUnpack8x4snorm(b_value); + // Number of B elements per 32-bit word is 32/bits = 32/4 = 8 + var offset: u32 = word_offset; + for (var i: u32 = 0; i < 8; i++) { + ${a.indicesSet('a_indices', aRank - 1, 'offset')}; + let a_value = ${a.getByIndices('a_indices')}; + let b_quantized_value = b_quantized_values[i]; + let b_dequantized_value = (b_quantized_value - zero_point) * scale; + value += a_value * b_dequantized_value; + offset++; + } + word_offset += 8; + } + scale_idex++; + ${ + zeroPoints ? ` + if (zero_point_offset == 28) { + zero_point_offset = 0; + zero_point_index++; + zero_point_word = ${zeroPoints.getByOffset('zero_point_index')}; + } else { + zero_point_offset += 4; + }` : + ''} + block_offset += uniforms.block_size; + } + ${output.setByOffset('global_idx', 'value')}; + } + `; + }; + return { + name: 'MatMulNBits', + shaderCache: + {hint: `${attributes.cacheKey};${inputs.length}`, inputDependencies: Array(inputs.length).fill('rank')}, + getRunData: () => ({ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64)}, + programUniforms + }), + getShaderSource + }; + }; + +export const matMulNBits = (context: ComputeContext, attributes: MatMulNBitsAttributes): void => { + validateInputs(context.inputs, attributes); + context.compute(createMatMulNBitsProgramInfo(context.inputs, attributes)); +}; + +export const parseMatMulNBitsAttributes = (attributes: Record): MatMulNBitsAttributes => + createAttributeWithCacheKey(attributes as Omit); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 14d6f37927590..a09ac78b17006 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -68,7 +68,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const dataType = inputs[0].dataType; const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length); const outputs = new Array(attributes.numOutputs); - const input = inputVariable('input', dataType, inputShape); + const input = inputVariable('input', dataType, inputShape.length); const sizeInSplitAxis = new Array(attributes.numOutputs); const outputsTensorInfo: TensorInfo[] = []; const outputShapes: number[][] = []; @@ -80,7 +80,7 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split const outputShape = inputShape.slice(); outputShape[attributes.axis] = attributes.splitSizes[i]; outputShapes.push(outputShape); - outputs[i] = outputVariable(`output${i}`, dataType, outputShape); + outputs[i] = outputVariable(`output${i}`, dataType, outputShape.length); outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType}); } programUniforms.push( diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index b9eff45e890c4..54eaf5e0c43cc 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -3,6 +3,12 @@ import {Tensor} from 'onnxruntime-common'; +// a dummy type declaration for Float16Array in case any polyfill is available. +declare global { + // eslint-disable-next-line @typescript-eslint/naming-convention, @typescript-eslint/no-explicit-any + const Float16Array: any; +} + // This file includes common definitions. They do NOT have dependency on the WebAssembly instance. /** @@ -117,7 +123,8 @@ export const tensorTypeToTypedArrayConstructor = (type: Tensor.Type): Float32Arr Uint8ArrayConstructor|Float64ArrayConstructor|Uint32ArrayConstructor|BigUint64ArrayConstructor => { switch (type) { case 'float16': - return Uint16Array; + // allow Float16Array polyfill. + return typeof Float16Array !== 'undefined' && Float16Array.from ? Float16Array : Uint16Array; case 'float32': return Float32Array; case 'uint8': @@ -169,7 +176,8 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro * Check whether the given tensor type is supported by GPU buffer */ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || - type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + type === 'float16' || type === 'int32' || type === 'int64' || type === 'uint32' || type === 'uint8' || + type === 'bool'; /** * Map string data location to integer value diff --git a/js/web/test/data/ops/matmulnbits.jsonc b/js/web/test/data/ops/matmulnbits.jsonc new file mode 100644 index 0000000000000..c57c431afb3ce --- /dev/null +++ b/js/web/test/data/ops/matmulnbits.jsonc @@ -0,0 +1,1527 @@ +[ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, -385, -1120, -963, -1984, -1285, -2592, -1351, -2944, -1161, -3040, -715, -2880, -13, -2464, 945, 0, + -1073, -3808, -2643, -6848, -3445, -9120, -3479, -10624, -2745, -11360, -1243, -11328, 1027, -10528, 4065, + 0, -1761, -6496, -4323, -11712, -5605, -15648, -5607, -18304, -4329, -19680, -1771, -19776, 2067, -18592, + 7185, 0, -2449, -9184, -6003, -16576, -7765, -22176, -7735, -25984, -5913, -28000, -2299, -28224, 3107, + -26656, 10305, 0, -3137, -11872, -7683, -21440, -9925, -28704, -9863, -33664, -7497, -36320, -2827, + -36672, 4147, -34720, 13425, 0, -3825, -14560, -9363, -26304, -12085, -35232, -11991, -41344, -9081, + -44640, -3355, -45120, 5187, -42784, 16545, 0, -4513, -17248, -11043, -31168, -14245, -41760, -14119, + -49024, -10665, -52960, -3883, -53568, 6227, -50848, 19665, 0, -5201, -19936, -12723, -36032, -16405, + -48288, -16247, -56704, -12249, -61280, -4411, -62016, 7267, -58912, 22785, 0, -5889, -22624, -14403, + -40896, -18565, -54816, -18375, -64384, -13833, -69600, -4939, -70464, 8307, -66976, 25905, 0, -6577, + -25312, -16083, -45760, -20725, -61344, -20503, -72064, -15417, -77920, -5467, -78912, 9347, -75040, + 29025, 0, -7265, -28000, -17763, -50624, -22885, -67872, -22631, -79744, -17001, -86240, -5995, -87360, + 10387, -83104, 32145, 0, -7953, -30688, -19443, -55488, -25045, -74400, -24759, -87424, -18585, -94560, + -6523, -95808, 11427, -91168, 35265, 0, -8641, -33376, -21123, -60352, -27205, -80928, -26887, -95104, + -20169, -102880, -7051, -104256, 12467, -99232, 38385, 0, -9329, -36064, -22803, -65216, -29365, -87456, + -29015, -102784, -21753, -111200, -7579, -112704, 13507, -107296, 41505, 0, -10017, -38752, -24483, + -70080, -31525, -93984, -31143, -110464, -23337, -119520, -8107, -121152, 14547, -115360, 44625, 0, + -10705, -41440, -26163, -74944, -33685, -100512, -33271, -118144, -24921, -127840, -8635, -129600, 15587, + -123424, 47745 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, + 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, + 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, + 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, + 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, + 253, 254, 255 + ], + "dims": [16, 16], + "type": "float32" + }, + { + "dims": [16, 1, 8], + "type": "uint8", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, + 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, + 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, + 127 + ] + }, + { + "dims": [16], + "type": "float32", + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + 0, 728, 688, 2376, 1632, 4280, 2832, 6440, 4288, 8856, 6000, 11528, 7968, 14456, 10192, 17640, 0, 2200, + 1840, 7176, 4448, 12920, 7824, 19432, 11968, 26712, 16880, 34760, 22560, 43576, 29008, 53160, 0, 3672, + 2992, 11976, 7264, 21560, 12816, 32424, 19648, 44568, 27760, 57992, 37152, 72696, 47824, 88680, 0, 5144, + 4144, 16776, 10080, 30200, 17808, 45416, 27328, 62424, 38640, 81224, 51744, 101816, 66640, 124200, 0, + 6616, 5296, 21576, 12896, 38840, 22800, 58408, 35008, 80280, 49520, 104456, 66336, 130936, 85456, 159720, + 0, 8088, 6448, 26376, 15712, 47480, 27792, 71400, 42688, 98136, 60400, 127688, 80928, 160056, 104272, + 195240, 0, 9560, 7600, 31176, 18528, 56120, 32784, 84392, 50368, 115992, 71280, 150920, 95520, 189176, + 123088, 230760, 0, 11032, 8752, 35976, 21344, 64760, 37776, 97384, 58048, 133848, 82160, 174152, 110112, + 218296, 141904, 266280, 0, 12504, 9904, 40776, 24160, 73400, 42768, 110376, 65728, 151704, 93040, 197384, + 124704, 247416, 160720, 301800, 0, 13976, 11056, 45576, 26976, 82040, 47760, 123368, 73408, 169560, + 103920, 220616, 139296, 276536, 179536, 337320, 0, 15448, 12208, 50376, 29792, 90680, 52752, 136360, + 81088, 187416, 114800, 243848, 153888, 305656, 198352, 372840, 0, 16920, 13360, 55176, 32608, 99320, + 57744, 149352, 88768, 205272, 125680, 267080, 168480, 334776, 217168, 408360, 0, 18392, 14512, 59976, + 35424, 107960, 62736, 162344, 96448, 223128, 136560, 290312, 183072, 363896, 235984, 443880, 0, 19864, + 15664, 64776, 38240, 116600, 67728, 175336, 104128, 240984, 147440, 313544, 197664, 393016, 254800, + 479400, 0, 21336, 16816, 69576, 41056, 125240, 72720, 188328, 111808, 258840, 158320, 336776, 212256, + 422136, 273616, 514920, 0, 22808, 17968, 74376, 43872, 133880, 77712, 201320, 119488, 276696, 169200, + 360008, 226848, 451256, 292432, 550440 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -428, -1288, -1068, -2288, -1420, -3000, -1484, -3424, -1260, -3560, -748, -3408, 52, -2968, 1140, + -2272, 2516, -1224, 4180, 80, 6132, 1672, 8372, 3552, 10900, 5720, 13716, 8176, 16820, 10920, 12276, 0, + -1116, -3976, -2748, -7152, -3580, -9528, -3612, -11104, -2844, -11880, -1276, -11856, 1092, -11032, 4260, + -8160, 8228, -6984, 12996, -3760, 18564, 264, 24932, 5088, 32100, 10712, 40068, 17136, 48836, 24360, + 42532, 0, -1804, -6664, -4428, -12016, -5740, -16056, -5740, -18784, -4428, -20200, -1804, -20304, 2132, + -19096, 7380, -14048, 13940, -12744, 21812, -7600, 30996, -1144, 41492, 6624, 53300, 15704, 66420, 26096, + 80852, 37800, 72788, 0, -2492, -9352, -6108, -16880, -7900, -22584, -7868, -26464, -6012, -28520, -2332, + -28752, 3172, -27160, 10500, -19936, 19652, -18504, 30628, -11440, 43428, -2552, 58052, 8160, 74500, + 20696, 92772, 35056, 112868, 51240, 103044, 0, -3180, -12040, -7788, -21744, -10060, -29112, -9996, + -34144, -7596, -36840, -2860, -37200, 4212, -35224, 13620, -25824, 25364, -24264, 39444, -15280, 55860, + -3960, 74612, 9696, 95700, 25688, 119124, 44016, 144884, 64680, 133300, 0, -3868, -14728, -9468, -26608, + -12220, -35640, -12124, -41824, -9180, -45160, -3388, -45648, 5252, -43288, 16740, -31712, 31076, -30024, + 48260, -19120, 68292, -5368, 91172, 11232, 116900, 30680, 145476, 52976, 176900, 78120, 163556, 0, -4556, + -17416, -11148, -31472, -14380, -42168, -14252, -49504, -10764, -53480, -3916, -54096, 6292, -51352, + 19860, -37600, 36788, -35784, 57076, -22960, 80724, -6776, 107732, 12768, 138100, 35672, 171828, 61936, + 208916, 91560, 193812, 0, -5244, -20104, -12828, -36336, -16540, -48696, -16380, -57184, -12348, -61800, + -4444, -62544, 7332, -59416, 22980, -43488, 42500, -41544, 65892, -26800, 93156, -8184, 124292, 14304, + 159300, 40664, 198180, 70896, 240932, 105000, 224068, 0, -5932, -22792, -14508, -41200, -18700, -55224, + -18508, -64864, -13932, -70120, -4972, -70992, 8372, -67480, 26100, -49376, 48212, -47304, 74708, -30640, + 105588, -9592, 140852, 15840, 180500, 45656, 224532, 79856, 272948, 118440, 254324, 0, -6620, -25480, + -16188, -46064, -20860, -61752, -20636, -72544, -15516, -78440, -5500, -79440, 9412, -75544, 29220, + -55264, 53924, -53064, 83524, -34480, 118020, -11000, 157412, 17376, 201700, 50648, 250884, 88816, 304964, + 131880, 284580, 0, -7308, -28168, -17868, -50928, -23020, -68280, -22764, -80224, -17100, -86760, -6028, + -87888, 10452, -83608, 32340, -61152, 59636, -58824, 92340, -38320, 130452, -12408, 173972, 18912, 222900, + 55640, 277236, 97776, 336980, 145320, 314836, 0, -7996, -30856, -19548, -55792, -25180, -74808, -24892, + -87904, -18684, -95080, -6556, -96336, 11492, -91672, 35460, -67040, 65348, -64584, 101156, -42160, + 142884, -13816, 190532, 20448, 244100, 60632, 303588, 106736, 368996, 158760, 345092, 0, -8684, -33544, + -21228, -60656, -27340, -81336, -27020, -95584, -20268, -103400, -7084, -104784, 12532, -99736, 38580, + -72928, 71060, -70344, 109972, -46000, 155316, -15224, 207092, 21984, 265300, 65624, 329940, 115696, + 401012, 172200, 375348, 0, -9372, -36232, -22908, -65520, -29500, -87864, -29148, -103264, -21852, + -111720, -7612, -113232, 13572, -107800, 41700, -78816, 76772, -76104, 118788, -49840, 167748, -16632, + 223652, 23520, 286500, 70616, 356292, 124656, 433028, 185640, 405604, 0, -10060, -38920, -24588, -70384, + -31660, -94392, -31276, -110944, -23436, -120040, -8140, -121680, 14612, -115864, 44820, -84704, 82484, + -81864, 127604, -53680, 180180, -18040, 240212, 25056, 307700, 75608, 382644, 133616, 465044, 199080, + 435860, 0, -10748, -41608, -26268, -75248, -33820, -100920, -33404, -118624, -25020, -128360, -8668, + -130128, 15652, -123928, 47940, -90592, 88196, -87624, 136420, -57520, 192612, -19448, 256772, 26592, + 328900, 80600, 408996, 142576, 497060, 212520, 466116, 0, -11436, -44296, -27948, -80112, -35980, -107448, + -35532, -126304, -26604, -136680, -9196, -138576, 16692, -131992, 51060, -96480, 93908, -93384, 145236, + -61360, 205044, -20856, 273332, 28128, 350100, 85592, 435348, 151536, 529076, 225960, 496372, 0, -12124, + -46984, -29628, -84976, -38140, -113976, -37660, -133984, -28188, -145000, -9724, -147024, 17732, -140056, + 54180, -102368, 99620, -99144, 154052, -65200, 217476, -22264, 289892, 29664, 371300, 90584, 461700, + 160496, 561092, 239400, 526628, 0, -12812, -49672, -31308, -89840, -40300, -120504, -39788, -141664, + -29772, -153320, -10252, -155472, 18772, -148120, 57300, -108256, 105332, -104904, 162868, -69040, 229908, + -23672, 306452, 31200, 392500, 95576, 488052, 169456, 593108, 252840, 556884, 0, -13500, -52360, -32988, + -94704, -42460, -127032, -41916, -149344, -31356, -161640, -10780, -163920, 19812, -156184, 60420, + -114144, 111044, -110664, 171684, -72880, 242340, -25080, 323012, 32736, 413700, 100568, 514404, 178416, + 625124, 266280, 587140, 0, -14188, -55048, -34668, -99568, -44620, -133560, -44044, -157024, -32940, + -169960, -11308, -172368, 20852, -164248, 63540, -120032, 116756, -116424, 180500, -76720, 254772, -26488, + 339572, 34272, 434900, 105560, 540756, 187376, 657140, 279720, 617396, 0, -14876, -57736, -36348, -104432, + -46780, -140088, -46172, -164704, -34524, -178280, -11836, -180816, 21892, -172312, 66660, -125920, + 122468, -122184, 189316, -80560, 267204, -27896, 356132, 35808, 456100, 110552, 567108, 196336, 689156, + 293160, 647652, 0, -15564, -60424, -38028, -109296, -48940, -146616, -48300, -172384, -36108, -186600, + -12364, -189264, 22932, -180376, 69780, -131808, 128180, -127944, 198132, -84400, 279636, -29304, 372692, + 37344, 477300, 115544, 593460, 205296, 721172, 306600, 677908, 0, -16252, -63112, -39708, -114160, -51100, + -153144, -50428, -180064, -37692, -194920, -12892, -197712, 23972, -188440, 72900, -137696, 133892, + -133704, 206948, -88240, 292068, -30712, 389252, 38880, 498500, 120536, 619812, 214256, 753188, 320040, + 708164, 0, -16940, -65800, -41388, -119024, -53260, -159672, -52556, -187744, -39276, -203240, -13420, + -206160, 25012, -196504, 76020, -143584, 139604, -139464, 215764, -92080, 304500, -32120, 405812, 40416, + 519700, 125528, 646164, 223216, 785204, 333480, 738420, 0, -17628, -68488, -43068, -123888, -55420, + -166200, -54684, -195424, -40860, -211560, -13948, -214608, 26052, -204568, 79140, -149472, 145316, + -145224, 224580, -95920, 316932, -33528, 422372, 41952, 540900, 130520, 672516, 232176, 817220, 346920, + 768676, 0, -18316, -71176, -44748, -128752, -57580, -172728, -56812, -203104, -42444, -219880, -14476, + -223056, 27092, -212632, 82260, -155360, 151028, -150984, 233396, -99760, 329364, -34936, 438932, 43488, + 562100, 135512, 698868, 241136, 849236, 360360, 798932, 0, -19004, -73864, -46428, -133616, -59740, + -179256, -58940, -210784, -44028, -228200, -15004, -231504, 28132, -220696, 85380, -161248, 156740, + -156744, 242212, -103600, 341796, -36344, 455492, 45024, 583300, 140504, 725220, 250096, 881252, 373800, + 829188, 0, -19692, -76552, -48108, -138480, -61900, -185784, -61068, -218464, -45612, -236520, -15532, + -239952, 29172, -228760, 88500, -167136, 162452, -162504, 251028, -107440, 354228, -37752, 472052, 46560, + 604500, 145496, 751572, 259056, 913268, 387240, 859444, 0, -20380, -79240, -49788, -143344, -64060, + -192312, -63196, -226144, -47196, -244840, -16060, -248400, 30212, -236824, 91620, -173024, 168164, + -168264, 259844, -111280, 366660, -39160, 488612, 48096, 625700, 150488, 777924, 268016, 945284, 400680, + 889700, 0, -21068, -81928, -51468, -148208, -66220, -198840, -65324, -233824, -48780, -253160, -16588, + -256848, 31252, -244888, 94740, -178912, 173876, -174024, 268660, -115120, 379092, -40568, 505172, 49632, + 646900, 155480, 804276, 276976, 977300, 414120, 919956, 0, -21756, -84616, -53148, -153072, -68380, + -205368, -67452, -241504, -50364, -261480, -17116, -265296, 32292, -252952, 97860, -184800, 179588, + -179784, 277476, -118960, 391524, -41976, 521732, 51168, 668100, 160472, 830628, 285936, 1009316, 427560, + 950212 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 16, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=16, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [32, 16], + "type": "float32" + }, + { + "dims": [32, 1, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 660, 888, 2196, 2064, 4020, 3528, 6132, 5280, 8532, 7320, 11220, 9648, 14196, 12264, 17460, 15136, + 21012, 18360, 24852, 21840, 28980, 25608, 33396, 29664, 38100, 34008, 43092, 38640, 48372, 43560, 46004, + 0, 2020, 2296, 6660, 5392, 12100, 9288, 18340, 13984, 25380, 19480, 33220, 25776, 41860, 32872, 51300, + 42016, 61540, 49464, 72580, 58960, 84420, 69256, 97060, 80352, 110500, 92248, 124740, 104944, 139780, + 118440, 139748, 0, 3380, 3704, 11124, 8720, 20180, 15048, 30548, 22688, 42228, 31640, 55220, 41904, 69524, + 53480, 85140, 68896, 102068, 80568, 120308, 96080, 139860, 112904, 160724, 131040, 182900, 150488, 206388, + 171248, 231188, 193320, 233492, 0, 4740, 5112, 15588, 12048, 28260, 20808, 42756, 31392, 59076, 43800, + 77220, 58032, 97188, 74088, 118980, 95776, 142596, 111672, 168036, 133200, 195300, 156552, 224388, 181728, + 255300, 208728, 288036, 237552, 322596, 268200, 327236, 0, 6100, 6520, 20052, 15376, 36340, 26568, 54964, + 40096, 75924, 55960, 99220, 74160, 124852, 94696, 152820, 122656, 183124, 142776, 215764, 170320, 250740, + 200200, 288052, 232416, 327700, 266968, 369684, 303856, 414004, 343080, 420980, 0, 7460, 7928, 24516, + 18704, 44420, 32328, 67172, 48800, 92772, 68120, 121220, 90288, 152516, 115304, 186660, 149536, 223652, + 173880, 263492, 207440, 306180, 243848, 351716, 283104, 400100, 325208, 451332, 370160, 505412, 417960, + 514724, 0, 8820, 9336, 28980, 22032, 52500, 38088, 79380, 57504, 109620, 80280, 143220, 106416, 180180, + 135912, 220500, 176416, 264180, 204984, 311220, 244560, 361620, 287496, 415380, 333792, 472500, 383448, + 532980, 436464, 596820, 492840, 608468, 0, 10180, 10744, 33444, 25360, 60580, 43848, 91588, 66208, 126468, + 92440, 165220, 122544, 207844, 156520, 254340, 203296, 304708, 236088, 358948, 281680, 417060, 331144, + 479044, 384480, 544900, 441688, 614628, 502768, 688228, 567720, 702212, 0, 11540, 12152, 37908, 28688, + 68660, 49608, 103796, 74912, 143316, 104600, 187220, 138672, 235508, 177128, 288180, 230176, 345236, + 267192, 406676, 318800, 472500, 374792, 542708, 435168, 617300, 499928, 696276, 569072, 779636, 642600, + 795956, 0, 12900, 13560, 42372, 32016, 76740, 55368, 116004, 83616, 160164, 116760, 209220, 154800, + 263172, 197736, 322020, 257056, 385764, 298296, 454404, 355920, 527940, 418440, 606372, 485856, 689700, + 558168, 777924, 635376, 871044, 717480, 889700, 0, 14260, 14968, 46836, 35344, 84820, 61128, 128212, + 92320, 177012, 128920, 231220, 170928, 290836, 218344, 355860, 283936, 426292, 329400, 502132, 393040, + 583380, 462088, 670036, 536544, 762100, 616408, 859572, 701680, 962452, 792360, 983444, 0, 15620, 16376, + 51300, 38672, 92900, 66888, 140420, 101024, 193860, 141080, 253220, 187056, 318500, 238952, 389700, + 310816, 466820, 360504, 549860, 430160, 638820, 505736, 733700, 587232, 834500, 674648, 941220, 767984, + 1053860, 867240, 1077188, 0, 16980, 17784, 55764, 42000, 100980, 72648, 152628, 109728, 210708, 153240, + 275220, 203184, 346164, 259560, 423540, 337696, 507348, 391608, 597588, 467280, 694260, 549384, 797364, + 637920, 906900, 732888, 1022868, 834288, 1145268, 942120, 1170932, 0, 18340, 19192, 60228, 45328, 109060, + 78408, 164836, 118432, 227556, 165400, 297220, 219312, 373828, 280168, 457380, 364576, 547876, 422712, + 645316, 504400, 749700, 593032, 861028, 688608, 979300, 791128, 1104516, 900592, 1236676, 1017000, + 1264676, 0, 19700, 20600, 64692, 48656, 117140, 84168, 177044, 127136, 244404, 177560, 319220, 235440, + 401492, 300776, 491220, 391456, 588404, 453816, 693044, 541520, 805140, 636680, 924692, 739296, 1051700, + 849368, 1186164, 966896, 1328084, 1091880, 1358420, 0, 21060, 22008, 69156, 51984, 125220, 89928, 189252, + 135840, 261252, 189720, 341220, 251568, 429156, 321384, 525060, 418336, 628932, 484920, 740772, 578640, + 860580, 680328, 988356, 789984, 1124100, 907608, 1267812, 1033200, 1419492, 1166760, 1452164, 0, 22420, + 23416, 73620, 55312, 133300, 95688, 201460, 144544, 278100, 201880, 363220, 267696, 456820, 341992, + 558900, 445216, 669460, 516024, 788500, 615760, 916020, 723976, 1052020, 840672, 1196500, 965848, 1349460, + 1099504, 1510900, 1241640, 1545908, 0, 23780, 24824, 78084, 58640, 141380, 101448, 213668, 153248, 294948, + 214040, 385220, 283824, 484484, 362600, 592740, 472096, 709988, 547128, 836228, 652880, 971460, 767624, + 1115684, 891360, 1268900, 1024088, 1431108, 1165808, 1602308, 1316520, 1639652, 0, 25140, 26232, 82548, + 61968, 149460, 107208, 225876, 161952, 311796, 226200, 407220, 299952, 512148, 383208, 626580, 498976, + 750516, 578232, 883956, 690000, 1026900, 811272, 1179348, 942048, 1341300, 1082328, 1512756, 1232112, + 1693716, 1391400, 1733396, 0, 26500, 27640, 87012, 65296, 157540, 112968, 238084, 170656, 328644, 238360, + 429220, 316080, 539812, 403816, 660420, 525856, 791044, 609336, 931684, 727120, 1082340, 854920, 1243012, + 992736, 1413700, 1140568, 1594404, 1298416, 1785124, 1466280, 1827140, 0, 27860, 29048, 91476, 68624, + 165620, 118728, 250292, 179360, 345492, 250520, 451220, 332208, 567476, 424424, 694260, 552736, 831572, + 640440, 979412, 764240, 1137780, 898568, 1306676, 1043424, 1486100, 1198808, 1676052, 1364720, 1876532, + 1541160, 1920884, 0, 29220, 30456, 95940, 71952, 173700, 124488, 262500, 188064, 362340, 262680, 473220, + 348336, 595140, 445032, 728100, 579616, 872100, 671544, 1027140, 801360, 1193220, 942216, 1370340, + 1094112, 1558500, 1257048, 1757700, 1431024, 1967940, 1616040, 2014628, 0, 30580, 31864, 100404, 75280, + 181780, 130248, 274708, 196768, 379188, 274840, 495220, 364464, 622804, 465640, 761940, 606496, 912628, + 702648, 1074868, 838480, 1248660, 985864, 1434004, 1144800, 1630900, 1315288, 1839348, 1497328, 2059348, + 1690920, 2108372, 0, 31940, 33272, 104868, 78608, 189860, 136008, 286916, 205472, 396036, 287000, 517220, + 380592, 650468, 486248, 795780, 633376, 953156, 733752, 1122596, 875600, 1304100, 1029512, 1497668, + 1195488, 1703300, 1373528, 1920996, 1563632, 2150756, 1765800, 2202116, 0, 33300, 34680, 109332, 81936, + 197940, 141768, 299124, 214176, 412884, 299160, 539220, 396720, 678132, 506856, 829620, 660256, 993684, + 764856, 1170324, 912720, 1359540, 1073160, 1561332, 1246176, 1775700, 1431768, 2002644, 1629936, 2242164, + 1840680, 2295860, 0, 34660, 36088, 113796, 85264, 206020, 147528, 311332, 222880, 429732, 311320, 561220, + 412848, 705796, 527464, 863460, 687136, 1034212, 795960, 1218052, 949840, 1414980, 1116808, 1624996, + 1296864, 1848100, 1490008, 2084292, 1696240, 2333572, 1915560, 2389604, 0, 36020, 37496, 118260, 88592, + 214100, 153288, 323540, 231584, 446580, 323480, 583220, 428976, 733460, 548072, 897300, 714016, 1074740, + 827064, 1265780, 986960, 1470420, 1160456, 1688660, 1347552, 1920500, 1548248, 2165940, 1762544, 2424980, + 1990440, 2483348, 0, 37380, 38904, 122724, 91920, 222180, 159048, 335748, 240288, 463428, 335640, 605220, + 445104, 761124, 568680, 931140, 740896, 1115268, 858168, 1313508, 1024080, 1525860, 1204104, 1752324, + 1398240, 1992900, 1606488, 2247588, 1828848, 2516388, 2065320, 2577092, 0, 38740, 40312, 127188, 95248, + 230260, 164808, 347956, 248992, 480276, 347800, 627220, 461232, 788788, 589288, 964980, 767776, 1155796, + 889272, 1361236, 1061200, 1581300, 1247752, 1815988, 1448928, 2065300, 1664728, 2329236, 1895152, 2607796, + 2140200, 2670836, 0, 40100, 41720, 131652, 98576, 238340, 170568, 360164, 257696, 497124, 359960, 649220, + 477360, 816452, 609896, 998820, 794656, 1196324, 920376, 1408964, 1098320, 1636740, 1291400, 1879652, + 1499616, 2137700, 1722968, 2410884, 1961456, 2699204, 2215080, 2764580, 0, 41460, 43128, 136116, 101904, + 246420, 176328, 372372, 266400, 513972, 372120, 671220, 493488, 844116, 630504, 1032660, 821536, 1236852, + 951480, 1456692, 1135440, 1692180, 1335048, 1943316, 1550304, 2210100, 1781208, 2492532, 2027760, 2790612, + 2289960, 2858324, 0, 42820, 44536, 140580, 105232, 254500, 182088, 384580, 275104, 530820, 384280, 693220, + 509616, 871780, 651112, 1066500, 848416, 1277380, 982584, 1504420, 1172560, 1747620, 1378696, 2006980, + 1600992, 2282500, 1839448, 2574180, 2094064, 2882020, 2364840, 2952068 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, 56908, 81124, + 108476, 138964, 140844, -3868, -21508, -33964, -41236, -43324, -40228, -31948, -18484, 5252, 23996, 53012, + 87212, 126596, 171164, 220916, 228236, -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, + 4900, 30108, 70196, 117516, 172068, 233852, 302868, 315628, -6620, -38980, -62060, -75860, -80380, -75620, + -61580, -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -7996, -47716, -76108, -93172, + -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, 466772, 490412, -9372, + -56452, -90156, -110484, -117436, -111012, -91212, -58036, 3844, 48444, 121748, 208428, 308484, 421916, + 548724, 577804, -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, + 238732, 353956, 484604, 630676, 665196, -12124, -73924, -118252, -145108, -154492, -146404, -120844, + -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -13500, -82660, -132300, -162420, + -173020, -164100, -135660, -87700, 2788, 66780, 173300, 299340, 444900, 609980, 794580, 839980, -14876, + -91396, -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, + 876532, 927372, -16252, -100132, -160396, -197044, -210076, -199492, -165292, -107476, 2084, 79004, + 207668, 359948, 535844, 735356, 958484, 1014764, -17628, -108868, -174444, -214356, -228604, -217188, + -180108, -117364, 1732, 85116, 224852, 390252, 581316, 798044, 1040436, 1102156, -19004, -117604, -188492, + -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, 860732, 1122388, + 1189548, -20380, -126340, -202540, -248980, -265660, -252580, -209740, -137140, 1028, 97340, 259220, + 450860, 672260, 923420, 1204340, 1276940, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 16, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=16, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ], + "dims": [16, 32], + "type": "float32" + }, + { + "dims": [16, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [16], + "type": "uint8", + "data": [128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128] + } + ], + "outputs": [ + { + "dims": [16, 16], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, 170956, 205540, 243260, + 284116, 296364, -3868, -2948, 3156, 14444, 30916, 52572, 79412, 111436, 153732, 191036, 238612, 291372, + 349316, 412444, 480756, 506636, -5244, -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, + 337716, 411788, 493092, 581628, 677396, 716908, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, + 284100, 350716, 436820, 532204, 636868, 750812, 874036, 927180, -7996, -4580, 10164, 36236, 73636, 122364, + 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, -9372, -5124, 12500, + 43500, 87876, 145628, 216756, 301260, 414468, 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, + -10748, -5668, 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, + 1258364, 1463956, 1557996, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, 544836, 670076, + 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -13500, -6756, 19508, 65292, 130596, 215420, 319764, + 443628, 610020, 749916, 932340, 1134284, 1355748, 1596732, 1857236, 1978540, -14876, -7300, 21844, 72556, + 144836, 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, + -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, 1130548, 1375116, 1643300, + 1935100, 2250516, 2399084, -17628, -8388, 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, + 1229652, 1495532, 1787076, 2104284, 2447156, 2609356, -19004, -8932, 28852, 94348, 187556, 308476, 457108, + 633452, 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -20380, -9476, 31188, + 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, 2442652, 2840436, + 3029900, -21756, -10020, 33524, 108876, 216036, 355004, 525780, 728364, 1001124, 1228956, 1526964, + 1856780, 2218404, 2611836, 3037076, 3240172 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -4036, -5868, -6612, -6268, -4836, -2316, 1292, 5956, 11772, 18644, 26604, 35652, 45788, 57012, + 53452, -59740, -53956, -47084, -39124, -30076, -19940, -8716, 3596, 16996, 31484, 47060, 63724, 81476, + 100316, 120244, 109004, -2492, -12772, -19916, -23924, -24796, -22532, -17132, -8596, 5604, 17884, 35828, + 56908, 81124, 108476, 138964, 140844, -199356, -184548, -166604, -145524, -121308, -93956, -63468, -29844, + 6916, 46812, 89844, 136012, 185316, 237756, 293332, 287532, -3868, -21508, -33964, -41236, -43324, -40228, + -31948, -18484, 5252, 23996, 53012, 87212, 126596, 171164, 220916, 228236, -338972, -315140, -286124, + -251924, -212540, -167972, -118220, -63284, -3164, 62140, 132628, 208300, 289156, 375196, 466420, 466060, + -5244, -30244, -48012, -58548, -61852, -57924, -46764, -28372, 4900, 30108, 70196, 117516, 172068, 233852, + 302868, 315628, -478588, -445732, -405644, -358324, -303772, -241988, -172972, -96724, -13244, 77468, + 175412, 280588, 392996, 512636, 639508, 644588, -6620, -38980, -62060, -75860, -80380, -75620, -61580, + -38260, 4548, 36220, 87380, 147820, 217540, 296540, 384820, 403020, -618204, -576324, -525164, -464724, + -395004, -316004, -227724, -130164, -23324, 92796, 218196, 352876, 496836, 650076, 812596, 823116, -7996, + -47716, -76108, -93172, -98908, -93316, -76396, -48148, 4196, 42332, 104564, 178124, 263012, 359228, + 466772, 490412, -757820, -706916, -644684, -571124, -486236, -390020, -282476, -163604, -33404, 108124, + 260980, 425164, 600676, 787516, 985684, 1001644, -9372, -56452, -90156, -110484, -117436, -111012, -91212, + -58036, 3844, 48444, 121748, 208428, 308484, 421916, 548724, 577804, -897436, -837508, -764204, -677524, + -577468, -464036, -337228, -197044, -43484, 123452, 303764, 497452, 704516, 924956, 1158772, 1180172, + -10748, -65188, -104204, -127796, -135964, -128708, -106028, -67924, 3492, 54556, 138932, 238732, 353956, + 484604, 630676, 665196, -1037052, -968100, -883724, -783924, -668700, -538052, -391980, -230484, -53564, + 138780, 346548, 569740, 808356, 1062396, 1331860, 1358700, -12124, -73924, -118252, -145108, -154492, + -146404, -120844, -77812, 3140, 60668, 156116, 269036, 399428, 547292, 712628, 752588, -1176668, -1098692, + -1003244, -890324, -759932, -612068, -446732, -263924, -63644, 154108, 389332, 642028, 912196, 1199836, + 1504948, 1537228, -13500, -82660, -132300, -162420, -173020, -164100, -135660, -87700, 2788, 66780, + 173300, 299340, 444900, 609980, 794580, 839980, -1316284, -1229284, -1122764, -996724, -851164, -686084, + -501484, -297364, -73724, 169436, 432116, 714316, 1016036, 1337276, 1678036, 1715756, -14876, -91396, + -146348, -179732, -191548, -181796, -150476, -97588, 2436, 72892, 190484, 329644, 490372, 672668, 876532, + 927372, -1455900, -1359876, -1242284, -1103124, -942396, -760100, -556236, -330804, -83804, 184764, + 474900, 786604, 1119876, 1474716, 1851124, 1894284, -16252, -100132, -160396, -197044, -210076, -199492, + -165292, -107476, 2084, 79004, 207668, 359948, 535844, 735356, 958484, 1014764, -1595516, -1490468, + -1361804, -1209524, -1033628, -834116, -610988, -364244, -93884, 200092, 517684, 858892, 1223716, 1612156, + 2024212, 2072812, -17628, -108868, -174444, -214356, -228604, -217188, -180108, -117364, 1732, 85116, + 224852, 390252, 581316, 798044, 1040436, 1102156, -1735132, -1621060, -1481324, -1315924, -1124860, + -908132, -665740, -397684, -103964, 215420, 560468, 931180, 1327556, 1749596, 2197300, 2251340, -19004, + -117604, -188492, -231668, -247132, -234884, -194924, -127252, 1380, 91228, 242036, 420556, 626788, + 860732, 1122388, 1189548, -1874748, -1751652, -1600844, -1422324, -1216092, -982148, -720492, -431124, + -114044, 230748, 603252, 1003468, 1431396, 1887036, 2370388, 2429868, -20380, -126340, -202540, -248980, + -265660, -252580, -209740, -137140, 1028, 97340, 259220, 450860, 672260, 923420, 1204340, 1276940, + -2014364, -1882244, -1720364, -1528724, -1307324, -1056164, -775244, -464564, -124124, 246076, 646036, + 1075756, 1535236, 2024476, 2543476, 2608396, -21756, -135076, -216588, -266292, -284188, -270276, -224556, + -147028, 676, 103452, 276404, 481164, 717732, 986108, 1286292, 1364332, -2153980, -2012836, -1839884, + -1635124, -1398556, -1130180, -829996, -498004, -134204, 261404, 688820, 1148044, 1639076, 2161916, + 2716564, 2786924, -23132, -143812, -230636, -283604, -302716, -287972, -239372, -156916, 324, 109564, + 293588, 511468, 763204, 1048796, 1368244, 1451724, -2293596, -2143428, -1959404, -1741524, -1489788, + -1204196, -884748, -531444, -144284, 276732, 731604, 1220332, 1742916, 2299356, 2889652, 2965452, -24508, + -152548, -244684, -300916, -321244, -305668, -254188, -166804, -28, 115676, 310772, 541772, 808676, + 1111484, 1450196, 1539116, -2433212, -2274020, -2078924, -1847924, -1581020, -1278212, -939500, -564884, + -154364, 292060, 774388, 1292620, 1846756, 2436796, 3062740, 3143980, -25884, -161284, -258732, -318228, + -339772, -323364, -269004, -176692, -380, 121788, 327956, 572076, 854148, 1174172, 1532148, 1626508, + -2572828, -2404612, -2198444, -1954324, -1672252, -1352228, -994252, -598324, -164444, 307388, 817172, + 1364908, 1950596, 2574236, 3235828, 3322508, -27260, -170020, -272780, -335540, -358300, -341060, -283820, + -186580, -732, 127900, 345140, 602380, 899620, 1236860, 1614100, 1713900, -2712444, -2535204, -2317964, + -2060724, -1763484, -1426244, -1049004, -631764, -174524, 322716, 859956, 1437196, 2054436, 2711676, + 3408916, 3501036, -28636, -178756, -286828, -352852, -376828, -358756, -298636, -196468, -1084, 134012, + 362324, 632684, 945092, 1299548, 1696052, 1801292, -2852060, -2665796, -2437484, -2167124, -1854716, + -1500260, -1103756, -665204, -184604, 338044, 902740, 1509484, 2158276, 2849116, 3582004, 3679564, -30012, + -187492, -300876, -370164, -395356, -376452, -313452, -206356, -1436, 140124, 379508, 662988, 990564, + 1362236, 1778004, 1888684, -2991676, -2796388, -2557004, -2273524, -1945948, -1574276, -1158508, -698644, + -194684, 353372, 945524, 1581772, 2262116, 2986556, 3755092, 3858092, -31388, -196228, -314924, -387476, + -413884, -394148, -328268, -216244, -1788, 146236, 396692, 693292, 1036036, 1424924, 1859956, 1976076, + -3131292, -2926980, -2676524, -2379924, -2037180, -1648292, -1213260, -732084, -204764, 368700, 988308, + 1654060, 2365956, 3123996, 3928180, 4036620, -32764, -204964, -328972, -404788, -432412, -411844, -343084, + -226132, -2140, 152348, 413876, 723596, 1081508, 1487612, 1941908, 2063468, -3270908, -3057572, -2796044, + -2486324, -2128412, -1722308, -1268012, -765524, -214844, 384028, 1031092, 1726348, 2469796, 3261436, + 4101268, 4215148, -34140, -213700, -343020, -422100, -450940, -429540, -357900, -236020, -2492, 158460, + 431060, 753900, 1126980, 1550300, 2023860, 2150860, -3410524, -3188164, -2915564, -2592724, -2219644, + -1796324, -1322764, -798964, -224924, 399356, 1073876, 1798636, 2573636, 3398876, 4274356, 4393676, + -35516, -222436, -357068, -439412, -469468, -447236, -372716, -245908, -2844, 164572, 448244, 784204, + 1172452, 1612988, 2105812, 2238252, -3550140, -3318756, -3035084, -2699124, -2310876, -1870340, -1377516, + -832404, -235004, 414684, 1116660, 1870924, 2677476, 3536316, 4447444, 4572204, -36892, -231172, -371116, + -456724, -487996, -464932, -387532, -255796, -3196, 170684, 465428, 814508, 1217924, 1675676, 2187764, + 2325644, -3689756, -3449348, -3154604, -2805524, -2402108, -1944356, -1432268, -865844, -245084, 430012, + 1159444, 1943212, 2781316, 3673756, 4620532, 4750732, -38268, -239908, -385164, -474036, -506524, -482628, + -402348, -265684, -3548, 176796, 482612, 844812, 1263396, 1738364, 2269716, 2413036, -3829372, -3579940, + -3274124, -2911924, -2493340, -2018372, -1487020, -899284, -255164, 445340, 1202228, 2015500, 2885156, + 3811196, 4793620, 4929260, -39644, -248644, -399212, -491348, -525052, -500324, -417164, -275572, -3900, + 182908, 499796, 875116, 1308868, 1801052, 2351668, 2500428, -3968988, -3710532, -3393644, -3018324, + -2584572, -2092388, -1541772, -932724, -265244, 460668, 1245012, 2087788, 2988996, 3948636, 4966708, + 5107788, -41020, -257380, -413260, -508660, -543580, -518020, -431980, -285460, -4252, 189020, 516980, + 905420, 1354340, 1863740, 2433620, 2587820, -4108604, -3841124, -3513164, -3124724, -2675804, -2166404, + -1596524, -966164, -275324, 475996, 1287796, 2160076, 3092836, 4086076, 5139796, 5286316, -42396, -266116, + -427308, -525972, -562108, -535716, -446796, -295348, -4604, 195132, 534164, 935724, 1399812, 1926428, + 2515572, 2675212, -4248220, -3971716, -3632684, -3231124, -2767036, -2240420, -1651276, -999604, -285404, + 491324, 1330580, 2232364, 3196676, 4223516, 5312884, 5464844, -43772, -274852, -441356, -543284, -580636, + -553412, -461612, -305236, -4956, 201244, 551348, 966028, 1445284, 1989116, 2597524, 2762604, -4387836, + -4102308, -3752204, -3337524, -2858268, -2314436, -1706028, -1033044, -295484, 506652, 1373364, 2304652, + 3300516, 4360956, 5485972, 5643372 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 16, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=16, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 2, 8], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [64], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, + 55, 56, 57, 58, 59, 60, 61, 62, 63 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + -1116, -1860, -1516, -84, 2436, 6044, 10740, 16524, 23364, 31356, 40404, 50540, 61764, 74076, 87476, + 86092, -24924, -16964, -7916, 2220, 13444, 25756, 39156, 53644, 69220, 85884, 103636, 122476, 142404, + 163420, 185524, 176460, -2492, -2404, 820, 7180, 16676, 29308, 45076, 63980, 88548, 111196, 139508, + 170956, 205540, 243260, 284116, 296364, -33468, -8292, 20020, 51468, 86052, 123772, 164628, 208620, + 255748, 306012, 359412, 415948, 475620, 538428, 604372, 608940, -3868, -2948, 3156, 14444, 30916, 52572, + 79412, 111436, 153732, 191036, 238612, 291372, 349316, 412444, 480756, 506636, -42012, 380, 47956, 100716, + 158660, 221788, 290100, 363596, 442276, 526140, 615188, 709420, 808836, 913436, 1023220, 1041420, -5244, + -3492, 5492, 21708, 45156, 75836, 113748, 158892, 218916, 270876, 337716, 411788, 493092, 581628, 677396, + 716908, -50556, 9052, 75892, 149964, 231268, 319804, 415572, 518572, 628804, 746268, 870964, 1002892, + 1142052, 1288444, 1442068, 1473900, -6620, -4036, 7828, 28972, 59396, 99100, 148084, 206348, 284100, + 350716, 436820, 532204, 636868, 750812, 874036, 927180, -59100, 17724, 103828, 199212, 303876, 417820, + 541044, 673548, 815332, 966396, 1126740, 1296364, 1475268, 1663452, 1860916, 1906380, -7996, -4580, 10164, + 36236, 73636, 122364, 182420, 253804, 349284, 430556, 535924, 652620, 780644, 919996, 1070676, 1137452, + -67644, 26396, 131764, 248460, 376484, 515836, 666516, 828524, 1001860, 1186524, 1382516, 1589836, + 1808484, 2038460, 2279764, 2338860, -9372, -5124, 12500, 43500, 87876, 145628, 216756, 301260, 414468, + 510396, 635028, 773036, 924420, 1089180, 1267316, 1347724, -76188, 35068, 159700, 297708, 449092, 613852, + 791988, 983500, 1188388, 1406652, 1638292, 1883308, 2141700, 2413468, 2698612, 2771340, -10748, -5668, + 14836, 50764, 102116, 168892, 251092, 348716, 479652, 590236, 734132, 893452, 1068196, 1258364, 1463956, + 1557996, -84732, 43740, 187636, 346956, 521700, 711868, 917460, 1138476, 1374916, 1626780, 1894068, + 2176780, 2474916, 2788476, 3117460, 3203820, -12124, -6212, 17172, 58028, 116356, 192156, 285428, 396172, + 544836, 670076, 833236, 1013868, 1211972, 1427548, 1660596, 1768268, -93276, 52412, 215572, 396204, + 594308, 809884, 1042932, 1293452, 1561444, 1846908, 2149844, 2470252, 2808132, 3163484, 3536308, 3636300, + -13500, -6756, 19508, 65292, 130596, 215420, 319764, 443628, 610020, 749916, 932340, 1134284, 1355748, + 1596732, 1857236, 1978540, -101820, 61084, 243508, 445452, 666916, 907900, 1168404, 1448428, 1747972, + 2067036, 2405620, 2763724, 3141348, 3538492, 3955156, 4068780, -14876, -7300, 21844, 72556, 144836, + 238684, 354100, 491084, 675204, 829756, 1031444, 1254700, 1499524, 1765916, 2053876, 2188812, -110364, + 69756, 271444, 494700, 739524, 1005916, 1293876, 1603404, 1934500, 2287164, 2661396, 3057196, 3474564, + 3913500, 4374004, 4501260, -16252, -7844, 24180, 79820, 159076, 261948, 388436, 538540, 740388, 909596, + 1130548, 1375116, 1643300, 1935100, 2250516, 2399084, -118908, 78428, 299380, 543948, 812132, 1103932, + 1419348, 1758380, 2121028, 2507292, 2917172, 3350668, 3807780, 4288508, 4792852, 4933740, -17628, -8388, + 26516, 87084, 173316, 285212, 422772, 585996, 805572, 989436, 1229652, 1495532, 1787076, 2104284, 2447156, + 2609356, -127452, 87100, 327316, 593196, 884740, 1201948, 1544820, 1913356, 2307556, 2727420, 3172948, + 3644140, 4140996, 4663516, 5211700, 5366220, -19004, -8932, 28852, 94348, 187556, 308476, 457108, 633452, + 870756, 1069276, 1328756, 1615948, 1930852, 2273468, 2643796, 2819628, -135996, 95772, 355252, 642444, + 957348, 1299964, 1670292, 2068332, 2494084, 2947548, 3428724, 3937612, 4474212, 5038524, 5630548, 5798700, + -20380, -9476, 31188, 101612, 201796, 331740, 491444, 680908, 935940, 1149116, 1427860, 1736364, 2074628, + 2442652, 2840436, 3029900, -144540, 104444, 383188, 691692, 1029956, 1397980, 1795764, 2223308, 2680612, + 3167676, 3684500, 4231084, 4807428, 5413532, 6049396, 6231180, -21756, -10020, 33524, 108876, 216036, + 355004, 525780, 728364, 1001124, 1228956, 1526964, 1856780, 2218404, 2611836, 3037076, 3240172, -153084, + 113116, 411124, 740940, 1102564, 1495996, 1921236, 2378284, 2867140, 3387804, 3940276, 4524556, 5140644, + 5788540, 6468244, 6663660, -23132, -10564, 35860, 116140, 230276, 378268, 560116, 775820, 1066308, + 1308796, 1626068, 1977196, 2362180, 2781020, 3233716, 3450444, -161628, 121788, 439060, 790188, 1175172, + 1594012, 2046708, 2533260, 3053668, 3607932, 4196052, 4818028, 5473860, 6163548, 6887092, 7096140, -24508, + -11108, 38196, 123404, 244516, 401532, 594452, 823276, 1131492, 1388636, 1725172, 2097612, 2505956, + 2950204, 3430356, 3660716, -170172, 130460, 466996, 839436, 1247780, 1692028, 2172180, 2688236, 3240196, + 3828060, 4451828, 5111500, 5807076, 6538556, 7305940, 7528620, -25884, -11652, 40532, 130668, 258756, + 424796, 628788, 870732, 1196676, 1468476, 1824276, 2218028, 2649732, 3119388, 3626996, 3870988, -178716, + 139132, 494932, 888684, 1320388, 1790044, 2297652, 2843212, 3426724, 4048188, 4707604, 5404972, 6140292, + 6913564, 7724788, 7961100, -27260, -12196, 42868, 137932, 272996, 448060, 663124, 918188, 1261860, + 1548316, 1923380, 2338444, 2793508, 3288572, 3823636, 4081260, -187260, 147804, 522868, 937932, 1392996, + 1888060, 2423124, 2998188, 3613252, 4268316, 4963380, 5698444, 6473508, 7288572, 8143636, 8393580, -28636, + -12740, 45204, 145196, 287236, 471324, 697460, 965644, 1327044, 1628156, 2022484, 2458860, 2937284, + 3457756, 4020276, 4291532, -195804, 156476, 550804, 987180, 1465604, 1986076, 2548596, 3153164, 3799780, + 4488444, 5219156, 5991916, 6806724, 7663580, 8562484, 8826060, -30012, -13284, 47540, 152460, 301476, + 494588, 731796, 1013100, 1392228, 1707996, 2121588, 2579276, 3081060, 3626940, 4216916, 4501804, -204348, + 165148, 578740, 1036428, 1538212, 2084092, 2674068, 3308140, 3986308, 4708572, 5474932, 6285388, 7139940, + 8038588, 8981332, 9258540, -31388, -13828, 49876, 159724, 315716, 517852, 766132, 1060556, 1457412, + 1787836, 2220692, 2699692, 3224836, 3796124, 4413556, 4712076, -212892, 173820, 606676, 1085676, 1610820, + 2182108, 2799540, 3463116, 4172836, 4928700, 5730708, 6578860, 7473156, 8413596, 9400180, 9691020, -32764, + -14372, 52212, 166988, 329956, 541116, 800468, 1108012, 1522596, 1867676, 2319796, 2820108, 3368612, + 3965308, 4610196, 4922348, -221436, 182492, 634612, 1134924, 1683428, 2280124, 2925012, 3618092, 4359364, + 5148828, 5986484, 6872332, 7806372, 8788604, 9819028, 10123500, -34140, -14916, 54548, 174252, 344196, + 564380, 834804, 1155468, 1587780, 1947516, 2418900, 2940524, 3512388, 4134492, 4806836, 5132620, -229980, + 191164, 662548, 1184172, 1756036, 2378140, 3050484, 3773068, 4545892, 5368956, 6242260, 7165804, 8139588, + 9163612, 10237876, 10555980, -35516, -15460, 56884, 181516, 358436, 587644, 869140, 1202924, 1652964, + 2027356, 2518004, 3060940, 3656164, 4303676, 5003476, 5342892, -238524, 199836, 690484, 1233420, 1828644, + 2476156, 3175956, 3928044, 4732420, 5589084, 6498036, 7459276, 8472804, 9538620, 10656724, 10988460, + -36892, -16004, 59220, 188780, 372676, 610908, 903476, 1250380, 1718148, 2107196, 2617108, 3181356, + 3799940, 4472860, 5200116, 5553164, -247068, 208508, 718420, 1282668, 1901252, 2574172, 3301428, 4083020, + 4918948, 5809212, 6753812, 7752748, 8806020, 9913628, 11075572, 11420940, -38268, -16548, 61556, 196044, + 386916, 634172, 937812, 1297836, 1783332, 2187036, 2716212, 3301772, 3943716, 4642044, 5396756, 5763436, + -255612, 217180, 746356, 1331916, 1973860, 2672188, 3426900, 4237996, 5105476, 6029340, 7009588, 8046220, + 9139236, 10288636, 11494420, 11853420, -39644, -17092, 63892, 203308, 401156, 657436, 972148, 1345292, + 1848516, 2266876, 2815316, 3422188, 4087492, 4811228, 5593396, 5973708, -264156, 225852, 774292, 1381164, + 2046468, 2770204, 3552372, 4392972, 5292004, 6249468, 7265364, 8339692, 9472452, 10663644, 11913268, + 12285900, -41020, -17636, 66228, 210572, 415396, 680700, 1006484, 1392748, 1913700, 2346716, 2914420, + 3542604, 4231268, 4980412, 5790036, 6183980, -272700, 234524, 802228, 1430412, 2119076, 2868220, 3677844, + 4547948, 5478532, 6469596, 7521140, 8633164, 9805668, 11038652, 12332116, 12718380, -42396, -18180, 68564, + 217836, 429636, 703964, 1040820, 1440204, 1978884, 2426556, 3013524, 3663020, 4375044, 5149596, 5986676, + 6394252, -281244, 243196, 830164, 1479660, 2191684, 2966236, 3803316, 4702924, 5665060, 6689724, 7776916, + 8926636, 10138884, 11413660, 12750964, 13150860, -43772, -18724, 70900, 225100, 443876, 727228, 1075156, + 1487660, 2044068, 2506396, 3112628, 3783436, 4518820, 5318780, 6183316, 6604524, -289788, 251868, 858100, + 1528908, 2264292, 3064252, 3928788, 4857900, 5851588, 6909852, 8032692, 9220108, 10472100, 11788668, + 13169812, 13583340 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; symmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, -1560, -2576, -3048, -2976, -2360, -1200, 504, 2736, 5544, 8880, 12760, 17184, 22152, 27664, 26040, + -29312, -26520, -23184, -19304, -14880, -9912, -4400, 1656, 8256, 15400, 23088, 31320, 40096, 49416, + 59280, 53816, 0, -5368, -9168, -11400, -12064, -11160, -8688, -4648, 2224, 8136, 16880, 27192, 39072, + 52520, 67536, 68760, -98432, -91256, -82512, -72200, -60320, -46872, -31856, -15272, 2880, 22600, 43888, + 66744, 91168, 117160, 144720, 142104, 0, -9176, -15760, -19752, -21152, -19960, -16176, -9800, 1712, + 10728, 24880, 41624, 60960, 82888, 107408, 111480, -167552, -155992, -141840, -125096, -105760, -83832, + -59312, -32200, -2496, 29800, 64688, 102168, 142240, 184904, 230160, 230392, 0, -12984, -22352, -28104, + -30240, -28760, -23664, -14952, 1200, 13320, 32880, 56056, 82848, 113256, 147280, 154200, -236672, + -220728, -201168, -177992, -151200, -120792, -86768, -49128, -7872, 37000, 85488, 137592, 193312, 252648, + 315600, 318680, 0, -16792, -28944, -36456, -39328, -37560, -31152, -20104, 688, 15912, 40880, 70488, + 104736, 143624, 187152, 196920, -305792, -285464, -260496, -230888, -196640, -157752, -114224, -66056, + -13248, 44200, 106288, 173016, 244384, 320392, 401040, 406968, 0, -20600, -35536, -44808, -48416, -46360, + -38640, -25256, 176, 18504, 48880, 84920, 126624, 173992, 227024, 239640, -374912, -350200, -319824, + -283784, -242080, -194712, -141680, -82984, -18624, 51400, 127088, 208440, 295456, 388136, 486480, 495256, + 0, -24408, -42128, -53160, -57504, -55160, -46128, -30408, -336, 21096, 56880, 99352, 148512, 204360, + 266896, 282360, -444032, -414936, -379152, -336680, -287520, -231672, -169136, -99912, -24000, 58600, + 147888, 243864, 346528, 455880, 571920, 583544, 0, -28216, -48720, -61512, -66592, -63960, -53616, -35560, + -848, 23688, 64880, 113784, 170400, 234728, 306768, 325080, -513152, -479672, -438480, -389576, -332960, + -268632, -196592, -116840, -29376, 65800, 168688, 279288, 397600, 523624, 657360, 671832, 0, -32024, + -55312, -69864, -75680, -72760, -61104, -40712, -1360, 26280, 72880, 128216, 192288, 265096, 346640, + 367800, -582272, -544408, -497808, -442472, -378400, -305592, -224048, -133768, -34752, 73000, 189488, + 314712, 448672, 591368, 742800, 760120, 0, -35832, -61904, -78216, -84768, -81560, -68592, -45864, -1872, + 28872, 80880, 142648, 214176, 295464, 386512, 410520, -651392, -609144, -557136, -495368, -423840, + -342552, -251504, -150696, -40128, 80200, 210288, 350136, 499744, 659112, 828240, 848408, 0, -39640, + -68496, -86568, -93856, -90360, -76080, -51016, -2384, 31464, 88880, 157080, 236064, 325832, 426384, + 453240, -720512, -673880, -616464, -548264, -469280, -379512, -278960, -167624, -45504, 87400, 231088, + 385560, 550816, 726856, 913680, 936696, 0, -43448, -75088, -94920, -102944, -99160, -83568, -56168, -2896, + 34056, 96880, 171512, 257952, 356200, 466256, 495960, -789632, -738616, -675792, -601160, -514720, + -416472, -306416, -184552, -50880, 94600, 251888, 420984, 601888, 794600, 999120, 1024984, 0, -47256, + -81680, -103272, -112032, -107960, -91056, -61320, -3408, 36648, 104880, 185944, 279840, 386568, 506128, + 538680, -858752, -803352, -735120, -654056, -560160, -453432, -333872, -201480, -56256, 101800, 272688, + 456408, 652960, 862344, 1084560, 1113272, 0, -51064, -88272, -111624, -121120, -116760, -98544, -66472, + -3920, 39240, 112880, 200376, 301728, 416936, 546000, 581400, -927872, -868088, -794448, -706952, -605600, + -490392, -361328, -218408, -61632, 109000, 293488, 491832, 704032, 930088, 1170000, 1201560, 0, -54872, + -94864, -119976, -130208, -125560, -106032, -71624, -4432, 41832, 120880, 214808, 323616, 447304, 585872, + 624120, -996992, -932824, -853776, -759848, -651040, -527352, -388784, -235336, -67008, 116200, 314288, + 527256, 755104, 997832, 1255440, 1289848, 0, -58680, -101456, -128328, -139296, -134360, -113520, -76776, + -4944, 44424, 128880, 229240, 345504, 477672, 625744, 666840, -1066112, -997560, -913104, -812744, + -696480, -564312, -416240, -252264, -72384, 123400, 335088, 562680, 806176, 1065576, 1340880, 1378136, 0, + -62488, -108048, -136680, -148384, -143160, -121008, -81928, -5456, 47016, 136880, 243672, 367392, 508040, + 665616, 709560, -1135232, -1062296, -972432, -865640, -741920, -601272, -443696, -269192, -77760, 130600, + 355888, 598104, 857248, 1133320, 1426320, 1466424, 0, -66296, -114640, -145032, -157472, -151960, -128496, + -87080, -5968, 49608, 144880, 258104, 389280, 538408, 705488, 752280, -1204352, -1127032, -1031760, + -918536, -787360, -638232, -471152, -286120, -83136, 137800, 376688, 633528, 908320, 1201064, 1511760, + 1554712, 0, -70104, -121232, -153384, -166560, -160760, -135984, -92232, -6480, 52200, 152880, 272536, + 411168, 568776, 745360, 795000, -1273472, -1191768, -1091088, -971432, -832800, -675192, -498608, -303048, + -88512, 145000, 397488, 668952, 959392, 1268808, 1597200, 1643000, 0, -73912, -127824, -161736, -175648, + -169560, -143472, -97384, -6992, 54792, 160880, 286968, 433056, 599144, 785232, 837720, -1342592, + -1256504, -1150416, -1024328, -878240, -712152, -526064, -319976, -93888, 152200, 418288, 704376, 1010464, + 1336552, 1682640, 1731288, 0, -77720, -134416, -170088, -184736, -178360, -150960, -102536, -7504, 57384, + 168880, 301400, 454944, 629512, 825104, 880440, -1411712, -1321240, -1209744, -1077224, -923680, -749112, + -553520, -336904, -99264, 159400, 439088, 739800, 1061536, 1404296, 1768080, 1819576, 0, -81528, -141008, + -178440, -193824, -187160, -158448, -107688, -8016, 59976, 176880, 315832, 476832, 659880, 864976, 923160, + -1480832, -1385976, -1269072, -1130120, -969120, -786072, -580976, -353832, -104640, 166600, 459888, + 775224, 1112608, 1472040, 1853520, 1907864, 0, -85336, -147600, -186792, -202912, -195960, -165936, + -112840, -8528, 62568, 184880, 330264, 498720, 690248, 904848, 965880, -1549952, -1450712, -1328400, + -1183016, -1014560, -823032, -608432, -370760, -110016, 173800, 480688, 810648, 1163680, 1539784, 1938960, + 1996152, 0, -89144, -154192, -195144, -212000, -204760, -173424, -117992, -9040, 65160, 192880, 344696, + 520608, 720616, 944720, 1008600, -1619072, -1515448, -1387728, -1235912, -1060000, -859992, -635888, + -387688, -115392, 181000, 501488, 846072, 1214752, 1607528, 2024400, 2084440, 0, -92952, -160784, -203496, + -221088, -213560, -180912, -123144, -9552, 67752, 200880, 359128, 542496, 750984, 984592, 1051320, + -1688192, -1580184, -1447056, -1288808, -1105440, -896952, -663344, -404616, -120768, 188200, 522288, + 881496, 1265824, 1675272, 2109840, 2172728, 0, -96760, -167376, -211848, -230176, -222360, -188400, + -128296, -10064, 70344, 208880, 373560, 564384, 781352, 1024464, 1094040, -1757312, -1644920, -1506384, + -1341704, -1150880, -933912, -690800, -421544, -126144, 195400, 543088, 916920, 1316896, 1743016, 2195280, + 2261016, 0, -100568, -173968, -220200, -239264, -231160, -195888, -133448, -10576, 72936, 216880, 387992, + 586272, 811720, 1064336, 1136760, -1826432, -1709656, -1565712, -1394600, -1196320, -970872, -718256, + -438472, -131520, 202600, 563888, 952344, 1367968, 1810760, 2280720, 2349304, 0, -104376, -180560, + -228552, -248352, -239960, -203376, -138600, -11088, 75528, 224880, 402424, 608160, 842088, 1104208, + 1179480, -1895552, -1774392, -1625040, -1447496, -1241760, -1007832, -745712, -455400, -136896, 209800, + 584688, 987768, 1419040, 1878504, 2366160, 2437592, 0, -108184, -187152, -236904, -257440, -248760, + -210864, -143752, -11600, 78120, 232880, 416856, 630048, 872456, 1144080, 1222200, -1964672, -1839128, + -1684368, -1500392, -1287200, -1044792, -773168, -472328, -142272, 217000, 605488, 1023192, 1470112, + 1946248, 2451600, 2525880, 0, -111992, -193744, -245256, -266528, -257560, -218352, -148904, -12112, + 80712, 240880, 431288, 651936, 902824, 1183952, 1264920, -2033792, -1903864, -1743696, -1553288, -1332640, + -1081752, -800624, -489256, -147648, 224200, 626288, 1058616, 1521184, 2013992, 2537040, 2614168, 0, + -115800, -200336, -253608, -275616, -266360, -225840, -154056, -12624, 83304, 248880, 445720, 673824, + 933192, 1223824, 1307640, -2102912, -1968600, -1803024, -1606184, -1378080, -1118712, -828080, -506184, + -153024, 231400, 647088, 1094040, 1572256, 2081736, 2622480, 2702456, 0, -119608, -206928, -261960, + -284704, -275160, -233328, -159208, -13136, 85896, 256880, 460152, 695712, 963560, 1263696, 1350360, + -2172032, -2033336, -1862352, -1659080, -1423520, -1155672, -855536, -523112, -158400, 238600, 667888, + 1129464, 1623328, 2149480, 2707920, 2790744 + ] + } + ] + } + ] + }, + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4", + "operator": "MatMulNBits", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [ + { "name": "K", "data": 32, "type": "int" }, + { "name": "N", "data": 32, "type": "int" }, + { "name": "block_size", "data": 32, "type": "int" }, + { "name": "bits", "data": 4, "type": "int" } + ], + "cases": [ + { + "name": "MatMulNBits; K=32, N=32, block_size=32, bits=4; asymmetric", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512, 513, 514, 515, 516, 517, 518, 519, 520, 521, 522, 523, 524, 525, 526, + 527, 528, 529, 530, 531, 532, 533, 534, 535, 536, 537, 538, 539, 540, 541, 542, 543, 544, 545, 546, 547, + 548, 549, 550, 551, 552, 553, 554, 555, 556, 557, 558, 559, 560, 561, 562, 563, 564, 565, 566, 567, 568, + 569, 570, 571, 572, 573, 574, 575, 576, 577, 578, 579, 580, 581, 582, 583, 584, 585, 586, 587, 588, 589, + 590, 591, 592, 593, 594, 595, 596, 597, 598, 599, 600, 601, 602, 603, 604, 605, 606, 607, 608, 609, 610, + 611, 612, 613, 614, 615, 616, 617, 618, 619, 620, 621, 622, 623, 624, 625, 626, 627, 628, 629, 630, 631, + 632, 633, 634, 635, 636, 637, 638, 639, 640, 641, 642, 643, 644, 645, 646, 647, 648, 649, 650, 651, 652, + 653, 654, 655, 656, 657, 658, 659, 660, 661, 662, 663, 664, 665, 666, 667, 668, 669, 670, 671, 672, 673, + 674, 675, 676, 677, 678, 679, 680, 681, 682, 683, 684, 685, 686, 687, 688, 689, 690, 691, 692, 693, 694, + 695, 696, 697, 698, 699, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, + 716, 717, 718, 719, 720, 721, 722, 723, 724, 725, 726, 727, 728, 729, 730, 731, 732, 733, 734, 735, 736, + 737, 738, 739, 740, 741, 742, 743, 744, 745, 746, 747, 748, 749, 750, 751, 752, 753, 754, 755, 756, 757, + 758, 759, 760, 761, 762, 763, 764, 765, 766, 767, 768, 769, 770, 771, 772, 773, 774, 775, 776, 777, 778, + 779, 780, 781, 782, 783, 784, 785, 786, 787, 788, 789, 790, 791, 792, 793, 794, 795, 796, 797, 798, 799, + 800, 801, 802, 803, 804, 805, 806, 807, 808, 809, 810, 811, 812, 813, 814, 815, 816, 817, 818, 819, 820, + 821, 822, 823, 824, 825, 826, 827, 828, 829, 830, 831, 832, 833, 834, 835, 836, 837, 838, 839, 840, 841, + 842, 843, 844, 845, 846, 847, 848, 849, 850, 851, 852, 853, 854, 855, 856, 857, 858, 859, 860, 861, 862, + 863, 864, 865, 866, 867, 868, 869, 870, 871, 872, 873, 874, 875, 876, 877, 878, 879, 880, 881, 882, 883, + 884, 885, 886, 887, 888, 889, 890, 891, 892, 893, 894, 895, 896, 897, 898, 899, 900, 901, 902, 903, 904, + 905, 906, 907, 908, 909, 910, 911, 912, 913, 914, 915, 916, 917, 918, 919, 920, 921, 922, 923, 924, 925, + 926, 927, 928, 929, 930, 931, 932, 933, 934, 935, 936, 937, 938, 939, 940, 941, 942, 943, 944, 945, 946, + 947, 948, 949, 950, 951, 952, 953, 954, 955, 956, 957, 958, 959, 960, 961, 962, 963, 964, 965, 966, 967, + 968, 969, 970, 971, 972, 973, 974, 975, 976, 977, 978, 979, 980, 981, 982, 983, 984, 985, 986, 987, 988, + 989, 990, 991, 992, 993, 994, 995, 996, 997, 998, 999, 1000, 1001, 1002, 1003, 1004, 1005, 1006, 1007, + 1008, 1009, 1010, 1011, 1012, 1013, 1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023, 1024 + ], + "dims": [32, 32], + "type": "float32" + }, + { + "dims": [32, 1, 16], + "type": "uint8", + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, + 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, + 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + 128, 29, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, + 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, + 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, + 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, + 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, + 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, + 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273, 274, + 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287, 288, 289, 290, 291, 292, 293, 294, 295, + 296, 297, 298, 299, 300, 301, 302, 303, 304, 305, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 316, + 317, 318, 319, 320, 321, 322, 323, 324, 325, 326, 327, 328, 329, 330, 331, 332, 333, 334, 335, 336, 337, + 338, 339, 340, 341, 342, 343, 344, 345, 346, 347, 348, 349, 350, 351, 352, 353, 354, 355, 356, 357, 358, + 359, 360, 361, 362, 363, 364, 365, 366, 367, 368, 369, 370, 371, 372, 373, 374, 375, 376, 377, 378, 379, + 380, 381, 382, 383, 384, 385, 386, 387, 388, 389, 390, 391, 392, 393, 394, 395, 396, 397, 398, 399, 400, + 401, 402, 403, 404, 405, 406, 407, 408, 409, 410, 411, 412, 413, 414, 415, 416, 417, 418, 419, 420, 421, + 422, 423, 424, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 439, 440, 441, 442, + 443, 444, 445, 446, 447, 448, 449, 450, 451, 452, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462, 463, + 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 474, 475, 476, 477, 478, 479, 480, 481, 482, 483, 484, + 485, 486, 487, 488, 489, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 501, 502, 503, 504, 505, + 506, 507, 508, 509, 510, 511, 512 + ] + }, + { + "dims": [32], + "type": "float32", + "data": [ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, + 29, 30, 31 + ] + }, + { + "dims": [32], + "type": "uint8", + "data": [ + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128 + ] + } + ], + "outputs": [ + { + "dims": [32, 32], + "type": "float32", + "data": [ + 0, 2664, 5872, 9624, 13920, 18760, 24144, 30072, 36528, 43560, 51120, 59224, 67872, 77064, 86800, 89400, + 38272, 45288, 52848, 60952, 69600, 78792, 88528, 98808, 109632, 121000, 132912, 145368, 158368, 171912, + 186000, 184760, 0, 7048, 15664, 25848, 37600, 50920, 65808, 82264, 101552, 119880, 141040, 163768, 188064, + 213928, 241360, 255000, 100224, 119816, 140976, 163704, 188000, 213864, 241296, 270296, 300864, 333000, + 366704, 401976, 438816, 477224, 517200, 527000, 0, 11432, 25456, 42072, 61280, 83080, 107472, 134456, + 166576, 196200, 230960, 268312, 308256, 350792, 395920, 420600, 162176, 194344, 229104, 266456, 306400, + 348936, 394064, 441784, 492096, 545000, 600496, 658584, 719264, 782536, 848400, 869240, 0, 15816, 35248, + 58296, 84960, 115240, 149136, 186648, 231600, 272520, 320880, 372856, 428448, 487656, 550480, 586200, + 224128, 268872, 317232, 369208, 424800, 484008, 546832, 613272, 683328, 757000, 834288, 915192, 999712, + 1087848, 1179600, 1211480, 0, 20200, 45040, 74520, 108640, 147400, 190800, 238840, 296624, 348840, 410800, + 477400, 548640, 624520, 705040, 751800, 286080, 343400, 405360, 471960, 543200, 619080, 699600, 784760, + 874560, 969000, 1068080, 1171800, 1280160, 1393160, 1510800, 1553720, 0, 24584, 54832, 90744, 132320, + 179560, 232464, 291032, 361648, 425160, 500720, 581944, 668832, 761384, 859600, 917400, 348032, 417928, + 493488, 574712, 661600, 754152, 852368, 956248, 1065792, 1181000, 1301872, 1428408, 1560608, 1698472, + 1842000, 1895960, 0, 28968, 64624, 106968, 156000, 211720, 274128, 343224, 426672, 501480, 590640, 686488, + 789024, 898248, 1014160, 1083000, 409984, 492456, 581616, 677464, 780000, 889224, 1005136, 1127736, + 1257024, 1393000, 1535664, 1685016, 1841056, 2003784, 2173200, 2238200, 0, 33352, 74416, 123192, 179680, + 243880, 315792, 395416, 491696, 577800, 680560, 791032, 909216, 1035112, 1168720, 1248600, 471936, 566984, + 669744, 780216, 898400, 1024296, 1157904, 1299224, 1448256, 1605000, 1769456, 1941624, 2121504, 2309096, + 2504400, 2580440, 0, 37736, 84208, 139416, 203360, 276040, 357456, 447608, 556720, 654120, 770480, 895576, + 1029408, 1171976, 1323280, 1414200, 533888, 641512, 757872, 882968, 1016800, 1159368, 1310672, 1470712, + 1639488, 1817000, 2003248, 2198232, 2401952, 2614408, 2835600, 2922680, 0, 42120, 94000, 155640, 227040, + 308200, 399120, 499800, 621744, 730440, 860400, 1000120, 1149600, 1308840, 1477840, 1579800, 595840, + 716040, 846000, 985720, 1135200, 1294440, 1463440, 1642200, 1830720, 2029000, 2237040, 2454840, 2682400, + 2919720, 3166800, 3264920, 0, 46504, 103792, 171864, 250720, 340360, 440784, 551992, 686768, 806760, + 950320, 1104664, 1269792, 1445704, 1632400, 1745400, 657792, 790568, 934128, 1088472, 1253600, 1429512, + 1616208, 1813688, 2021952, 2241000, 2470832, 2711448, 2962848, 3225032, 3498000, 3607160, 0, 50888, + 113584, 188088, 274400, 372520, 482448, 604184, 751792, 883080, 1040240, 1209208, 1389984, 1582568, + 1786960, 1911000, 719744, 865096, 1022256, 1191224, 1372000, 1564584, 1768976, 1985176, 2213184, 2453000, + 2704624, 2968056, 3243296, 3530344, 3829200, 3949400, 0, 55272, 123376, 204312, 298080, 404680, 524112, + 656376, 816816, 959400, 1130160, 1313752, 1510176, 1719432, 1941520, 2076600, 781696, 939624, 1110384, + 1293976, 1490400, 1699656, 1921744, 2156664, 2404416, 2665000, 2938416, 3224664, 3523744, 3835656, + 4160400, 4291640, 0, 59656, 133168, 220536, 321760, 436840, 565776, 708568, 881840, 1035720, 1220080, + 1418296, 1630368, 1856296, 2096080, 2242200, 843648, 1014152, 1198512, 1396728, 1608800, 1834728, 2074512, + 2328152, 2595648, 2877000, 3172208, 3481272, 3804192, 4140968, 4491600, 4633880, 0, 64040, 142960, 236760, + 345440, 469000, 607440, 760760, 946864, 1112040, 1310000, 1522840, 1750560, 1993160, 2250640, 2407800, + 905600, 1088680, 1286640, 1499480, 1727200, 1969800, 2227280, 2499640, 2786880, 3089000, 3406000, 3737880, + 4084640, 4446280, 4822800, 4976120, 0, 68424, 152752, 252984, 369120, 501160, 649104, 812952, 1011888, + 1188360, 1399920, 1627384, 1870752, 2130024, 2405200, 2573400, 967552, 1163208, 1374768, 1602232, 1845600, + 2104872, 2380048, 2671128, 2978112, 3301000, 3639792, 3994488, 4365088, 4751592, 5154000, 5318360, 0, + 72808, 162544, 269208, 392800, 533320, 690768, 865144, 1076912, 1264680, 1489840, 1731928, 1990944, + 2266888, 2559760, 2739000, 1029504, 1237736, 1462896, 1704984, 1964000, 2239944, 2532816, 2842616, + 3169344, 3513000, 3873584, 4251096, 4645536, 5056904, 5485200, 5660600, 0, 77192, 172336, 285432, 416480, + 565480, 732432, 917336, 1141936, 1341000, 1579760, 1836472, 2111136, 2403752, 2714320, 2904600, 1091456, + 1312264, 1551024, 1807736, 2082400, 2375016, 2685584, 3014104, 3360576, 3725000, 4107376, 4507704, + 4925984, 5362216, 5816400, 6002840, 0, 81576, 182128, 301656, 440160, 597640, 774096, 969528, 1206960, + 1417320, 1669680, 1941016, 2231328, 2540616, 2868880, 3070200, 1153408, 1386792, 1639152, 1910488, + 2200800, 2510088, 2838352, 3185592, 3551808, 3937000, 4341168, 4764312, 5206432, 5667528, 6147600, + 6345080, 0, 85960, 191920, 317880, 463840, 629800, 815760, 1021720, 1271984, 1493640, 1759600, 2045560, + 2351520, 2677480, 3023440, 3235800, 1215360, 1461320, 1727280, 2013240, 2319200, 2645160, 2991120, + 3357080, 3743040, 4149000, 4574960, 5020920, 5486880, 5972840, 6478800, 6687320, 0, 90344, 201712, 334104, + 487520, 661960, 857424, 1073912, 1337008, 1569960, 1849520, 2150104, 2471712, 2814344, 3178000, 3401400, + 1277312, 1535848, 1815408, 2115992, 2437600, 2780232, 3143888, 3528568, 3934272, 4361000, 4808752, + 5277528, 5767328, 6278152, 6810000, 7029560, 0, 94728, 211504, 350328, 511200, 694120, 899088, 1126104, + 1402032, 1646280, 1939440, 2254648, 2591904, 2951208, 3332560, 3567000, 1339264, 1610376, 1903536, + 2218744, 2556000, 2915304, 3296656, 3700056, 4125504, 4573000, 5042544, 5534136, 6047776, 6583464, + 7141200, 7371800, 0, 99112, 221296, 366552, 534880, 726280, 940752, 1178296, 1467056, 1722600, 2029360, + 2359192, 2712096, 3088072, 3487120, 3732600, 1401216, 1684904, 1991664, 2321496, 2674400, 3050376, + 3449424, 3871544, 4316736, 4785000, 5276336, 5790744, 6328224, 6888776, 7472400, 7714040, 0, 103496, + 231088, 382776, 558560, 758440, 982416, 1230488, 1532080, 1798920, 2119280, 2463736, 2832288, 3224936, + 3641680, 3898200, 1463168, 1759432, 2079792, 2424248, 2792800, 3185448, 3602192, 4043032, 4507968, + 4997000, 5510128, 6047352, 6608672, 7194088, 7803600, 8056280, 0, 107880, 240880, 399000, 582240, 790600, + 1024080, 1282680, 1597104, 1875240, 2209200, 2568280, 2952480, 3361800, 3796240, 4063800, 1525120, + 1833960, 2167920, 2527000, 2911200, 3320520, 3754960, 4214520, 4699200, 5209000, 5743920, 6303960, + 6889120, 7499400, 8134800, 8398520, 0, 112264, 250672, 415224, 605920, 822760, 1065744, 1334872, 1662128, + 1951560, 2299120, 2672824, 3072672, 3498664, 3950800, 4229400, 1587072, 1908488, 2256048, 2629752, + 3029600, 3455592, 3907728, 4386008, 4890432, 5421000, 5977712, 6560568, 7169568, 7804712, 8466000, + 8740760, 0, 116648, 260464, 431448, 629600, 854920, 1107408, 1387064, 1727152, 2027880, 2389040, 2777368, + 3192864, 3635528, 4105360, 4395000, 1649024, 1983016, 2344176, 2732504, 3148000, 3590664, 4060496, + 4557496, 5081664, 5633000, 6211504, 6817176, 7450016, 8110024, 8797200, 9083000, 0, 121032, 270256, + 447672, 653280, 887080, 1149072, 1439256, 1792176, 2104200, 2478960, 2881912, 3313056, 3772392, 4259920, + 4560600, 1710976, 2057544, 2432304, 2835256, 3266400, 3725736, 4213264, 4728984, 5272896, 5845000, + 6445296, 7073784, 7730464, 8415336, 9128400, 9425240, 0, 125416, 280048, 463896, 676960, 919240, 1190736, + 1491448, 1857200, 2180520, 2568880, 2986456, 3433248, 3909256, 4414480, 4726200, 1772928, 2132072, + 2520432, 2938008, 3384800, 3860808, 4366032, 4900472, 5464128, 6057000, 6679088, 7330392, 8010912, + 8720648, 9459600, 9767480, 0, 129800, 289840, 480120, 700640, 951400, 1232400, 1543640, 1922224, 2256840, + 2658800, 3091000, 3553440, 4046120, 4569040, 4891800, 1834880, 2206600, 2608560, 3040760, 3503200, + 3995880, 4518800, 5071960, 5655360, 6269000, 6912880, 7587000, 8291360, 9025960, 9790800, 10109720, 0, + 134184, 299632, 496344, 724320, 983560, 1274064, 1595832, 1987248, 2333160, 2748720, 3195544, 3673632, + 4182984, 4723600, 5057400, 1896832, 2281128, 2696688, 3143512, 3621600, 4130952, 4671568, 5243448, + 5846592, 6481000, 7146672, 7843608, 8571808, 9331272, 10122000, 10451960, 0, 138568, 309424, 512568, + 748000, 1015720, 1315728, 1648024, 2052272, 2409480, 2838640, 3300088, 3793824, 4319848, 4878160, 5223000, + 1958784, 2355656, 2784816, 3246264, 3740000, 4266024, 4824336, 5414936, 6037824, 6693000, 7380464, + 8100216, 8852256, 9636584, 10453200, 10794200 + ] + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 55b21283025c2..b43b1ac37e37d 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -1354,6 +1354,7 @@ "expand.jsonc", "fast-gelu.jsonc", "floor.jsonc", + "fused-conv.jsonc", "gather-elements.jsonc", "gemm.jsonc", "global-average-pool.jsonc", @@ -1362,6 +1363,7 @@ "less.jsonc", "log.jsonc", "matmul.jsonc", + "matmulnbits.jsonc", "matmul-broadcast.jsonc", "mul.jsonc", "mul_int32.jsonc", diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index b01d474788f25..ecc7d4b4a09a5 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -39,10 +39,6 @@ const ONNXRUNTIME_THRESHOLD_RELATIVE_ERROR = 1.00001; */ const now = (typeof performance !== 'undefined' && performance.now) ? () => performance.now() : Date.now; -function toInternalTensor(tensor: ort.Tensor): Tensor { - return new Tensor( - tensor.dims, tensor.type as Tensor.DataType, undefined, undefined, tensor.data as Tensor.NumberType); -} function fromInternalTensor(tensor: Tensor): ort.Tensor { return new ort.Tensor(tensor.type, tensor.data as ort.Tensor.DataType, tensor.dims); } @@ -330,6 +326,10 @@ export class TensorResultValidator { } checkTensorResult(actual: Tensor[], expected: Tensor[]): void { + this.checkApiTensorResult(actual.map(fromInternalTensor), expected.map(fromInternalTensor)); + } + + checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void { // check output size expect(actual.length, 'size of output tensors').to.equal(expected.length); @@ -347,10 +347,6 @@ export class TensorResultValidator { } } - checkApiTensorResult(actual: ort.Tensor[], expected: ort.Tensor[]): void { - this.checkTensorResult(actual.map(toInternalTensor), expected.map(toInternalTensor)); - } - checkNamedTensorResult(actual: Record, expected: Test.NamedTensor[]): void { // check output size expect(Object.getOwnPropertyNames(actual).length, 'size of output tensors').to.equal(expected.length); @@ -364,7 +360,7 @@ export class TensorResultValidator { } // This function check whether 2 tensors should be considered as 'match' or not - areEqual(actual: Tensor, expected: Tensor): boolean { + areEqual(actual: ort.Tensor, expected: ort.Tensor): boolean { if (!actual || !expected) { return false; } @@ -392,13 +388,13 @@ export class TensorResultValidator { switch (actualType) { case 'string': - return this.strictEqual(actual.stringData, expected.stringData); + return this.strictEqual(actual.data, expected.data); case 'float32': case 'float64': return this.floatEqual( - actual.numberData as number[] | Float32Array | Float64Array, - expected.numberData as number[] | Float32Array | Float64Array); + actual.data as number[] | Float32Array | Float64Array, + expected.data as number[] | Float32Array | Float64Array); case 'uint8': case 'int8': @@ -409,10 +405,8 @@ export class TensorResultValidator { case 'int64': case 'bool': return TensorResultValidator.integerEqual( - actual.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | - Int32Array, - expected.numberData as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | - Int32Array); + actual.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array, + expected.data as number[] | Uint8Array | Int8Array | Uint16Array | Int16Array | Uint32Array | Int32Array); default: throw new Error('type not implemented or not supported'); diff --git a/js/web/test/unittests/backends/webgl/test-conv-new.ts b/js/web/test/unittests/backends/webgl/test-conv-new.ts index 8c186b9b36451..014fc57f21558 100644 --- a/js/web/test/unittests/backends/webgl/test-conv-new.ts +++ b/js/web/test/unittests/backends/webgl/test-conv-new.ts @@ -893,7 +893,9 @@ describe('New Conv tests', () => { const expected = cpuConv( inputTensor, kernelTensor, biasTensor, testData.autoPad, testData.dilations, testData.pads, testData.strides); - if (!validator.areEqual(actual, expected)) { + try { + validator.checkTensorResult([actual], [expected]); + } catch { console.log(actual.dims, `[${actual.numberData.slice(0, 20).join(',')},...]`); console.log(expected.dims, `[${expected.numberData.slice(0, 20).join(',')},...]`); throw new Error('Expected and Actual did not match'); diff --git a/objectivec/include/ort_coreml_execution_provider.h b/objectivec/include/ort_coreml_execution_provider.h index a015b6fd60c8f..6ff18176ebeb2 100644 --- a/objectivec/include/ort_coreml_execution_provider.h +++ b/objectivec/include/ort_coreml_execution_provider.h @@ -41,6 +41,17 @@ NS_ASSUME_NONNULL_BEGIN */ @property BOOL onlyEnableForDevicesWithANE; +/** + * Only allow CoreML EP to take nodes with inputs with static shapes. By default it will also allow inputs with + * dynamic shapes. However, the performance may be negatively impacted if inputs have dynamic shapes. + */ +@property BOOL onlyAllowStaticInputShapes; + +/** + * Create an MLProgram. By default it will create a NeuralNetwork model. Requires Core ML 5 or later. + */ +@property BOOL createMLProgram; + @end @interface ORTSessionOptions (ORTSessionOptionsCoreMLEP) diff --git a/objectivec/ort_coreml_execution_provider.mm b/objectivec/ort_coreml_execution_provider.mm index 6340fdea1c3a7..58b47d68eea63 100644 --- a/objectivec/ort_coreml_execution_provider.mm +++ b/objectivec/ort_coreml_execution_provider.mm @@ -26,7 +26,10 @@ - (BOOL)appendCoreMLExecutionProviderWithOptions:(ORTCoreMLExecutionProviderOpti const uint32_t flags = (options.useCPUOnly ? COREML_FLAG_USE_CPU_ONLY : 0) | (options.enableOnSubgraphs ? COREML_FLAG_ENABLE_ON_SUBGRAPH : 0) | - (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0); + (options.onlyEnableForDevicesWithANE ? COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE : 0) | + (options.onlyAllowStaticInputShapes ? COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES : 0) | + (options.createMLProgram ? COREML_FLAG_CREATE_MLPROGRAM : 0); + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML( [self CXXAPIOrtSessionOptions], flags)); return YES; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h index 72e6d3930a548..af0904b7d6e4b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h @@ -134,8 +134,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe TensorShape no_speech_probs_shape{parameters->batch_size}; Tensor* no_speech_probs = this->context_.Output(parameters->no_speech_probs_output_id, no_speech_probs_shape); if (no_speech_probs && no_speech_probs->MutableData()) { - ORT_ENFORCE(parameters->no_speech_token >= 0 && parameters->no_speech_token < parameters->vocab_size, - "no_speech_token id out of range, it is ", parameters->no_speech_token, + ORT_ENFORCE(parameters->no_speech_token_id >= 0 && parameters->no_speech_token_id < parameters->vocab_size, + "no_speech_token_id is out of range, it is ", parameters->no_speech_token_id, ", vocab_size is ", parameters->vocab_size); this->parameters_->no_speech_probs = (void*)no_speech_probs->MutableData(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index bb6885c3216bc..93837e785b4a4 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -153,7 +153,13 @@ void WhisperBeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeWhisper)); ORT_ENFORCE(model_type == IGenerationParameters::kModelTypeWhisper); - no_speech_token = static_cast(info.GetAttrOrDefault("no_speech_token", -1LL)); + // Token ids are defined below in the order that they appear in the tokenizer + translate_token_id = static_cast(info.GetAttrOrDefault("translate_token_id", -1LL)); + transcribe_token_id = static_cast(info.GetAttrOrDefault("transcribe_token_id", -1LL)); + start_of_lm_token_id = static_cast(info.GetAttrOrDefault("start_of_lm_token_id", -1LL)); + no_speech_token_id = static_cast(info.GetAttrOrDefault("no_speech_token_id", -1LL)); + no_timestamps_token_id = static_cast(info.GetAttrOrDefault("no_timestamps_token_id", -1LL)); + beginning_timestamp_token_id = static_cast(info.GetAttrOrDefault("beginning_timestamp_token_id", -1LL)); cross_qk_layer_head_input_id = 12; extra_decoding_ids_input_id = 13; cross_qk_output_id = 3; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index cb62e2f7bf4da..b1dd55eb20f34 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -183,7 +183,14 @@ struct IGenerationParameters { // Parameters for whisper model bool decoder_output_cross_qk = false; gsl::span extra_decoding_ids; - int32_t no_speech_token = -1; + + // Token ids are defined below in the order that they appear in the tokenizer + int32_t translate_token_id = -1; + int32_t transcribe_token_id = -1; + int32_t start_of_lm_token_id = -1; + int32_t no_speech_token_id = -1; + int32_t no_timestamps_token_id = -1; + int32_t beginning_timestamp_token_id = -1; void* no_speech_probs = nullptr; int cross_qk_layer_head_input_id = -1; diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 03d4e89ac20fe..231eb17d1a947 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -10,6 +10,7 @@ #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" #include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" +#include namespace onnxruntime { namespace contrib { @@ -34,6 +35,14 @@ struct NextTokenScores { } }; +#ifdef DEBUG_GENERATION +template +void DumpScores(const char* name, const NextTokenScores& next_token_scores) { + std::cout << name << std::endl; + ORT_UNUSED_PARAMETER(next_token_scores); +} +#endif + // Interface for all scorers for beam search or beam sample. template class ILogitsProcessor { @@ -150,19 +159,25 @@ class PresencePenaltyLogitsProcessor : public ILogitsProcessor { template class TimestampLogitsProcessor : public ILogitsProcessor { public: - TimestampLogitsProcessor(int eos_token_id, int max_initial_timestamp_index) - : eos_token_id_(eos_token_id), max_initial_timestamp_index_(max_initial_timestamp_index) {} + TimestampLogitsProcessor(int end_of_text_token_id, // <|endoftext|> + int start_of_transcript_token_id, // <|startoftranscript|> + int translate_token_id, // <|translate|> + int transcribe_token_id, // <|transcribe|> + int start_of_lm_token_id, // <|startoflm|> + int no_timestamps_token_id, // <|notimestamps|> + int beginning_timestamp_token_id, // <|0.00|> + int max_initial_timestamp_index) + : end_of_text_token_id_(end_of_text_token_id), + start_of_transcript_token_id_(start_of_transcript_token_id), + translate_token_id_(translate_token_id), + transcribe_token_id_(transcribe_token_id), + start_of_lm_token_id_(start_of_lm_token_id), + no_timestamps_token_id_(no_timestamps_token_id), + beginning_timestamp_token_id_(beginning_timestamp_token_id), + max_initial_timestamp_index_(max_initial_timestamp_index) {} void Process(const ISequences* sequences, NextTokenScores& next_token_scores) override { - // TODO: translate_token_id_ and transcribe_token_id_ need to support both multilingual and English-only models. - const int beg_token_id_ = eos_token_id_ + 107; - const int not_token_id_ = eos_token_id_ + 106; - const int solm_token_id_ = eos_token_id_ + 105; - const int sot_token_id_ = eos_token_id_ + 1; - constexpr int translate_token_id_ = 50358; - constexpr int transcribe_token_id_ = 50359; - const int batch_beam_size = next_token_scores.batch_beam_size; const int vocab_size = next_token_scores.vocab_size; for (int i = 0; i < batch_beam_size; i++) { @@ -174,7 +189,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { size_t sample_begin = 0; for (size_t j = 0; j < seq_length; j++) { sample_begin++; - if (sequence[j] >= beg_token_id_) { + if (sequence[j] >= beginning_timestamp_token_id_) { break; } } @@ -182,30 +197,30 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Suppress tokens for (int j = 0; j < vocab_size; j++) { // Suppress notimestamps and solm tokens - if (j == not_token_id_ || j == solm_token_id_) { + if (j == no_timestamps_token_id_ || j == start_of_lm_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } // Suppress sot, translate and transcribe tokens if (seq_length > sample_begin) { - if (j == sot_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { + if (j == start_of_transcript_token_id_ || j == translate_token_id_ || j == transcribe_token_id_) { beam_token_scores[j] = std::numeric_limits::lowest(); } } } // Timestamps should be in pair except the first one - const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beg_token_id_; - const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beg_token_id_; + const bool last_was_timestamp = seq_length > 0 && sequence.back() >= beginning_timestamp_token_id_; + const bool penultimate_was_timestamp = seq_length <= sample_begin || sequence[seq_length - 2] >= beginning_timestamp_token_id_; if (last_was_timestamp) { if (penultimate_was_timestamp) { // If timestamps show up in pair, or it's the first timestamp, no more timestamp is generated - for (int j = beg_token_id_; j < vocab_size; j++) { + for (int j = beginning_timestamp_token_id_; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } else { // If timestamp doesn't show up in pair, generate timestamp - for (int j = 0; j < eos_token_id_; j++) { + for (int j = 0; j < end_of_text_token_id_; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -214,7 +229,7 @@ class TimestampLogitsProcessor : public ILogitsProcessor { // Find timestamp tokens std::vector timestamps; for (const auto& word_id : sequence) { - if (word_id >= beg_token_id_) { + if (word_id >= beginning_timestamp_token_id_) { timestamps.push_back(word_id); } } @@ -231,13 +246,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { timestamp_last = timestamps.back() + 1; } - for (int j = beg_token_id_; j < timestamp_last; j++) { + for (int j = beginning_timestamp_token_id_; j < timestamp_last; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } } if (seq_length == sample_begin) { - const int last_allowed = beg_token_id_ + max_initial_timestamp_index_; + const int last_allowed = beginning_timestamp_token_id_ + max_initial_timestamp_index_; for (int j = last_allowed + 1; j < vocab_size; j++) { beam_token_scores[j] = std::numeric_limits::lowest(); } @@ -247,8 +262,8 @@ class TimestampLogitsProcessor : public ILogitsProcessor { float timestamp_logprob = std::numeric_limits::lowest(); { float logsumexp = 0.0f; - const float logprob_max = *std::max_element(beam_token_scores.begin() + beg_token_id_, beam_token_scores.end()); - for (int j = beg_token_id_; j < vocab_size; ++j) { + const float logprob_max = *std::max_element(beam_token_scores.begin() + beginning_timestamp_token_id_, beam_token_scores.end()); + for (int j = beginning_timestamp_token_id_; j < vocab_size; ++j) { if (beam_token_scores[j] > std::numeric_limits::lowest()) { logsumexp += expf(beam_token_scores[j] - logprob_max); } @@ -258,9 +273,9 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } } - const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beg_token_id_); + const float max_text_token_logprob = *std::max_element(beam_token_scores.begin(), beam_token_scores.begin() + beginning_timestamp_token_id_); if (timestamp_logprob > max_text_token_logprob) { - for (int j = 0; j < beg_token_id_; ++j) { + for (int j = 0; j < beginning_timestamp_token_id_; ++j) { beam_token_scores[j] = std::numeric_limits::lowest(); } } @@ -268,7 +283,13 @@ class TimestampLogitsProcessor : public ILogitsProcessor { } private: - int eos_token_id_; + int end_of_text_token_id_; + int start_of_transcript_token_id_; + int translate_token_id_; + int transcribe_token_id_; + int start_of_lm_token_id_; + int no_timestamps_token_id_; + int beginning_timestamp_token_id_; int max_initial_timestamp_index_; }; @@ -330,7 +351,15 @@ class LogitsProcessorList : public ILogitsProcessorList { // Add timestamp processor for whisper model if (parameters.model_type == IGenerationParameters::kModelTypeWhisper && parameters.logits_processor == IGenerationParameters::kLogitsProcessorTypeWhisper) { constexpr int max_initial_timestamp_index = 50; - timestamp_processor_ = std::make_unique>(parameters.eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + timestamp_processor_ = std::make_unique>(parameters.eos_token_id, + parameters.decoder_start_token_id, + parameters.translate_token_id, + parameters.transcribe_token_id, + parameters.start_of_lm_token_id, + parameters.no_timestamps_token_id, + parameters.beginning_timestamp_token_id, + max_initial_timestamp_index); processor_list_.push_back(timestamp_processor_.get()); } diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 8f368251f12c7..be8c0dc86c135 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -120,6 +120,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float_float_MLFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16_float_float, SimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); @@ -318,6 +319,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index bba30805ae1be..7adc2fe0a67ea 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -424,7 +424,7 @@ Status ProcessLogits(const OrtValue& logits, // const bool is_whisper_model = (parameters->model_type == onnxruntime::contrib::transformers::IGenerationParameters::kModelTypeWhisper); if (step == 1 && is_whisper_model && parameters->no_speech_probs) { cuda::LaunchSaveNoSpeechProbs( - (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token, cuda_stream); + (T*)parameters->no_speech_probs, Y_data, batch_size, num_beams, vocab_size, parameters->no_speech_token_id, cuda_stream); } // NOTE: currently we treat extra decoding ids are same @@ -469,7 +469,15 @@ Status ProcessLogits(const OrtValue& logits, // cudaMemcpyDeviceToHost, cuda_stream)); constexpr int max_initial_timestamp_index = 50; - onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, max_initial_timestamp_index); + // Token ids are passed below in the order that they appear in the tokenizer + onnxruntime::contrib::transformers::TimestampLogitsProcessor time_logit_processor(parameters->eos_token_id, + parameters->decoder_start_token_id, + parameters->translate_token_id, + parameters->transcribe_token_id, + parameters->start_of_lm_token_id, + parameters->no_timestamps_token_id, + parameters->beginning_timestamp_token_id, + max_initial_timestamp_index); onnxruntime::contrib::transformers::NextTokenScores next_token_scores_timestamp({cpu_next_token_scores_span, batch_beam_size, vocab_size}); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index bd58dded026a6..25e7567a2e9fc 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -8,13 +8,14 @@ namespace contrib { namespace js { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Attention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FastGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, FusedConv); template <> KernelCreateInfo BuildKernelCreateInfo() { @@ -25,14 +26,15 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + SkipLayerNormalization)>}; for (auto& function_table_entry : function_table) { KernelCreateInfo info = function_table_entry(); diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc new file mode 100644 index 0000000000000..888db0fd161f2 --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.cc @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/js/quantization/matmul_nbits.h" +#include "core/providers/js/js_data_types.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + MatMulNBits, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", JsepSupportedFloatTypes()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulNBits); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h new file mode 100644 index 0000000000000..cca2c4757765b --- /dev/null +++ b/onnxruntime/contrib_ops/js/quantization/matmul_nbits.h @@ -0,0 +1,48 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class MatMulNBits final : public JsKernel { + public: + MatMulNBits(const OpKernelInfo& info) : JsKernel(info), + K_{narrow(info.GetAttr("K"))}, + N_{narrow(info.GetAttr("N"))}, + accuracy_level_{info.GetAttrOrDefault("accuracy_level", 0)}, + nbits_{narrow(info.GetAttr("bits"))}, + block_size_{narrow(info.GetAttr("block_size"))} { + ORT_ENFORCE(nbits_ == 4, + "Only 4b quantization is supported for MatMulNBits op, additional bits support is planned."); + ORT_ENFORCE(block_size_ >= 16 && !(block_size_ & (block_size_ - 1)), + "Block size must be a power of 2 and greater than or equal to 16."); + JSEP_INIT_KERNEL_ATTRIBUTE(MatMulNBits, ({ + "k" : $1, + "n" : $2, + "accuracyLevel" : $3, + "bits" : $4, + "blockSize" : $5 + }), + static_cast(K_), + static_cast(N_), + static_cast(accuracy_level_), + static_cast(nbits_), + static_cast(block_size_)); + } + + private: + const size_t K_; + const size_t N_; + const int64_t accuracy_level_; + const size_t nbits_; + const size_t block_size_; +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc deleted file mode 100644 index e82e15a304f4c..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ /dev/null @@ -1,152 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/rocm/rocm_common.h" -#include "contrib_ops/rocm/diffusion/group_norm.h" -#include "contrib_ops/rocm/diffusion/group_norm_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -#define GROUP_NORM_TYPES float, MLFloat16 - -ONNX_OPERATOR_KERNEL_EX( - GroupNorm, kMSDomain, 1, kRocmExecutionProvider, - (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); - -using namespace ONNX_NAMESPACE; - -namespace { -template -struct DispatchGroupNorm { - Status operator()(RocmTuningContext* tuning_ctx, - Stream* stream, - Tensor* output, - const Tensor* input, - const Tensor* gamma, - const Tensor* beta, - void* workspace, - float epsilon, - int batch_size, - int num_channels, - int height, - int width, - int num_groups, - bool use_swish_activation) { - typedef typename ToHipType::MappedType HipT; - return LaunchGroupNormKernel( - tuning_ctx, - stream, - reinterpret_cast(output->MutableData()), - reinterpret_cast(input->Data()), - gamma->Data(), - beta->Data(), - workspace, - epsilon, - batch_size, - num_channels, - height, - width, - num_groups, - use_swish_activation); - } -}; - -} // namespace - -GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { - epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); - ORT_ENFORCE(epsilon_ >= 0); - - int64_t num_groups; - ORT_ENFORCE(op_info.GetAttr("groups", &num_groups).IsOK()); - ORT_ENFORCE(num_groups >= 0); - num_groups_ = static_cast(num_groups); - - int64_t activation; - ORT_ENFORCE(op_info.GetAttr("activation", &activation).IsOK()); - ORT_ENFORCE(activation == 0 || activation == 1); // 0 is None, 1 is Swish - use_swish_activation_ = (activation == 1); - - channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); -} - -Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, - bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { - is_packed = false; - return Status::OK(); -} - -Status GroupNorm::ComputeInternal(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* gamma = context->Input(1); - const Tensor* beta = context->Input(2); - Tensor* output = context->Output(0, input->Shape()); - - if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "only the channels_last layout is supported"); - } - - const auto& input_dims = input->Shape().GetDims(); - if (input_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "input is expected to have 4 dimensions, got ", input_dims.size()); - } - - const auto& gamma_dims = gamma->Shape().GetDims(); - if (gamma_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "gamma is expected to have 1 dimension, got ", gamma_dims.size()); - } - if (gamma_dims[0] != input_dims[3]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in gamma and input does not match"); - } - - const auto& beta_dims = beta->Shape().GetDims(); - if (beta_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "beta is expected to have 1 dimension, got ", beta_dims.size()); - } - if (beta_dims[0] != input_dims[3]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels in beta and input does not match"); - } - - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisible by num_groups"); - } - - if (context->GetUseDeterministicCompute()) { - static std::once_flag log_warning; - std::call_once(log_warning, []() { - LOGS_DEFAULT(WARNING) << "GroupNorm has no deterministic GPU kernel, its outputs may still be nondeterministic."; - }); - } - - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); - - utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(GetTuningContext(), context->GetComputeStream(), - output, input, gamma, beta, workspace.get(), - epsilon_, - batch_size, - num_channels, - height, - width, - num_groups_, - use_swish_activation_); -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index fb7091592c16e..d0a0d09fcbae3 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -26,13 +26,18 @@ namespace rocm { using onnxruntime::rocm::CKDataTypeAdaptor; -using Swish = ck::tensor_operation::element_wise::Swish; +// The SiLU function is a special case of Swish function, +// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as: +// SiLU(x) = x * sigmoid(x) +// Swish(x) = x * sigmoid(bx) +// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here. +using Silu = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; constexpr int Rank = 5; constexpr int NumReduceDim = 3; -template +template auto GetCKGroupNormNHWCTypeStringAndOps() { using XDataType = typename CKDataTypeAdaptor::type; using YDataType = typename CKDataTypeAdaptor::type; @@ -40,26 +45,30 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { using GammaDataType = float; using BetaDataType = float; - using Activation = std::conditional_t; + using Activation = std::conditional_t; - std::vector>>> ret; + std::vector>>> ret; for (auto&& impl : internal::GetDeviceGroupNormInstances()) { - std::string swish_suffix = WithSwish ? "_Swish" : "_Pass"; - auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix; + std::string silu_suffix = WithSilu ? "_Silu" : "_Pass"; + auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix; auto invoker = impl->MakeInvokerPointer(); - auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams* params) -> Status { - if constexpr (WithSwish) { + auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Input skip or bias is not supported by composable kernel."); + if constexpr (WithSilu) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !params->withSwish, "Swish version only support groupnorm with swish"); + !params->use_silu, "Silu version only support groupnorm with silu"); } else { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->withSwish, "Pass version only support groupnorm without swish"); + params->use_silu, "Pass version only support groupnorm without silu"); } - std::vector in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup}; - std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1}; - std::vector gamma_beta_strides{0, 0, 0, params->cPerGroup, 1}; + std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; + std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, + params->c, params->channels_per_group, 1}; + std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; std::vector reduce_dims{1, 2, 4}; auto activation = Activation{}; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh index 19b081881dcec..4cb371fdcf960 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl.cuh @@ -18,7 +18,7 @@ namespace internal { using F16 = ck::half_t; using F32 = float; -using Swish = ck::tensor_operation::element_wise::Swish; +using Silu = ck::tensor_operation::element_wise::Swish; using Pass = ck::tensor_operation::element_wise::PassThrough; using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface @@ -101,9 +101,9 @@ GetDeviceGroupNormInstances() { template <> std::vector>> + F16, F32, F32, F16, F32, Silu, 5, 3>>> GetDeviceGroupNormInstances< - F16, F32, F32, F16, F32, Swish, 5, 3>(); + F16, F32, F32, F16, F32, Silu, 5, 3>(); template <> std::vector std::vector>> + F32, F32, F32, F32, F32, Silu, 5, 3>>> GetDeviceGroupNormInstances< - F32, F32, F32, F32, F32, Swish, 5, 3>(); + F32, F32, F32, F32, F32, Silu, 5, 3>(); template <> std::vector -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, - device_normalization_f16_instances{}); + device_normalization_f16_instances{}); return instances; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu index 9b0ccab17b4c1..ceb53ed442abc 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck_impl/impl_fp32.cu @@ -11,12 +11,12 @@ namespace rocm { namespace internal { template <> -std::vector>> -GetDeviceGroupNormInstances() { - std::vector>> instances; +std::vector>> +GetDeviceGroupNormInstances() { + std::vector>> instances; ck::tensor_operation::device::instance::add_device_operation_instances( instances, - device_normalization_f32_instances{}); + device_normalization_f32_instances{}); return instances; } diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h index 008ae20b0561f..7cff640db2f34 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h @@ -8,110 +8,47 @@ #include "core/providers/rocm/cu_inc/common.cuh" #include "core/providers/rocm/rocm_common.h" #include "core/providers/rocm/tunable/rocm_tunable.h" +#include "contrib_ops/rocm/diffusion/group_norm_common_base.h" namespace onnxruntime { namespace contrib { namespace rocm { -using onnxruntime::rocm::CeilDiv; - -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; - for (int32_t i = 1; i <= std::sqrt(n); i++) { - if (n % i == 0) { - int32_t divisor1 = n / i; - int32_t divisor2 = i; - - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; - } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; - } - } - } - return maxDivisor; -} - template -struct GroupNormNHWCParams : OpParams { - GroupNormNHWCParams(RocmTuningContext* tuning_ctx, onnxruntime::Stream* stream, T* dst, float* redBuffer, const T* src, const float* gamma, - const float* beta, int32_t n, int32_t h, int32_t w, int32_t c, int32_t groups, float epsilon, bool withSwish) - : OpParams(tuning_ctx, stream), dst(dst), src(src), gamma(gamma), beta(beta), redBuffer(redBuffer), epsilon(epsilon), n(n), h(h), w(w), c(c), groups(groups), withSwish(withSwish) { - int32_t maxBlocksPerHW = 1024; - switch (c) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; - } - - hw = h * w; - const int32_t blocksPerHW = findMaxDivisor(hw, maxBlocksPerHW); - hwPerBlock = CeilDiv(hw, blocksPerHW); - cPerGroup = c / groups; - hwc = hw * c; - invHWC = 1.F / (float)(hw * cPerGroup); - groupsPerBlock = cPerBlock / cPerGroup; - } +struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { + GroupNormNHWCTunableParams(RocmTuningContext* tuning_ctx, + onnxruntime::Stream* ort_stream, + T* output, + T* add_out, + const T* input, + const T* skip, + const T* bias, + const float* gamma, + const float* beta, + float* workspace, + float epsilon, + int batch_size, + int num_channels, + int height, + int width, + int num_groups, + bool use_silu, + bool broadcast_skip, + int channels_per_block) + : OpParams(tuning_ctx, ort_stream), + GroupNormNHWCParams(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, + num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) {} std::string Signature() const override { - std::string swish_suffix = withSwish ? "_Swish" : "_Pass"; - std::string sig = std::to_string(n) + "_" + std::to_string(h * w) + "_" + std::to_string(c) + "_" + std::to_string(groups) + swish_suffix; + std::string silu_suffix = this->use_silu ? "_silu" : "_pass"; + std::string skip_suffix = this->skip != nullptr ? "_skip" : "_noskip"; + std::string broadcast_suffix = this->broadcast_skip ? "_broadcast" : "_nobroadcast"; + std::string bias_suffix = this->bias != nullptr ? "_bias" : "_nobias"; + std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + + std::to_string(this->c) + "_" + std::to_string(this->groups) + silu_suffix + + skip_suffix + broadcast_suffix + bias_suffix; return sig; } - - // The output buffer. Layout NHWC. - T* dst; - // The input buffer. Layout NHWC. - T const* src; - // The gamma scaling factor. - float const* gamma; - // The beta term to add in GN. - float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; - float epsilon; - - // The number of instances in the batch. - int32_t n; - // The height and width of each activation map. - int32_t h; - int32_t w; - // The number of channels. - int32_t c; - // The number of groups. - int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; - - // Precomputed values and parameters to control the execution of the kernels. - - // The number of activations per instance (h * w) and the number of - // activations per block. - int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; - - // The precomputed stride between instances. - int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; - // The precomputed number of groups per block. - int32_t groupsPerBlock; }; } // namespace rocm diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu index dbd5009e63676..142aaf14e8d2d 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu @@ -15,9 +15,12 @@ namespace rocm { template Status LaunchGroupNormKernel( RocmTuningContext* tuning_ctx, - Stream* stream, + Stream* ort_stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -27,19 +30,26 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, + reinterpret_cast(workspace), epsilon, batch_size, num_channels, + height, width, num_groups, use_silu, broadcast_skip, channels_per_block); - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + if (params.channels_per_block % params.channels_per_group != 0 || + params.channels_per_block > kMaxSize || + (params.channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in ROCM does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - GroupNormNHWCParams params(tuning_ctx, stream, output, reinterpret_cast(workspace), input, gamma, beta, - batch_size, height, width, num_channels, num_groups, epsilon, use_swish_activation); + HIP_RETURN_IF_ERROR(hipMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), params.StreamHandle())); if (tuning_ctx->IsTunableOpEnabled()) { static GroupNormNHWCTunableOp op; @@ -50,14 +60,17 @@ Status LaunchGroupNormKernel( } template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, - float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + half* add_out, const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, - float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + float* add_out, const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); + } // namespace rocm } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h deleted file mode 100644 index a0f7e0aca5def..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.h +++ /dev/null @@ -1,47 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -#include "core/common/common.h" -#include "core/common/status.h" -#include "core/providers/rocm/tunable/rocm_tunable.h" - -using onnxruntime::rocm::tunable::RocmTuningContext; - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { - // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; -} - -template -Status LaunchGroupNormKernel( - RocmTuningContext* tuning_ctx, - Stream* stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization -); - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh deleted file mode 100644 index d6322a12a9363..0000000000000 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl_kernel.cuh +++ /dev/null @@ -1,213 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// The ROCm kernel is modified from TensorRT 8.5. -/* - * SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include "core/providers/rocm/cu_inc/common.cuh" -#include "core/providers/rocm/rocm_common.h" - -namespace onnxruntime { -namespace contrib { -namespace rocm { - -static inline __device__ __host__ float sigmoid(float x) { - return 1.F / (1.F + expf(-x)); -} - -struct GroupSums { - // Is it the 1st element of the group? - int32_t flag; - // The sum. - float sum; - // The sum of squares. - float sumSq; -}; - -struct GroupSumsOp { - inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { - GroupSums dst; - dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); - dst.flag = a.flag + b.flag; - return dst; - } -}; - -template -inline __device__ void UpdateSum(const T* src, int64_t offset, U& sum, U& sumSq) { - using VecT = onnxruntime::rocm::aligned_vector; - const VecT input_v = *reinterpret_cast(src + offset); - -#pragma unroll - for (int i = 0; i < ILP; i++) { - const U val = static_cast(input_v.val[i]); - sum += val; - sumSq += val * val; - } -} - -template -__global__ void groupNormNHWCSumKernel(const T* src, float* redBuffer, int32_t cPerBlock, int32_t hwPerBlock, int32_t hw, - int32_t hwc, int32_t c, int32_t cPerGroup, int32_t groups, int32_t groupsPerBlock) { - // The object in charge of doing the sums for the different blocks. - typedef hipcub::BlockScan BlockScan; - - // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[ThreadsPerBlock]; - - // The instance in the batch. - int32_t ni = blockIdx.z; - // The channel loaded by that thread (ILP channels per thread). - int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP; - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + hwPerBlock, hw); - - // The sums. - float sum = 0.F; - float sumSq = 0.F; - - // Iterate over the activations to compute the sums. - if (ci < c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * hwc + static_cast(hwi) * c + ci; - UpdateSum(src, offset, sum, sumSq); - } - } - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * ILP / cPerGroup; - int32_t cj = threadIdx.x * ILP - cPerGroup * gi; - - // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; - - // Do the segmented scan. - GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); - - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == cPerGroup - ILP) { // ILP channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The global group index. - int32_t gj = blockIdx.x * groupsPerBlock + threadIdx.x; - - // Threads that have nothing left to do, exit. - if (threadIdx.x >= groupsPerBlock || gj >= groups) { - return; - } - - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(&redBuffer[(2 * ni + 0) * groups + gj], sums.x); - atomicAdd(&redBuffer[(2 * ni + 1) * groups + gj], sums.y); -} - -template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, U mean, U invStdDev, - const U* gamma_v, const U* beta_v, bool swish) { - using VecT = onnxruntime::rocm::aligned_vector; - const VecT input_v = *reinterpret_cast(src + offset); - VecT output_v; - -#pragma unroll - for (int i = 0; i < ILP; i++) { - U val = static_cast(input_v.val[i]); - val = (val - mean) * invStdDev; - val = gamma_v[i] * val + beta_v[i]; - - if (swish) { - val = val * sigmoid(val); - } - output_v.val[i] = static_cast(val); - } - *(reinterpret_cast(dst + offset)) = output_v; -} - -template -__global__ void groupNormNHWCScaleKernel(T* dst, const T* src, const float* gamma, const float* beta, const float* redBuffer, float epsilon, int32_t c, int32_t cPerBlock, - int32_t cPerGroup, int32_t groups, int32_t hwc, float invHWC, int32_t hw, int32_t hwPerBlock, bool withSwish) { - // The channel loaded by that thread (ILP channels per thread for F16x2). - int32_t ci = blockIdx.x * cPerBlock + threadIdx.x * ILP; - if (ci >= c) { - return; - } - - // The instance in the batch. - int32_t ni = blockIdx.z; - - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / cPerGroup; - - // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; - if (gi < groups) { - sum = redBuffer[(2 * ni + 0) * groups + gi]; - sumSq = redBuffer[(2 * ni + 1) * groups + gi]; - } - - using VecF = onnxruntime::rocm::aligned_vector; - - const VecF gamma_v = *reinterpret_cast(gamma + ci); - const VecF beta_v = *reinterpret_cast(beta + ci); - - // Compute the mean. - float mean = sum * invHWC; - // Compute the variance. - float var = sumSq * invHWC - (mean * mean); - // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var + epsilon); - - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + hwPerBlock, hw); - - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * hwc + hwi * c + ci; - - // Fetch ILP channels per thread. - computeGroupNorm(src, dst, offset, mean, invStdDev, gamma_v.val, beta_v.val, withSwish); - } -} - -} // namespace rocm -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index b7b9441ac997d..b3d3e92209b39 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -20,21 +20,21 @@ namespace rocm { namespace { -template +template std::string GetGroupNormTritonGroupName() { std::string ret = "GroupNormTriton_"; - std::string swish_suffix = WithSwish ? "Swish_" : "Pass_"; - ret += swish_suffix; + std::string silu_suffix = WithSilu ? "Silu_" : "Pass_"; + ret += silu_suffix; ret += GetDataTypeName(); return ret; } } // namespace -template +template auto GetTritonGroupNormNHWCTypeStringAndOps() { - std::vector>>> ret; - auto group_name = GetGroupNormTritonGroupName(); + std::vector>>> ret; + auto group_name = GetGroupNormTritonGroupName(); auto* kernel_list = GetOrtTritonKernelByGroup(group_name); if (kernel_list == nullptr) { return ret; @@ -45,16 +45,19 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto* metadata = GetOrtTritonKernelMetadata(i); auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); - auto impl = [i, block_size, hw_size](const GroupNormNHWCParams* params) -> Status { + auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Input skip or bias is not supported by triton kernel."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->cPerGroup > block_size || params->cPerGroup * 2 <= block_size, - "Arg block_size (", block_size, ") is not the next power of 2 of cPerGroup (", params->cPerGroup, ")."); + params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, + "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", + params->channels_per_group, ")."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); - if constexpr (WithSwish) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->withSwish, "Swish version does not support GN w/o swish."); + if constexpr (WithSilu) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!params->use_silu, "Silu version does not support GN w/o silu."); } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->withSwish, "Pass version does not support GN w/ swish."); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(params->use_silu, "Pass version does not support GN w/ silu."); } // Construct args for launch kernel struct { @@ -73,7 +76,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { (const void*)params->beta, params->hw, params->c, - params->cPerGroup, + params->channels_per_group, params->epsilon}; // Grid dim is (batch_count, groups, 1) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py index 56b3a030b289e..5368cb1cf635b 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.py @@ -21,7 +21,7 @@ def group_norm_kernel( eps, BLOCK_SIZE: tl.constexpr, HW_SIZE: tl.constexpr, - ACTIVATION_SWISH: tl.constexpr, + ACTIVATION_SILU: tl.constexpr, ): row_x = tl.program_id(0) row_y = tl.program_id(1) @@ -62,7 +62,7 @@ def group_norm_kernel( x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32) x_hat = (x - group_mean) * rstd y = x_hat * gamma + beta - if ACTIVATION_SWISH: + if ACTIVATION_SILU: y *= tl.sigmoid(y) tl.store(y_ptr + offsets, y, mask=mask) @@ -71,7 +71,7 @@ def group_norm_kernel( # blocks = [16, 32, 64, 128, 256, 512] # hw_sizes = [8, 16, 32, 64, 128, 256, 512] # but this will result in too many functions and slow down the compilation. -with_swish = [True, False] +with_silu = [True, False] dtypes = ["fp32", "fp16"] blocks = [16, 32, 64, 128] hw_sizes = [8, 16, 32, 64, 128, 256] @@ -84,14 +84,14 @@ def group_norm_kernel( def get_function_table(): func_table = [] - for swish, dtype, hw_size, warp, b in product(with_swish, dtypes, hw_sizes, warps, blocks): - swish_suffix = "Swish" if swish else "Pass" - name = name_pattern.format(swish_suffix, dtype, b, hw_size, warp) - group = group_pattern.format(swish_suffix, dtype) + for silu, dtype, hw_size, warp, b in product(with_silu, dtypes, hw_sizes, warps, blocks): + silu_suffix = "Silu" if silu else "Pass" + name = name_pattern.format(silu_suffix, dtype, b, hw_size, warp) + group = group_pattern.format(silu_suffix, dtype) sig = sig_pattern.format(dtype, dtype) kwargs = { "num_warps": warp, - "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SWISH": int(swish)}, + "constants": {"BLOCK_SIZE": b, "HW_SIZE": hw_size, "ACTIVATION_SILU": int(silu)}, } func_desc = {"name": name, "group": group, "func": group_norm_kernel, "sig": sig, "kwargs": kwargs} func_table.append(func_desc) diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h index 25d820f7ed326..e6831f764b418 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h @@ -20,115 +20,117 @@ namespace rocm { using onnxruntime::rocm::GPU_WARP_SIZE; template -void groupNormNHWCSum(const GroupNormNHWCParams* params) { - // Make sure the values are as we expect. - ORT_ENFORCE(params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0); - +void GroupNormNHWCSum(const GroupNormNHWCTunableParams* params) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params->c / params->cPerBlock; + grid.x = DivUp(params->c, params->channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.y = DivUp(params->hw, params->hw_per_block); // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ - groupNormNHWCSumKernel \ - <<StreamHandle()>>>( \ - params->src, params->redBuffer, params->cPerBlock, \ - params->hwPerBlock, params->hw, params->hwc, params->c, \ - params->cPerGroup, params->groups, params->groupsPerBlock); \ +#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ + GroupNormNHWCSumKernel \ + <<StreamHandle()>>>( \ + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, \ + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, \ + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); \ break; - switch (params->cPerBlock) { - case 320: - LAUNCH_GROUPNORM_SUM(256, 2) - case 480: - LAUNCH_GROUPNORM_SUM(256, 2) + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { case 256: - LAUNCH_GROUPNORM_SUM(128, 2) + LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) case 128: - LAUNCH_GROUPNORM_SUM(64, 2) + LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) default: ORT_NOT_IMPLEMENTED("Not implemented"); } } template -Status GroupNormNHWCSumOp(const GroupNormNHWCParams* params) { +Status GroupNormNHWCSumOp(const GroupNormNHWCTunableParams* params) { dim3 grid; - grid.x = params->c / params->cPerBlock; - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); grid.z = params->n; - groupNormNHWCSumKernel + GroupNormNHWCSumKernel <<StreamHandle()>>>( - params->src, params->redBuffer, params->cPerBlock, params->hwPerBlock, - params->hw, params->hwc, params->c, params->cPerGroup, params->groups, params->groupsPerBlock); + params->skip_workspace, params->group_sum_buffer, params->src, params->skip, params->bias, + params->channels_per_block, params->hw_per_block, params->hw, params->hwc, params->c, + params->channels_per_group, params->groups, params->groups_per_block, params->broadcast_skip); return HIP_CALL(hipGetLastError()); } template -void groupNormNHWCScale(const GroupNormNHWCParams* params) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params->c % params->cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params->cPerBlock % params->cPerGroup == 0); - +void GroupNormNHWCScale(const GroupNormNHWCTunableParams* params) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params->c / params->cPerBlock; + grid.x = DivUp(params->c, params->channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.y = DivUp(params->hw, params->hw_per_block); // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ - groupNormNHWCScaleKernel \ - <<StreamHandle()>>>( \ - params->dst, params->src, params->gamma, params->beta, \ - params->redBuffer, params->epsilon, params->c, params->cPerBlock, \ - params->cPerGroup, params->groups, params->hwc, params->invHWC, \ - params->hw, params->hwPerBlock, params->withSwish); \ +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<StreamHandle()>>>( \ + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ + params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ + params->hw, params->hw_per_block, params->use_silu); \ break; - switch (params->cPerBlock) { - case 320: - LAUNCH_GROUPNORM_SCALE(256, 2) - case 480: - LAUNCH_GROUPNORM_SCALE(256, 2) + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params->threads_per_block) { case 256: - LAUNCH_GROUPNORM_SCALE(128, 2) + LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) + case 192: + LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) + case 160: + LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) case 128: - LAUNCH_GROUPNORM_SCALE(64, 2) + LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) + case 64: + LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) default: ORT_NOT_IMPLEMENTED("Not implemented"); } } template -Status GroupNormNHWCScaleOp(const GroupNormNHWCParams* params) { +Status GroupNormNHWCScaleOp(const GroupNormNHWCTunableParams* params) { dim3 grid; - grid.x = params->c / params->cPerBlock; - grid.y = CeilDiv(params->hw, params->hwPerBlock); + grid.x = DivUp(params->c, params->channels_per_block); + grid.y = DivUp(params->hw, params->hw_per_block); grid.z = params->n; - groupNormNHWCScaleKernel + GroupNormNHWCScaleKernel <<StreamHandle()>>>( - params->dst, params->src, params->gamma, params->beta, params->redBuffer, params->epsilon, params->c, params->cPerBlock, - params->cPerGroup, params->groups, params->hwc, params->invHWC, params->hw, params->hwPerBlock, params->withSwish); + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, + params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, + params->use_silu); return HIP_CALL(hipGetLastError()); } template class GroupNormNHWCOp { public: - Status operator()(const GroupNormNHWCParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle())); + Status operator()(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); auto status = GroupNormNHWCSumOp(params); ORT_RETURN_IF_ERROR(status); HIP_RETURN_IF_ERROR(hipGetLastError()); @@ -138,29 +140,30 @@ class GroupNormNHWCOp { return Status::OK(); } - Status IsSupported(const GroupNormNHWCParams* params) { + Status IsSupported(const GroupNormNHWCTunableParams* params) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !(params->c % VecSize == 0 && params->cPerGroup % VecSize == 0), - "The number of channels (", params->c, ") or the number of channels per group (", params->cPerGroup, + !(params->c % VecSize == 0 && params->channels_per_group % VecSize == 0), + "The number of channels (", params->c, ") or the number of channels per group (", params->channels_per_group, ") isn't divisible by the number of vector size: ", VecSize); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock % params->cPerGroup == 0 && - params->c % params->cPerBlock == 0 && params->hw % params->hwPerBlock == 0), - "The value of attributes don't meet the requirements."); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->cPerBlock <= ThreadsPerBlock * VecSize && - params->cPerBlock > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!(params->channels_per_block <= ThreadsPerBlock * VecSize && + params->channels_per_block > (ThreadsPerBlock - GPU_WARP_SIZE) * VecSize), "Configuration: Threads (", ThreadsPerBlock, "), vector size (", - VecSize, ") is redundant for the number of channels per group: ", params->cPerBlock); + VecSize, ") is redundant for the number of channels per group: ", + params->channels_per_block); return Status::OK(); } }; template -Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync(params->redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), params->StreamHandle())); - groupNormNHWCSum(params); +Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); + GroupNormNHWCSum(params); HIP_RETURN_IF_ERROR(hipGetLastError()); - groupNormNHWCScale(params); + GroupNormNHWCScale(params); HIP_RETURN_IF_ERROR(hipGetLastError()); return Status::OK(); } @@ -178,30 +181,30 @@ Status GroupNormNHWCStaticSelection(const GroupNormNHWCParams* params) { ADD_OP_FOR_ALL_VEC_SIZE(name, 320) template -class GroupNormNHWCTunableOp : public TunableOp> { +class GroupNormNHWCTunableOp : public TunableOp> { public: GroupNormNHWCTunableOp() { this->RegisterOp(GroupNormNHWCStaticSelection); ADD_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWCOp) #ifdef USE_COMPOSABLE_KERNEL - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetCKGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } - for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { + for (auto&& [_, op] : GetTritonGroupNormNHWCTypeStringAndOps()) { ORT_UNUSED_PARAMETER(_); this->RegisterOp(std::move(op)); } diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 55cd6a1d112f5..382a3951f3a83 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -93,6 +93,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, Samp class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -246,6 +247,7 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/framework/execution_frame.cc b/onnxruntime/core/framework/execution_frame.cc index 8c08152986cf6..32a5f749af084 100644 --- a/onnxruntime/core/framework/execution_frame.cc +++ b/onnxruntime/core/framework/execution_frame.cc @@ -204,6 +204,14 @@ AllocatorPtr IExecutionFrame::GetAllocator(const OrtDevice& info) const { Status IExecutionFrame::ReleaseMLValue(int ort_value_idx) { return ReleaseMLValueImpl(ort_value_idx); } +#ifdef ENABLE_TRAINING +void IExecutionFrame::ReleaseAllMLValues() { + for (size_t ort_value_idx = 0; ort_value_idx < all_values_.size(); ort_value_idx++) { + all_values_[ort_value_idx] = OrtValue(); + } +} +#endif + Status IExecutionFrame::ReleaseMLValueImpl(int ort_value_idx) { if (ort_value_idx == NodeIndexInfo::kInvalidEntry || static_cast(ort_value_idx) >= all_values_size_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid index ", ort_value_idx); @@ -831,7 +839,20 @@ AllocatorPtr ExecutionFrame::GetAllocatorImpl(const OrtDevice& info) const { // This method is not thread safe! // Return S_OK and nullptr if index map to a value that is an unused optional input/output Status ExecutionFrame::CreateNodeOutputMLValueImpl(OrtValue& ort_value, int ort_value_idx, const TensorShape* shape) { +#ifdef ENABLE_TRAINING + try { + auto status = AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); + return status; + } catch (const std::exception& e) { + LOGS(session_state_.Logger(), WARNING) + << "Exception caught when allocating memory for ort_value with index: " << ort_value_idx + << "so clean up all OrtValues"; + ReleaseAllMLValues(); + return Status(ONNXRUNTIME, FAIL, e.what()); + } +#else return AllocateAsPerAllocationPlan(ort_value, ort_value_idx, shape); +#endif } void ExecutionFrame::VerifyOutputSizes(int output_index, const Node& node, const TensorShape& output_shape) { diff --git a/onnxruntime/core/framework/execution_frame.h b/onnxruntime/core/framework/execution_frame.h index 1576c16684faa..18d210ffd48f7 100644 --- a/onnxruntime/core/framework/execution_frame.h +++ b/onnxruntime/core/framework/execution_frame.h @@ -67,6 +67,8 @@ class IExecutionFrame { const std::unordered_map& initializers); Status GetOutputs(gsl::span fetch_mlvalue_idxs, std::vector& fetches); + // if OOM happens, then release all values, so session can run next batch. + void ReleaseAllMLValues(); #endif // TO DO: make it thread safe diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 27c968a59eb91..e33ce20737f80 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1163,7 +1163,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) @@ -1188,7 +1188,15 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .SetDoc("Beam Search for whisper model, especiall with cross_qk features etc.") .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) - .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts (i.e. the start of transcription token id)", AttributeProto::INT, static_cast(-1)) + .Attr("translate_token_id", "The id of the translate task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("transcribe_token_id", "The id of the transcribe task", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("start_of_lm_token_id", "The id of the token that indicates LM starts", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_speech_token_id", + "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", + AttributeProto::INT, OPTIONAL_VALUE) + .Attr("no_timestamps_token_id", "The id of the token that indicates no timestamps", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("beginning_timestamp_token_id", "The id of the first timestamp", AttributeProto::INT, OPTIONAL_VALUE) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) .Attr("model_type", "Must be 2 for whisper", AttributeProto::INT, static_cast(2)) @@ -1203,27 +1211,24 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "If not provided, it will be inferred from the decoder subgraph's output shape", AttributeProto::INT, static_cast(-1)) .Attr("decoder_output_cross_qk", "If nozero, decoder subgraph contains output Q*K from cross attentions. Default 0.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("no_speech_token", - "The token in whisper model that marks all sequence empty. With this model, whisper could output no_speech_prob after. Default -1.", - AttributeProto::INT, OPTIONAL_VALUE) .Input(0, "input_ids", "The sequence used as a prompt for the generation in the encoder subgraph. Shape is (batch_size, sequence_length)", "F") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I") .Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I") .Input(5, "length_penalty", - "Exponential penalty to the length. Default value 1.0 means no penalty." - "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences." + "Exponential penalty to the length. Default value 1.0 means no penalty. " + "Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences. " "Shape is (1,)", "T", OpSchema::Optional) .Input(6, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional) + .Input(7, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "M", OpSchema::Optional) .Input(8, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "M", OpSchema::Optional) .Input(9, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(10, "decoder_input_ids", "The forced input id sequence for the decoder subgraph. Shape is (batch_size, initial_sequence_length)", "I", OpSchema::Optional) .Input(11, "logits_processor", "Specific logits processor for different types of beamsearch models. Default value 0 means no specific logit processor. Accepts value >= 0. Shape is (1)", "I", OpSchema::Optional) .Input(12, "cross_qk_layer_head", - "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all" + "Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect all " "its shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]", "I", OpSchema::Optional) .Input(13, "extra_decoding_ids", @@ -1235,20 +1240,19 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", - "Processed beam scores for each vocabulary token at each generation step." - "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam." + "Processed beam scores for each vocabulary token at each generation step. " + "Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam. " "Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)", "T", OpSchema::Optional) .Output(3, "cross_qk", "Output the accumulated stacked Q*K in cross attentions. Let H = number of Head of cross attention, " - "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers," - "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]." + "F = the frames or kv-seq-len of the cross attention input, T = real decoded token length, L = number of layers, " + "B = batch size, R = num_return_sequences. It then should return tensor of shape [B, R, L*H, T, F]. " "If cross_qk_layer_head is given, shape is [B, R, cross_qk_layer_head.shape[0], T, F]", "V", OpSchema::Optional) .Output(4, "non_speech_probs", - "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token." - "Currently we treat the last token's logits is what we need, in future extra graph logic may be add to the encoder/context-decoder subgraph." - "The prob is save before logits may be updated by extra-decoding-ids. The shape of non_speech_probs is [B]", + "For whisper model, output the probabilities from logits after encoder and context decoding for the no_speech_token_id. " + "The shape of non_speech_probs is [B]", "T", OpSchema::Optional) .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain to float tensors.") .TypeConstraint("F", {"tensor(float)", "tensor(int32)", "tensor(float16)"}, "Constrain input type to float or int tensors.") @@ -1322,7 +1326,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") @@ -1363,7 +1367,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) - .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vocab_size)", "I", OpSchema::Optional) .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) .Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 902839bee04ba..305122c56b865 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,16 +1818,36 @@ void Graph::ReverseDFSFrom(gsl::span from, } } +template +struct VisitorPriorityQueue { + using ComparatorType = std::function; + std::list list_; + const ComparatorType comparator_ = nullptr; + VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} + + void push(T node) { + list_.insert( + std::upper_bound(list_.begin(), list_.end(), node, comparator_), + node); + } + bool empty() { return list_.empty(); } + T top() { return list_.back(); } + void pop() { list_.pop_back(); } +}; + #if !defined(ORT_MINIMAL_BUILD) void Graph::KahnsTopologicalSort(const std::function& enter, const std::function& comp) const { - std::unordered_map in_degree; - std::priority_queue, decltype(comp)> to_visit(comp); - std::vector topo_order; + InlinedVector in_degree(MaxNodeIndex(), 0); + InlinedVector topo_order; + VisitorPriorityQueue to_visit(comp); + + auto number_of_nodes = NumberOfNodes(); + topo_order.reserve(number_of_nodes); for (auto& node : Nodes()) { size_t input_edge_count = node.GetInputEdgesCount(); - in_degree.insert({node.Index(), input_edge_count}); + in_degree[node.Index()] = input_edge_count; if (input_edge_count == 0) { to_visit.push(&node); } @@ -1844,16 +1864,17 @@ void Graph::KahnsTopologicalSort(const std::function& enter, } for (auto node_it = current->OutputNodesBegin(); node_it != current->OutputNodesEnd(); ++node_it) { - in_degree[node_it->Index()]--; + auto& node_in_degree = in_degree[node_it->Index()]; + node_in_degree--; - if (in_degree[node_it->Index()] == 0) { + if (node_in_degree == 0) { to_visit.push(&*node_it); } } topo_order.push_back(current->Index()); } - if (NumberOfNodes() != static_cast(topo_order.size())) { + if (number_of_nodes != static_cast(topo_order.size())) { ORT_THROW("Some nodes are not included in the topological sort, graph have a cycle."); } } @@ -2843,7 +2864,7 @@ void Graph::AddInitializedTensor(const TensorProto& tensor) { const gsl::not_null tensor_added{graph_proto_->add_initializer()}; *(tensor_added) = tensor; - name_to_initial_tensor_[tensor.name()] = tensor_added; + name_to_initial_tensor_.emplace(tensor.name(), tensor_added); SetGraphResolveNeeded(); if (!is_loaded_from_model_file_ && GetNodeArg(tensor.name()) == nullptr) { // make sure there is a NodeArg for the initializer as SetGraphInputsOutputs may add it to the graph inputs. diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index acf7b3a16541f..119d420066a84 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -14,8 +14,8 @@ bool NodeCompare::operator()(const Node* n1, const Node* n2) const { struct PriorityNodeCompare { inline bool IsHighPri(const Node* n) const { // local statics so we can compare std::strings in the checks - static const std::string shape_op("Shape"); - static const std::string size_op("Size"); + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); const auto& op_type = n->OpType(); return op_type == shape_op || op_type == size_op; @@ -26,15 +26,20 @@ struct PriorityNodeCompare { // If return true, n2 will be output first bool operator()(const Node* n1, const Node* n2) const { // nodes in global high priority list will be output first - if (IsHighPri(n1) != IsHighPri(n2)) { - return IsHighPri(n2); + const bool isN1HighPri = IsHighPri(n1); + const bool isN2HighPri = IsHighPri(n2); + if (isN1HighPri != isN2HighPri) { + return isN2HighPri; } // nodes with lower priority value will be output first - if (n1->Priority() != n2->Priority()) { - return n1->Priority() > n2->Priority(); + const auto n1_priority = n1->Priority(); + const auto n2_priority = n2->Priority(); + if (n1_priority != n2_priority) { + return n1_priority > n2_priority; } +#ifdef ENABLE_TRAINING // nodes of forward pass will be output first auto n1_attrs = n1->GetAttributes(); auto n2_attrs = n2->GetAttributes(); @@ -45,6 +50,7 @@ struct PriorityNodeCompare { if (n1_is_forward != n2_is_forward) { return n2_is_forward > n1_is_forward; } +#endif // otherwise, nodes with lower index will be output first return n1->Index() > n2->Index(); diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.cc b/onnxruntime/core/optimizer/gather_slice_fusion.cc new file mode 100644 index 0000000000000..21266d356a020 --- /dev/null +++ b/onnxruntime/core/optimizer/gather_slice_fusion.cc @@ -0,0 +1,344 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/gather_slice_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +bool GatherSliceToSplitFusion::IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, + int64_t& axis, int64_t& indices_n_dims) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Gather", {1, 11, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + const NodeArg& input_arg = *(node.InputDefs()[1]); + + if (!optimizer_utils::IsScalar(input_arg)) return false; + + const ONNX_NAMESPACE::TensorProto* indices_init = graph_utils::GetConstantInitializer(graph, input_arg.Name()); + + if (!indices_init) return false; + + if (indices_init->data_type() != ONNX_NAMESPACE::TensorProto::INT64) return false; + + // get the index value + Initializer init_const(*indices_init, graph.ModelPath()); + index = *(init_const.data()); + + // get attributes value + axis = 0; + auto& attrs = node.GetAttributes(); + if (attrs.find("axis") != attrs.end()) { + auto& axis_attr = attrs.at("axis"); + if (utils::HasInt(axis_attr)) axis = axis_attr.i(); + } + + indices_n_dims = indices_init->dims_size(); + return true; +} + +bool GatherSliceToSplitFusion::IsSupportedSlice(const Graph& graph, const Node& node, + InlinedVector& starts, + InlinedVector& ends, + InlinedVector& axes, + InlinedVector& steps) const { + // check the version of Slice ops + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Slice", {1, 10, 11, 13}) || + !graph_utils::IsSupportedProvider(node, GetCompatibleExecutionProviders())) { + return false; + } + + // get the opset version + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + // If Slice op of opset version 1 + if (onnx_opset_version == 1) { + if (!graph_utils::GetRepeatedNodeAttributeValues(node, "starts", starts) || + !graph_utils::GetRepeatedNodeAttributeValues(node, "ends", ends) || + starts.size() != ends.size()) { + return false; + } + + if (graph_utils::GetRepeatedNodeAttributeValues(node, "axes", axes) && (axes.size() != starts.size())) { + return false; + } + } + + // If Slice op of opset version >= 10 + if (onnx_opset_version >= 10) { + // node inputs include: starts - ends - axes - steps + + // return a pointer to the corresponding NodeArg if input of the node at the index exists + auto get_input_if_exists = [&node](size_t input_index) -> const NodeArg* { + const auto& input_defs = node.InputDefs(); + const NodeArg* input = (input_defs.size() > input_index) ? input_defs[input_index] : nullptr; + return (input == nullptr || !input->Exists()) ? nullptr : input; + }; + + // return a pointer to the initializer if it is constant; otherwise, a nullptr + auto get_initializer_if_constant = + [&graph, get_input_if_exists](size_t input_index) -> const ONNX_NAMESPACE::TensorProto* { + const NodeArg* input = get_input_if_exists(input_index); + return input ? graph_utils::GetConstantInitializer(graph, input->Name()) : nullptr; + }; + + // return the initialization data if it is constant + auto get_initializer_data = + [&graph](const ONNX_NAMESPACE::TensorProto* slice_initializer) -> InlinedVector { + Initializer init(*slice_initializer, graph.ModelPath()); + if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT32) { + int32_t* init_data = init.data(); + return InlinedVector(init_data, init_data + init.size()); + } + + if (slice_initializer->data_type() == ONNX_NAMESPACE::TensorProto::INT64) { + int64_t* init_data = init.data(); + return InlinedVector(init_data, init_data + init.size()); + } + return {}; + }; + + // starts and ends inputs have to exist, be constants and be of the same size. + const ONNX_NAMESPACE::TensorProto* starts_init = get_initializer_if_constant(1); + const ONNX_NAMESPACE::TensorProto* ends_init = get_initializer_if_constant(2); + const ONNX_NAMESPACE::TensorProto* axes_init = get_initializer_if_constant(3); + const ONNX_NAMESPACE::TensorProto* steps_init = get_initializer_if_constant(4); + + if (!starts_init || !ends_init || !axes_init || !steps_init) { + return false; + } + + starts = get_initializer_data(starts_init); + ends = get_initializer_data(ends_init); + axes = get_initializer_data(axes_init); + steps = get_initializer_data(steps_init); + + if (starts.size() == 0 || ends.size() == 0 || starts.size() != ends.size()) { + return false; + } + + if (axes_init->dims_size() != 1 || static_cast(axes_init->dims().Get(0)) != starts.size()) { + return false; + } + + // if steps exists, it should be constant and all value should be 1 + if (steps.size() != starts.size()) { + return false; + } + + for (int64_t step : steps) { + if (step != 1) { + return false; + } + } + } + + return true; +} + +/* +GatherToSplitFusion is to fuse: + Node + |-> Gather(index=0, axis=axis) + |-> Gather(index=1, axis=axis) + |-> Slice(index=2, axis=axis) +To + Node + |-> Split(index=0) +So that we can use one kernel to finish the job. +*/ + +Status GatherSliceToSplitFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { + GraphViewer graph_viewer(graph); + + const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder(); + + InlinedVector output_args; + + // Iterate the topological order and get Reshape ops + for (auto node_index : node_topology_list) { + auto* p_node = graph.GetNode(node_index); + + if (p_node == nullptr) continue; + + Node& node = *p_node; + + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + // Currently only catch after Reshape ops, optimize in the future + if (node.OpType() != "Reshape") continue; + + size_t output_count = node.GetOutputEdgesCount(); + + // We only catch 1 scenario for Multi Query Attention for now. + // |---> Gather + // Reshape |---> Gather + // |---> Slice + // |... or (other ops) + + // Get the output into node args + if (output_count < 3) continue; + + output_args.push_back(node.OutputDefs()[0]); + } + + // iterate the children of Reshape node + for (const NodeArg* node_arg : output_args) { + auto shape = node_arg->Shape(); + if (!shape) continue; + + auto consumers = graph.GetConsumerNodes(node_arg->Name()); + size_t consumer_count = consumers.size(); + + // get the tensor rank + int64_t rank = static_cast(shape->dim_size()); + + bool can_fuse = true; + bool first_edge = true; + int64_t split_axis = 0; + int64_t indices_n_dims = -1; + + // Fuse 2 Gathers and 1 slice to Split + // Get those outputs as Split outputs + InlinedVector split_outputs(3); + + InlinedVector> nodes_to_fuse; + size_t gather_node_count = 2, slice_node_count = 0; + + // find the nodes to be merged + for (auto consumer : consumers) { + int64_t index, axis, dims; + InlinedVector starts, ends, axes, steps; + + bool IsSupportedGatherOps = IsSupportedGather(graph, *consumer, index, axis, dims); + bool IsSupportedSliceOps = IsSupportedSlice(graph, *consumer, starts, ends, axes, steps); + + if ((!consumer || consumer->InputDefs()[0] != node_arg) || + (!IsSupportedGatherOps && !IsSupportedSliceOps)) { + break; + } + + if (IsSupportedGatherOps) { + if (indices_n_dims == -1) { + indices_n_dims = dims; + } else if (indices_n_dims != dims) { + // Not the same number of dimensions (0 or 1) for all scalar indices. + can_fuse = false; + break; + } + + if (axis < 0) axis += rank; + + if (first_edge) { + auto dim = shape->dim(static_cast(axis)); + // dim.dim_value() = 73 + if (!utils::HasDimValue(dim)) { + can_fuse = false; + break; + } + split_axis = axis; + first_edge = false; + } else if (axis != split_axis) { + can_fuse = false; + break; + } + + if (index < 0) index += static_cast(consumer_count); + if (index < 0 || index >= static_cast(consumer_count)) { + can_fuse = false; + break; + } + + Node& gather_node = *graph.GetNode(consumer->Index()); + nodes_to_fuse.push_back(gather_node); + NodeArg* gather_output_args = gather_node.MutableOutputDefs()[0]; + split_outputs[gather_node_count--] = gather_output_args; + } + + // check the Slice Ops + if (IsSupportedSliceOps) { + if (axes[0] != axis && !first_edge) { + can_fuse = false; + break; + } + + Node& slice_node = *graph.GetNode(consumer->Index()); + NodeArg* slice_output_args = slice_node.MutableOutputDefs()[0]; + nodes_to_fuse.push_back(slice_node); + split_outputs[slice_node_count++] = slice_output_args; + } + } + + // condition check + if (!can_fuse || gather_node_count != 0 || slice_node_count != 1) continue; + + // generate the split node and merge the kernel + ONNX_NAMESPACE::TypeProto split_output_type; + const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast( + node_arg->TypeAsProto()->tensor_type().elem_type()); + + split_output_type.mutable_tensor_type()->set_elem_type(element_type); + + for (int64_t i = 0; i < rank; i++) { + if (i == split_axis) + split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1LL); + else + *(split_output_type.mutable_tensor_type()->mutable_shape()->add_dim()) = shape->dim(static_cast(i)); + } + + InlinedVector split_output_types; + + for (size_t i = 0; i < consumer_count; ++i) { + split_output_types.push_back( + &graph.GetOrCreateNodeArg( + graph.GenerateNodeArgName("fused_split_" + std::to_string(i)), &split_output_type)); + } + + // Generate the Split Node + ONNX_NAMESPACE::TensorProto split_initializer_proto; + split_initializer_proto.set_name(graph.GenerateNodeName("fused_Split")); + split_initializer_proto.add_dims(static_cast(3)); + split_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + + auto dim_value = shape->dim(static_cast(split_axis)).dim_value(); + // Optimize 2 Gather Nodes, so Slice_dim = dim_value - 2 + int64_t slice_dim = static_cast(dim_value - 2); + InlinedVector split_value{{slice_dim, 1, 1}}; + split_initializer_proto.set_raw_data(split_value.data(), split_value.size() * sizeof(int64_t)); + NodeArg* split_arg = &graph_utils::AddInitializer(graph, split_initializer_proto); + + Node& split_node = + graph.AddNode(graph.GenerateNodeName("Split"), "Split", "Split for fused Gather-Slice fusion", + {graph.GetNodeArg(node_arg->Name()), split_arg}, split_outputs); + + split_node.AddAttribute("axis", split_axis); + + split_node.SetExecutionProviderType(nodes_to_fuse[0].get().GetExecutionProviderType()); + + int onnx_opset_version = -1; + if (graph.DomainToVersionMap().find(kOnnxDomain) != graph.DomainToVersionMap().end()) { + onnx_opset_version = graph.DomainToVersionMap().at(kOnnxDomain); + } + + if (onnx_opset_version >= 18) { + split_node.AddAttribute("num_outputs", static_cast(consumer_count)); + } + + for (Node& node_to_fuse : nodes_to_fuse) { + graph_utils::RemoveNodeOutputEdges(graph, node_to_fuse); + graph.RemoveNode(node_to_fuse.Index()); + } + modified = true; + } + + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/gather_slice_fusion.h b/onnxruntime/core/optimizer/gather_slice_fusion.h new file mode 100644 index 0000000000000..1c5c307efed7f --- /dev/null +++ b/onnxruntime/core/optimizer/gather_slice_fusion.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** +@class GatherSliceToSplitFusion +Fuse (2 Gather nodes + 1 Slice) to 1 split node. +*/ + +class GatherSliceToSplitFusion : public GraphTransformer { + private: + bool IsSupportedGather(const Graph& graph, const Node& node, int64_t& index, int64_t& axis, + int64_t& indices_n_dims) const; + + bool IsSupportedSlice(const Graph& graph, const Node& node, + InlinedVector& starts, + InlinedVector& ends, + InlinedVector& axes, + InlinedVector& steps) const; + + public: + GatherSliceToSplitFusion(const InlinedHashSet& compatible_execution_providers = {}) noexcept + : GraphTransformer("GatherSliceToSplitFusion", compatible_execution_providers) {} + + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index cd3c49be15aa4..4e939fe3c7b6b 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -37,6 +37,7 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -308,6 +309,7 @@ InlinedVector> GenerateTransformers( transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); + transformers.emplace_back(std::make_unique(cpu_cuda_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); transformers.emplace_back(std::make_unique(cpu_cuda_dml_rocm_eps)); diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 159e3b23d1ab0..b6ad4fde6c1f7 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -13,7 +13,7 @@ using namespace onnxruntime::common; namespace onnxruntime { // LayerNorm supports limited data types. -static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)"}; +static constexpr std::array supported_data_types{"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}; // Default epsilon static constexpr float DEFAULT_LAYERNORM_EPSILON = 1e-5f; diff --git a/onnxruntime/core/optimizer/noop_elimination.cc b/onnxruntime/core/optimizer/noop_elimination.cc index b3c2991d54b28..bba39b698a27a 100644 --- a/onnxruntime/core/optimizer/noop_elimination.cc +++ b/onnxruntime/core/optimizer/noop_elimination.cc @@ -42,49 +42,62 @@ bool NoopElimination::SatisfyCondition(const Graph& graph, const Node& node, con // if initializer_rank is bigger, the output is expected to be initializer_rank per broadcasting rule, // but it won't happen if the case is accepted, thus reject it - auto initializer_rank = initializer->dims().size(); + const auto& dims = initializer->dims(); + auto initializer_rank = dims.size(); const auto* other_input_shape = node.InputDefs()[input0_is_initializer ? 1 : 0]->Shape(); if (other_input_shape == nullptr || initializer_rank > other_input_shape->dim_size()) { return false; } - int32_t data_type = initializer->data_type(); - Initializer add_init(*initializer, graph.ModelPath()); - if (add_init.size() > 1) { + int64_t tensor_size = 1; + for (auto i : dims) { + tensor_size *= i; + } + + if (tensor_size > 1) { return false; } + // handle edge case where the total size of the initializer is 0 - if (add_init.size() == 0) { + if (tensor_size == 0) { return true; } - float value = 0.0f; - switch (data_type) { - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - value = *add_init.data(); - break; - case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: - value = math::halfToFloat(add_init.data()->val); - break; - case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT32: - value = static_cast(*add_init.data()); - break; - case ONNX_NAMESPACE::TensorProto_DataType_INT64: - value = static_cast(*add_init.data()); - break; - default: + if (op_type == "Add" || + op_type == "Sub" || + op_type == "Mul" || + op_type == "Div") { + int32_t data_type = initializer->data_type(); + Initializer add_init(*initializer, graph.ModelPath()); + + float value = 0.0f; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + value = *add_init.data(); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + value = math::halfToFloat(add_init.data()->val); + break; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + value = static_cast(*add_init.data()); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + value = static_cast(*add_init.data()); + break; + default: + return false; + } + + if (value != 0.0f && (op_type == "Add" || op_type == "Sub")) { return false; - } + } - if ((op_type == "Add" || op_type == "Sub") && value != 0.0f) { - return false; - } - - if ((op_type == "Mul" || op_type == "Div") && value != 1.0f) { - return false; + if (value != 1.0f && (op_type == "Mul" || op_type == "Div")) { + return false; + } } // reject node output is graph output for now diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc index b1ab641a23256..4e3dff705bd41 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.cc @@ -76,6 +76,49 @@ bool IsQDQPairSupported( } } +bool IsDQQConversion( + const Node& dq_node, const Node& q_node, + const GetConstantInitializerFn& get_const_initializer, + const Path& model_path) { + ConstPointerContainer> dq_input_defs = dq_node.InputDefs(); + ConstPointerContainer> q_input_defs = q_node.InputDefs(); + + // Q/DQ contains optional input is not supported + // non-scalar Q/DQ scale and zero point needs are not supported + if (dq_input_defs.size() != InputIndex::TOTAL_COUNT || + q_input_defs.size() != InputIndex::TOTAL_COUNT || + !optimizer_utils::IsScalar(*q_input_defs[InputIndex::SCALE_ID]) || + !optimizer_utils::IsScalar(*q_input_defs[InputIndex::ZERO_POINT_ID]) || + !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::SCALE_ID]) || + !optimizer_utils::IsScalar(*dq_input_defs[InputIndex::ZERO_POINT_ID])) { + return false; + } + + // if Q/DQ scale and zero point are not constant, return false + const ONNX_NAMESPACE::TensorProto* dq_scale_tensor_proto = + get_const_initializer(dq_input_defs[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* q_scale_tensor_proto = + get_const_initializer(q_input_defs[InputIndex::SCALE_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* dq_zp_tensor_proto = + get_const_initializer(dq_input_defs[InputIndex::ZERO_POINT_ID]->Name()); + const ONNX_NAMESPACE::TensorProto* q_zp_tensor_proto = + get_const_initializer(q_input_defs[InputIndex::ZERO_POINT_ID]->Name()); + if (nullptr == q_zp_tensor_proto || + nullptr == dq_zp_tensor_proto || + nullptr == q_scale_tensor_proto || + nullptr == dq_scale_tensor_proto) { + return false; + } + + // check Q/DQ have same scale type and different zero point type + Initializer q_zp(*q_zp_tensor_proto, model_path); + Initializer q_scale(*q_scale_tensor_proto, model_path); + Initializer dq_zp(*dq_zp_tensor_proto, model_path); + Initializer dq_scale(*dq_scale_tensor_proto, model_path); + + return (dq_zp.data_type() != q_zp.data_type()) && (dq_scale.data_type() == q_scale.data_type()); +} + bool IsDQSupported(const Node& dq_node, const GetConstantInitializerFn& get_const_initializer) { bool zero_point_exists = false; if (!QOrDQNodeHasConstantScalarScaleAndZeroPoint(dq_node, get_const_initializer, zero_point_exists)) { diff --git a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h index bb0bf9438cfcb..8333168b0093f 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h +++ b/onnxruntime/core/optimizer/qdq_transformer/qdq_util.h @@ -38,6 +38,18 @@ bool IsQDQPairSupported( const GetConstantInitializerFn& get_const_initializer, const Path& model_path); +// Check if a DQ -> Q sequence represents a conversion in quantization data type. +// Example of uint8 to uint16: +// Dequantize (uint8 to float) -> Quantize (float to uint16) +// Requires: +// 1. Q/DQ doesn't have optional input. +// 2. scale and zero-point are constant scalars. +// 3. Q and DQ have the same scale *type* and different zero-point *types*. +bool IsDQQConversion( + const Node& dq_node, const Node& q_node, + const GetConstantInitializerFn& get_const_initializer, + const Path& model_path); + // Check if DQ is supported in extended level QDQ transformers. It requires: // 1. DQ doesn't have optional input. // 2. scale and zero point is constant scalar diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index d9f08ffe1171e..c532f56b3d3d9 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -115,7 +115,7 @@ class ApiGraph final : public api::GraphRef { const auto& graph_outputs = graph_.GetOutputs(); graph_outputs_.reserve(graph_outputs.size()); for (const auto* output : graph_outputs) { - graph_outputs_.insert(output->Name()); + graph_outputs_.emplace(output->Name()); } } diff --git a/onnxruntime/core/platform/windows/env.cc b/onnxruntime/core/platform/windows/env.cc index 1a0713db43db8..0eb34cbfbc9eb 100644 --- a/onnxruntime/core/platform/windows/env.cc +++ b/onnxruntime/core/platform/windows/env.cc @@ -32,6 +32,9 @@ limitations under the License. #include "core/common/span_utils.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" +#if defined(_M_X64) && !defined(_M_ARM64EC) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH) +#include "core/platform/windows/hardware_core_enumerator.h" +#endif #include #include @@ -248,12 +251,53 @@ void WindowsEnv::SleepForMicroseconds(int64_t micros) const { Sleep(static_cast(micros) / 1000); } +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH) +static constexpr std::array kVendorID_Intel = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" +#endif int WindowsEnv::DefaultNumCores() { return std::max(1, static_cast(std::thread::hardware_concurrency() / 2)); } int WindowsEnv::GetNumPhysicalCpuCores() const { - return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) && defined(ONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH) + // The following code is a temporary fix for a perf problem on Intel's Meteor Lake CPUs. The Intel compute platform has + // a hybrid architecture that some CPU cores runs significant slower than the others. If we distribute our compute work + // evenly to all CPU cores, the slowest CPU core will drag the performance down. So, instead, we reduce the total number + // of threads to exclude the slowest cores out. + // The following code is based on assumptions that: + // 1. All Intel hybrid CPUs should have 3 levels of cache. + // 2. If a CPU core is only associated with two levels of cache, it should be a low performance CPU core and should + // not be used. + // Since we don't know what the next Intel hybrid CPU would be like, later on we may need to rework the following code. + // However, no matter what the code should not cause any crash. The worst is it might return 1 that + // thread pools will not be created, which is just a perf issue and does not impact usability. + // TODO: detect if CPUID instruction is available per instructions at https://wiki.osdev.org/CPUID#Checking_CPUID_availability + int regs[4]; + __cpuid(regs, 0); + bool bIsIntel = + (kVendorID_Intel[0] == regs[1]) && + (kVendorID_Intel[1] == regs[2]) && + (kVendorID_Intel[2] == regs[3]); + if (bIsIntel && regs[0] >= 7) { + // Query Structured Extended Feature Flags Enumeration Leaf + __cpuid(regs, 0x7); + // The bit 15 of EDX indicates if the processor is identified as a hybrid part. + bool ishybrid = regs[3] & (1 << 15); + if (ishybrid) { + // NOTE: even if ishybrid is true, it doesn't mean the processor must have P-cores and E-cores. + // On Intel CPUs we assume the HardwareCoreEnumerator::DefaultIntraOpNumThreads function would never fail. + // NOTE: due to resource restrictions, we cannot test this branch in our CI build pipelines. + return std::max(static_cast(1), HardwareCoreEnumerator::DefaultIntraOpNumThreads()); + } else { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); + } + } else +#endif + { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); + } } std::vector WindowsEnv::GetDefaultThreadAffinities() const { diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.cc b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc new file mode 100644 index 0000000000000..121c59808ae59 --- /dev/null +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.cc @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "hardware_core_enumerator.h" +#include +#include +#include + +namespace onnxruntime { + +struct LogicalProcessorInformation { + std::unique_ptr Buffer; + size_t Length; +}; + +struct CoreCounter { + uint32_t PhysicalCores = 0; + uint32_t SocDieCores = 0; +}; + +static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) { + DWORD length = 0; + DWORD rc = GetLogicalProcessorInformationEx(relationship, nullptr, &length); + + assert(rc == FALSE); + + auto processorInformationBytes = std::make_unique(length); + + rc = GetLogicalProcessorInformationEx( + relationship, reinterpret_cast(processorInformationBytes.get()), &length); + + assert(rc == TRUE); + + return {std::move(processorInformationBytes), length}; +} + +uint32_t CountSetBits(DWORD input) { + uint32_t c; + for (c = 0; input; c++) { + input &= input - 1; + } + return c; +} + +static CoreCounter GetNumberOPhysicalAndEngineeringCores() { + auto logicalProcessorInformation = GetLogicalProcessorInfos(RelationAll); + + CoreCounter cores; + DWORD dwLevel2GroupMask = 0; + DWORD dwLevel3GroupMask = 0; + size_t read = 0; + PSYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX currentProcessorInfo = NULL; + + while ((read + FIELD_OFFSET(SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX, Processor)) < logicalProcessorInformation.Length) { + currentProcessorInfo = + reinterpret_cast(logicalProcessorInformation.Buffer.get() + read); + if ((read + currentProcessorInfo->Size) > logicalProcessorInformation.Length) { + break; + } + + switch (currentProcessorInfo->Relationship) { + case RelationProcessorCore: + cores.PhysicalCores++; + break; + case RelationCache: + if (currentProcessorInfo->Cache.Level == 2) { + dwLevel2GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; + } else if (currentProcessorInfo->Cache.Level == 3) { + dwLevel3GroupMask |= currentProcessorInfo->Cache.GroupMask.Mask; + } + break; + } + + read += currentProcessorInfo->Size; + } + + cores.SocDieCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); + return cores; +} + +uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { + // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. + // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. + auto cores = GetNumberOPhysicalAndEngineeringCores(); + // We want to use the number of physical cores, but exclude soc cores + return cores.PhysicalCores - cores.SocDieCores; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/platform/windows/hardware_core_enumerator.h b/onnxruntime/core/platform/windows/hardware_core_enumerator.h new file mode 100644 index 0000000000000..93b50f452afcd --- /dev/null +++ b/onnxruntime/core/platform/windows/hardware_core_enumerator.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +namespace onnxruntime { +struct HardwareCoreEnumerator { + HardwareCoreEnumerator() = delete; + static uint32_t DefaultIntraOpNumThreads(); +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/coreml/builders/coreml_spec.h b/onnxruntime/core/providers/coreml/builders/coreml_spec.h index e9cd4af94e5fd..c9adba9e579d0 100644 --- a/onnxruntime/core/providers/coreml/builders/coreml_spec.h +++ b/onnxruntime/core/providers/coreml/builders/coreml_spec.h @@ -3,12 +3,28 @@ #pragma once -// TODO come up with a more intuitive way of limiting this to Apple platform builds -// E.g., putting CoreML EP files that should be enabled iff `defined(__APPLE__)` in a separate directory. -#if !defined(__APPLE__) -#error "This file should only be included when building on Apple platforms." +#include "onnxruntime_config.h" + +#if defined(__GNUC__) +#pragma GCC diagnostic push + +// Disable warning from protobuf code. +// +// In file included from coreml_proto/Model.pb.h:30: +// In file included from _deps/protobuf-src/src/google/protobuf/extension_set.h:53: +// _deps/protobuf-src/src/google/protobuf/parse_context.h:328:47: +// error: implicit conversion loses integer precision: 'long' to 'int' [-Werror,-Wshorten-64-to-32] +#ifdef HAS_SHORTEN_64_TO_32 +#pragma GCC diagnostic ignored "-Wshorten-64-to-32" +#endif #endif +// Model.pb.h is generated in the build output directory from the CoreML protobuf files in +// onnxruntime/core/providers/coreml/coremltools/mlmodel/format #include "coreml_proto/Model.pb.h" +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + namespace COREML_SPEC = CoreML::Specification; diff --git a/onnxruntime/core/providers/coreml/builders/helper.cc b/onnxruntime/core/providers/coreml/builders/helper.cc index 897856256cc79..bc3ba4432e66d 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.cc +++ b/onnxruntime/core/providers/coreml/builders/helper.cc @@ -22,22 +22,35 @@ namespace onnxruntime { namespace coreml { -OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, uint32_t coreml_flags) { +OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, + int32_t coreml_version, + uint32_t coreml_flags) { return OpBuilderInputParams{graph_viewer, - (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0}; + coreml_version, + (coreml_flags & COREML_FLAG_ONLY_ALLOW_STATIC_INPUT_SHAPES) != 0, + (coreml_flags & COREML_FLAG_CREATE_MLPROGRAM) != 0}; } -bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) { +const IOpBuilder* GetOpBuilder(const Node& node) { const auto& op_builders = GetOpBuilders(); - if (Contains(op_builders, node.OpType())) { - const auto* op_builder = op_builders.at(node.OpType()); + const auto it = op_builders.find(node.OpType()); + if (it != op_builders.cend()) { + return it->second; + } + + return nullptr; +} + +bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) { + const auto* op_builder = GetOpBuilder(node); + if (op_builder) { return op_builder->IsOpSupported(node, input_params, logger); } else { return false; } } -bool IsInputSupported(const NodeArg& input, const std::string& parent_name, +bool IsInputSupported(const Node& node, const NodeArg& input, const OpBuilderInputParams& input_params, const logging::Logger& logger) { if (!input.Exists()) { // optional input that is not provided @@ -48,8 +61,8 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, std::vector shape; // We do not support input with no shape if (!GetShape(input, shape, logger)) { - LOGS(logger, VERBOSE) << "Input [" << input_name << "] of [" << parent_name - << "] has no shape"; + LOGS(logger, VERBOSE) << MakeString("Input [", input_name, "] of Node [", node.Name(), "] type [", node.OpType(), + "] has no shape"); return false; } @@ -63,11 +76,19 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, // For some undocumented reason, Apple CoreML framework will fail loading the model if the model // input has dimension > 16384 // See this issue, https://github.com/apple/coremltools/issues/1003 + // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf has maximum texture widths which may be the + // root cause. if (dim > 16384) { LOGS(logger, WARNING) << "CoreML does not support input dim > 16384. Input:" << input_name << ", shape: " << Shape2String(shape); return false; } + + if (dim == 0) { + LOGS(logger, WARNING) << "CoreML does not support shapes with dimension values of 0. Input:" << input_name + << ", shape: " << Shape2String(shape); + return false; + } } // Limit input shape rank to 5. @@ -87,13 +108,6 @@ std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewe const logging::Logger& logger) { std::unordered_set supported_nodes{}; -#ifdef __APPLE__ - if (!util::HasRequiredBaseOS()) { - LOGS(logger, WARNING) << "All ops will fallback to CPU EP, because we do not have supported OS"; - return supported_nodes; - } -#endif - for (const auto& node : graph_viewer.Nodes()) { const bool supported = IsNodeSupported(node, input_params, logger); LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType() @@ -149,7 +163,9 @@ bool HasNeuralEngine(const logging::Logger& logger) { #else // In this case, we are running the EP on non-apple platform, which means we are running the model // conversion with CoreML EP enabled, for this we always assume the target system has Neural Engine - LOGS(logger, VERBOSE) << "HasNeuralEngine running on non-Apple hardware for model conversion only"; + LOGS(logger, INFO) << "HasNeuralEngine running on non-Apple hardware. " + "Returning true to enable model conversion and local testing of CoreML EP implementation. " + "No CoreML model will be compiled or run."; has_neural_engine = true; #endif // #ifdef __APPLE__ diff --git a/onnxruntime/core/providers/coreml/builders/helper.h b/onnxruntime/core/providers/coreml/builders/helper.h index d8b27ac76ae73..300de2dedd122 100644 --- a/onnxruntime/core/providers/coreml/builders/helper.h +++ b/onnxruntime/core/providers/coreml/builders/helper.h @@ -23,10 +23,14 @@ class Logger; namespace coreml { -OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, uint32_t coreml_flags); +OpBuilderInputParams MakeOpBuilderParams(const GraphViewer& graph_viewer, + int32_t coreml_version, + uint32_t coreml_flags); -bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, - const OpBuilderInputParams& input_params, const logging::Logger& logger); +const IOpBuilder* GetOpBuilder(const Node& node); + +bool IsInputSupported(const Node& node, const NodeArg& node_arg, const OpBuilderInputParams& input_params, + const logging::Logger& logger); bool IsNodeSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger); diff --git a/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc index 53f18b205880c..e9e520156576e 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/LRN_op_builder.cc @@ -3,39 +3,26 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class LRNOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ - Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_lrn = layer->mutable_lrn(); @@ -56,9 +43,6 @@ Status LRNOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool LRNOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc index 88d6616b4e097..dee87ce3632a8 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/activation_op_builder.cc @@ -2,44 +2,32 @@ // Licensed under the MIT License. #include "core/common/narrow.h" +#include "core/framework/tensorprotoutils.h" #include "core/optimizer/initializer.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/framework/tensorprotoutils.h" -#include "core/providers/coreml/builders/impl/builder_utils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ActivationOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; + int GetMinSupportedOpSet(const Node& node) const override; }; -// Add operator related - -#ifdef __APPLE__ void ActivationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); @@ -86,7 +74,7 @@ Status AddPReluWeight(ModelBuilder& model_builder, const Node& node, Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& op_type(node.OpType()); if (op_type == "Sigmoid") { @@ -115,14 +103,10 @@ Status ActivationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related namespace { // assumes that node.OpType() == "PRelu" -bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_params, - const logging::Logger& logger) { +bool IsPReluOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); // X input rank must be 3 or 4 diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index 7a5d4a5af673b..e9a8176c8349b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -1,37 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class ArgMaxOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& graph_viewer = model_builder.GetGraphViewer(); NodeAttrHelper helper(node); @@ -67,9 +56,6 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool ArgMaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc index 25d5bad14ceb6..2570e6d88ae0d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.cc @@ -1,21 +1,18 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif +using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { -// Shared functions - +namespace { // TODO, move this to shared_library bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node& node, const logging::Logger& logger) { @@ -37,93 +34,78 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node return false; } +} // namespace -// Add operator related -#ifdef __APPLE__ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, - const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - ORT_RETURN_IF_NOT( - IsOpSupported(node, input_params, logger), - "Unsupported operator ", - node.OpType()); - - ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); - LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() - << "] type: [" << node.OpType() << "] was added"; - return Status::OK(); -} + Status status = AddToModelBuilderImpl(model_builder, node, logger); -/* static */ std::unique_ptr -BaseOpBuilder::CreateNNLayer(ModelBuilder& model_builder, const Node& node) { - auto layer_name = node.Name(); - if (layer_name.empty()) { - // CoreML requires layer has a name, while the node name is optional in ONNX - // In this case, create a unique name for the layer - layer_name = model_builder.GetUniqueName(MakeString("Node_", node.Index(), "_type_", node.OpType())); + if (status.IsOK()) { + LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() << "] type: [" << node.OpType() << "] was added"; } - return CreateNNLayer(layer_name); -} -/* static */ std::unique_ptr -BaseOpBuilder::CreateNNLayer(const std::string& layer_name) { - std::unique_ptr layer = std::make_unique(); - layer->set_name(layer_name); - return layer; + return status; } -#endif - -// Operator support related bool BaseOpBuilder::IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, input_params, logger)) + if (input_params.create_mlprogram && !SupportsMLProgram()) { + LOGS(logger, VERBOSE) << "Operator [" << node.OpType() << "] does not support MLProgram"; return false; + } - // We do not support external initializers for now - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (HasExternalInitializer(initializers, node, logger)) + if (!HasSupportedOpSet(node, logger)) { + return false; + } + + if (!HasSupportedInputs(node, input_params, logger)) { return false; + } - if (!HasSupportedOpSet(node, logger)) + // We do not support external initializers for now + const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); + if (HasExternalInitializer(initializers, node, logger)) { return false; + } return IsOpSupportedImpl(node, input_params, logger); } bool BaseOpBuilder::HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { - const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { - if (!IsInputSupported(*input, node_name, input_params, logger)) { + if (!IsInputSupported(node, *input, input_params, logger)) { return false; } } - return HasSupportedInputsImpl(node, logger); + return HasSupportedInputsImpl(node, input_params, logger); } -bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { - // We only check the type of input 0 by default - // specific op builder can override this +/* static */ +bool BaseOpBuilder::IsInput0Supported(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) { const auto& input = *node.InputDefs()[0]; - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; + int32_t input_type = ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED; - if (input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Input type: [" << input_type - << "] is not supported for now"; + // currently only float is supported + if (!GetType(input, input_type, logger) || input_type != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + LOGS(logger, VERBOSE) << "[" << node.OpType() << "] Input type: [" << input_type << "] is not currently supported"; return false; } return true; } -bool BaseOpBuilder::HasSupportedOpSet(const Node& node, - const logging::Logger& logger) const { +bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + // We only check the type of input 0 by default + // specific op builder can override this + return IsInput0Supported(node, input_params, logger); +} + +bool BaseOpBuilder::HasSupportedOpSet(const Node& node, const logging::Logger& logger) const { auto since_version = node.SinceVersion(); if (since_version < GetMinSupportedOpSet(node) || since_version > GetMaxSupportedOpSet(node)) { LOGS(logger, VERBOSE) << node.OpType() << "is only supported for opset [" diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index b4132d3b770ec..06c4dd94ea30d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -3,11 +3,9 @@ #pragma once -#include "core/providers/coreml/builders/op_builder.h" - -#ifdef __APPLE__ +#include "core/common/span_utils.h" #include "core/providers/coreml/builders/coreml_spec.h" -#endif +#include "core/providers/coreml/builders/op_builder.h" namespace onnxruntime { namespace coreml { @@ -18,45 +16,40 @@ class BaseOpBuilder : public IOpBuilder { public: virtual ~BaseOpBuilder() = default; - // Add operator related + // does the operator implementation support creating an ML Program + bool SupportsMLProgram() const override { return false; } + + bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override final; -#ifdef __APPLE__ - public: - virtual void AddInitializersToSkip(ModelBuilder& /* model_builder */, const Node& /* node */) const override {} Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node, - const OpBuilderInputParams& input_params, const logging::Logger& logger) const override final; - protected: - virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const = 0; - - static std::unique_ptr - CreateNNLayer(ModelBuilder& model_builder, const Node& node); - - static std::unique_ptr CreateNNLayer(const std::string& layer_name); -#endif - - // Operator support related - public: - bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, - const logging::Logger& logger) const override final; + void AddInitializersToSkip(ModelBuilder& /*model_builder*/, const Node& /*node*/) const override {} protected: - virtual bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, - const logging::Logger& /* logger */) const { + // check if the first input's data type is supported. + static bool IsInput0Supported(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger); + + private: + virtual bool IsOpSupportedImpl(const Node& /*node*/, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& /*logger*/) const { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const; + virtual bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const; - virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } - virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 20; } + virtual int GetMinSupportedOpSet(const Node& /*node*/) const { return 1; } + virtual int GetMaxSupportedOpSet(const Node& /*node*/) const { return 20; } - private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; bool HasSupportedInputs(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const; + + virtual Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, + const logging::Logger& logger) const = 0; }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc index 391b02eaec497..8da58f659acf1 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/batch_norm_op_builder.cc @@ -5,30 +5,20 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class BatchNormalizationOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -36,9 +26,6 @@ class BatchNormalizationOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 7; } }; -// Add operator related - -#ifdef __APPLE__ void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { // skip everything except input0 for BatchNormalization const auto& input_defs = node.InputDefs(); @@ -48,10 +35,9 @@ void BatchNormalizationOpBuilder::AddInitializersToSkip(ModelBuilder& model_buil model_builder.AddInitializerToSkip(input_defs[4]->Name()); // var } -Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, - const Node& node, +Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); @@ -81,9 +67,6 @@ Status BatchNormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_bu model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool BatchNormalizationOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc index 10c9b32d03f37..6074fba1433d9 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/binary_op_builder.cc @@ -1,35 +1,28 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/framework/tensorprotoutils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - -#include "base_op_builder.h" namespace onnxruntime { namespace coreml { - class BinaryOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related + int GetMinSupportedOpSet(const Node& node) const override; - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; }; -#ifdef __APPLE__ -static bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger) { +namespace { +bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& logger) { const auto& input_defs = node.InputDefs(); const auto* x_shape_proto = input_defs[0]->Shape(); @@ -57,15 +50,14 @@ static bool CheckIfBothInputShapesMatch(const Node& node, const logging::Logger& y_shape_proto->dim().begin(), y_shape_proto->dim().end(), dim_eq); } - -// Add operator related +} // namespace Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "Add") { // original mutable_add() has limited broadcasting support @@ -99,31 +91,28 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related int BinaryOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { // Add/Sub/Mul/Div opset 6- has broadcast attributes we do not support now return 7; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { - bool is_pow = node.OpType() == "Pow"; - if (!is_pow) { - return BaseOpBuilder::HasSupportedInputsImpl(node, logger); +bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const { + if (node.OpType() != "Pow") { + return IsInput0Supported(node, input_params, logger); } const auto& input_1 = *node.InputDefs()[0]; const auto& input_2 = *node.InputDefs()[1]; + // Pow we only support both inputs as fp32 for now int32_t input_type_1; - if (!GetType(input_1, input_type_1, logger)) - return false; - int32_t input_type_2; - if (!GetType(input_2, input_type_2, logger)) + if (!GetType(input_1, input_type_1, logger) || + !GetType(input_2, input_type_2, logger)) { return false; + } if (input_type_1 != ONNX_NAMESPACE::TensorProto_DataType_FLOAT || input_type_1 != input_type_2) { LOGS(logger, VERBOSE) << "Pow only supports fp32 inputs, actual input type" diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc index ef66e6b877a1f..710f596b2a562 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.cc @@ -1,17 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef __APPLE__ - #include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/common/narrow.h" #include "core/framework/tensorprotoutils.h" +#include "core/providers/coreml/builders/coreml_spec.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/shared/utils/utils.h" #include "core/optimizer/initializer.h" -#include "coreml_proto/NeuralNetwork.pb.h" +using namespace COREML_SPEC; namespace onnxruntime { namespace coreml { @@ -133,7 +132,182 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span> shape) { + tensor_type.set_datatype(data_type); + if (shape) { + tensor_type.set_rank(shape->size()); + for (const auto& dim : *shape) { + if (dim >= 0) { + tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim)); + } else { + tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + } + } + } +} + +void SetTensorTypeInfo(MILSpec::TensorType& tensor_type, MILSpec::DataType data_type, + const ONNX_NAMESPACE::TensorShapeProto* shape) { + tensor_type.set_datatype(data_type); + if (shape) { + tensor_type.set_rank(shape->dim_size()); + for (const auto& dim : shape->dim()) { + if (dim.has_dim_value()) { + tensor_type.add_dimensions()->mutable_constant()->set_size(narrow(dim.dim_value())); + } else { + tensor_type.add_dimensions()->mutable_unknown()->set_variadic(false); + } + } + } +} + +template +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + // need a 'false' that is dependent on the template types to make gcc happy and give a meaningful error message. + static_assert(false_for_T && false_for_T, "Unsupported data type"); // add specializations below as needed +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_floats()->mutable_values()->Add(data.begin(), data.end()); +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_ints()->mutable_values()->Add(data.begin(), data.end()); +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_strings()->mutable_values()->Add(data.begin(), data.end()); +} + +// copy int64_t (used by ONNX for strides/indexes/etc.) to int32_t (used by CoreML) +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + auto& int32_out = *tensor_value.mutable_ints()->mutable_values(); + int32_out.Reserve(narrow(data.size())); + for (const int64_t v : data) { + int32_out.AddAlreadyReserved(narrow(v)); + } +} + +template <> +void CopyDataToTensorValue(MILSpec::TensorValue& tensor_value, gsl::span data) { + tensor_value.mutable_bools()->mutable_values()->Add(data.begin(), data.end()); +} + +} // namespace + +MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type) { + switch (static_cast(onnx_type)) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return MILSpec::DataType::FLOAT32; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + return MILSpec::DataType::FLOAT64; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + return MILSpec::DataType::BFLOAT16; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return MILSpec::DataType::FLOAT16; + + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + return MILSpec::DataType::INT8; + case ONNX_NAMESPACE::TensorProto_DataType_INT16: + return MILSpec::DataType::INT16; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + return MILSpec::DataType::INT32; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + return MILSpec::DataType::INT64; + + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + return MILSpec::DataType::UINT8; + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: + return MILSpec::DataType::UINT16; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + return MILSpec::DataType::UINT32; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + return MILSpec::DataType::UINT64; + + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + return MILSpec::DataType::BOOL; + case ONNX_NAMESPACE::TensorProto_DataType_STRING: + return MILSpec::DataType::STRING; + default: + ORT_THROW("Unsupported data type: ", onnx_type); + } +} + +template +MILSpec::Value CreateTensorValue(const gsl::span data, + std::optional> shape) { + MILSpec::Value value; + MILSpec::TensorType& tensor_type = *value.mutable_type()->mutable_tensortype(); + + if (shape) { + SetTensorTypeInfo(tensor_type, DataTypeToMILSpec(), *shape); + } else { + // infer as 1D shape + std::vector coreml_shape{narrow(data.size())}; + SetTensorTypeInfo(tensor_type, DataTypeToMILSpec(), coreml_shape); + } + + MILSpec::TensorValue& tensor_value = *value.mutable_immediatevalue()->mutable_tensor(); + CopyDataToTensorValue(tensor_value, data); + + return value; +} + +template +MILSpec::Value CreateScalarTensorValue(const T& data) { + gsl::span data_span{&data, 1}; + std::vector shape = {}; // empty for scalar + return CreateTensorValue(data_span, shape); +} + +// explicit specializations for types we handle so the implementation can be in the .cc file +template MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape); + +template MILSpec::Value CreateScalarTensorValue(const float& data); +template MILSpec::Value CreateScalarTensorValue(const int32_t& data); +template MILSpec::Value CreateScalarTensorValue(const std::string& data); +template MILSpec::Value CreateScalarTensorValue(const bool& data); + +COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg) { + MILSpec::NamedValueType nvt; + nvt.set_name(node_arg.Name()); + MILSpec::TensorType& tensor_type = *nvt.mutable_type()->mutable_tensortype(); + + SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(node_arg.TypeAsProto()->tensor_type().elem_type()), + node_arg.Shape()); + + return nvt; +} + +void AddOperationInput(MILSpec::Operation& op, std::string_view input_name, std::string_view value_name) { + MILSpec::Argument arg; + arg.mutable_arguments()->Add()->set_name(std::string(value_name)); + + (*op.mutable_inputs())[input_name] = std::move(arg); +} + +void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output) { + auto& outputs = *op.mutable_outputs(); + auto& output_arg = *outputs.Add(); + output_arg.set_name(output.Name()); + + MILSpec::ValueType& value = *output_arg.mutable_type(); + MILSpec::TensorType& tensor_type = *value.mutable_tensortype(); + + SetTensorTypeInfo(tensor_type, OnnxDataTypeToMILSpec(output.TypeAsProto()->tensor_type().elem_type()), + output.Shape()); +} + } // namespace coreml } // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h index 23b11928f7dc2..8126f0c126914 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/coreml/builders/impl/builder_utils.h @@ -5,22 +5,19 @@ #pragma once -#ifdef __APPLE__ +#include #include "core/common/gsl.h" #include "core/common/status.h" #include "core/graph/basic_types.h" #include "core/providers/common.h" -namespace CoreML { -namespace Specification { -class WeightParams; -} -} // namespace CoreML +#include "core/providers/coreml/builders/coreml_spec.h" namespace onnxruntime { -namespace coreml { +class NodeArg; +namespace coreml { // Try to see if we can map explicit padding to auto padding for Conv/Pool // Since usually use auto padding is more efficient Status HandleAutoPad(const std::vector input_shape, @@ -32,6 +29,10 @@ Status HandleAutoPad(const std::vector input_shape, AutoPadType auto_pad_type, AutoPadType& auto_pad_type_out); +// +// NeuralNetwork utils +// + // Copy an onnx initializer data to a coreml weight Status CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, const ONNX_NAMESPACE::TensorProto& tensor); @@ -44,7 +45,90 @@ void CreateCoreMLWeight(CoreML::Specification::WeightParams& weight, gsl::span data); +// +// MLProgram utils +// + +// helper for static_assert where the value needs to be dependent on a template parameter +template +constexpr bool false_for_T = false; + +template +COREML_SPEC::MILSpec::DataType DataTypeToMILSpec() { + if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::FLOAT32; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::FLOAT64; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::BFLOAT16; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::FLOAT16; + + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT8; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT16; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT32; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::INT64; + + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT8; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT16; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT32; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::UINT64; + + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::BOOL; + } else if constexpr (std::is_same_v) { + return COREML_SPEC::MILSpec::DataType::STRING; + } else { + static_assert(false_for_T, "Unsupported type."); + } +} + +// The TensorProto.data_type field is an int, but must be a valid TensorProto_DataType value. +// Use int for the arg so the caller can pass TensorProto.data_type() value and do the cast to enum internally +COREML_SPEC::MILSpec::DataType OnnxDataTypeToMILSpec(int onnx_type); + +/// +/// Create a CoreML MILSpec::TensorValue for the given input data. +/// +/// Original C++ data type +/// CoreML C++ data type +/// ONNX data +/// ONNX data shape. Inferred to be a 1D shape of `{data.size()}` if not specified. +/// TensorValue containing data. +template +COREML_SPEC::MILSpec::Value CreateTensorValue(gsl::span data, + std::optional> shape = std::nullopt); + +template +COREML_SPEC::MILSpec::Value CreateScalarTensorValue(const T& data); + +/// Create a NamedValueType from an ONNX tensor NodeArg. +/// Used to create inputs for the 'main' function in an ML Program. +COREML_SPEC::MILSpec::NamedValueType CreateNamedTensorValueType(const NodeArg& node_arg); + +/// +/// Add an input argument to a MILSpec::Operation +/// +/// Operation to update. +/// The input name defined by the spec for the operation. +/// The name of the value that is providing the input. +/// "https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html" +void AddOperationInput(COREML_SPEC::MILSpec::Operation& op, + std::string_view input_name, std::string_view value_name); + +/// +/// Add an output to a MILSpec::Operation. Name, data type and shape are used from the NodeArg. +/// +/// Operation to update. +/// NodeArg with details of output to add. +void AddOperationOutput(COREML_SPEC::MILSpec::Operation& op, const NodeArg& output); } // namespace coreml } // namespace onnxruntime - -#endif diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index 15ee1f0fc7284..70053c2c606a0 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -1,34 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/shared/utils/utils.h" #include "core/providers/coreml/builders/helper.h" -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class CastOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; -}; -// Add operator related + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; +}; -#ifdef __APPLE__ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */, const Node& /* node */, const logging::Logger& /* logger */) const { @@ -37,9 +28,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& /* model_builder */, // Cast node is not provided in CoreML model, so we're skipping adding the Cast node here. return Status::OK(); } -#endif - -// Operator support related bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -84,7 +72,8 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return true; } -bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { +bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { // We only check the type of input 0 const auto& input = *node.InputDefs()[0]; diff --git a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc index a298a8d12c741..9aca172abec98 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/clip_op_builder.cc @@ -1,37 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace coreml { class ClipOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ void ClipOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { // Both min and max values will be injected into the layer, no need to add to the model if (node.SinceVersion() >= 11) { @@ -58,7 +45,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (!has_min && !has_max) { // Clip without min/max is an identity node // In CoreML we don't have identity, use ActivationLinear instead - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); layer->mutable_activation()->mutable_linear()->set_alpha(1.0f); *layer->mutable_input()->Add() = input_name; *layer->mutable_output()->Add() = output_name; @@ -83,8 +70,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Handle clipping at min first if (has_min) { - const auto clip_min_layer_name = model_builder.GetUniqueName(MakeString(node_name, "_Clip_min")); - std::unique_ptr min_layer = CreateNNLayer(clip_min_layer_name); + std::unique_ptr min_layer = model_builder.CreateNNLayer(node, "_Clip_min"); if (min == 0.0f) { // If min is 0. then this min will be handled by relu min_layer->mutable_activation()->mutable_relu(); } else { // otherwise, min will be handled by unary->threshold @@ -101,9 +87,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (has_max) { const auto threshold_output_name = model_builder.GetUniqueName(MakeString(node_name, "threshold_output")); { // Add threshold layer, which is actually max( -1 * min_output, -max) - const auto clip_max_threshold_layer_name = - model_builder.GetUniqueName(MakeString(node_name, "_Clip_max_threshold")); - auto threshold_layer = CreateNNLayer(clip_max_threshold_layer_name); + auto threshold_layer = model_builder.CreateNNLayer(node, "_Clip_max_threshold"); threshold_layer->mutable_unary()->set_alpha(-max); threshold_layer->mutable_unary()->set_scale(-1.0f); threshold_layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::THRESHOLD); @@ -112,9 +96,7 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(threshold_layer)); } { // Add linear activation layer -1 * threshold_output - const auto clip_max_linear_layer_name = - model_builder.GetUniqueName(MakeString(node_name, "_Clip_max_linear")); - auto linear_layer = CreateNNLayer(clip_max_linear_layer_name); + auto linear_layer = model_builder.CreateNNLayer(node, "_Clip_max_linear"); linear_layer->mutable_activation()->mutable_linear()->set_alpha(-1.0f); *linear_layer->mutable_input()->Add() = threshold_output_name; *linear_layer->mutable_output()->Add() = output_name; @@ -125,9 +107,6 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related bool ClipOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc index b1e761024f5c9..34193318a0264 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/concat_op_builder.cc @@ -4,37 +4,26 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ConcatOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); layer->mutable_concat()->set_sequenceconcat(false); @@ -48,9 +37,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif -// Operator support related bool ConcatOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc index ff9dcbd9f8874..05e43dbbd16af 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/conv_op_builder.cc @@ -4,39 +4,35 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" -#include "core/providers/coreml/builders/op_builder_factory.h" -#include "core/providers/shared/utils/utils.h" - -#ifdef __APPLE__ #include "core/providers/coreml/builders/impl/builder_utils.h" #include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" -#endif +#include "core/providers/shared/utils/utils.h" + +using namespace CoreML::Specification; namespace onnxruntime { namespace coreml { class ConvOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, const logging::Logger& /* logger */) const override; -}; -// Add operator related + bool SupportsMLProgram() const override { return true; } +}; -#ifdef __APPLE__ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + if (model_builder.CreateMLProgram()) { + // we add the initializers as 'const' operations via ModelBuilder::RegisterInitializers + return; + } + const auto& input_defs = node.InputDefs(); // skip the weight and bias (if has it) for conv as we will directly set those as part of the NN layer @@ -49,136 +45,251 @@ void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Nod Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); - const auto& input_defs = node.InputDefs(); const auto& output_defs = node.OutputDefs(); const auto& input_name = input_defs[0]->Name(); const auto& output_name = output_defs[0]->Name(); - const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); - std::vector weight_shape = {weight_tensor.dims().cbegin(), weight_tensor.dims().cend()}; + NodeAttrHelper helper(node); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (model_builder.CreateMLProgram()) { + using namespace CoreML::Specification::MILSpec; - const bool is_1d_conv = (weight_shape.size() == 3); + // https://github.com/apple/coremltools/blob/7.1/coremltools/converters/mil/mil/ops/defs/iOS15/conv.py - if (is_1d_conv) { - // weight_shape needs to be expanded from MXCXH->MXCXHx1 - weight_shape.push_back(1); - } + std::unique_ptr conv_op = model_builder.CreateOperation(node, "conv"); - NodeAttrHelper helper(node); - auto strides = helper.Get("strides", std::vector{1, 1}); - auto dilations = helper.Get("dilations", std::vector{1, 1}); - auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); - // Strides/dilations for 1d conv is normally of length 1. Expand them by 1 - // to meet the required length 2 (for 2d conv it's normally 2) - // Similarly 1d conv normally has a length 2 padding. Expand it to length 4 by adding additional zeros. - if (is_1d_conv) { - if (strides.size() < 2) { - ORT_RETURN_IF_NOT(strides.size() == 1, "strides size does not equal 1 for Conv 1d"); - strides.push_back(1); + AddOperationInput(*conv_op, "x", input_name); + AddOperationInput(*conv_op, "weight", input_defs[1]->Name()); + + if (input_defs.size() > 2) { + AddOperationInput(*conv_op, "bias", input_defs[2]->Name()); } - if (dilations.size() < 2) { - ORT_RETURN_IF_NOT(dilations.size() == 1, "dilations size does not equal 1 for Conv 1d"); - dilations.push_back(1); + + // ONNX attributes. Add as inputs if specified/required + auto strides = helper.GetInt64s("strides"); + auto dilations = helper.GetInt64s("dilations"); + auto groups = helper.GetInt64("group"); + + // we know this input has a valid shape due to the check in IsOpSupportedImpl. ignore N and C dims. + const auto num_spatial_dims = input_defs[1]->Shape()->dim_size() - 2; + const auto& op_type = conv_op->type(); + + if (strides) { + AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", *strides)); + } else { + // spec says optional. testing suggests otherwise for at least the iOS15 target (CoreML5) + static const auto default_value = std::vector(num_spatial_dims, 1); + AddOperationInput(*conv_op, "strides", model_builder.AddConstant(op_type, "strides", default_value)); } - if (onnx_pads.size() < 4) { - ORT_RETURN_IF_NOT(onnx_pads.size() == 2, "onnx_pads size does not equal 2 for Conv 1d"); - onnx_pads.insert(onnx_pads.begin() + 1, 0); - onnx_pads.push_back(0); + + if (dilations) { + AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", *dilations)); + } else { + // spec says optional. testing suggests otherwise for at least the iOS15 target (CoreML5) + static const auto default_value = std::vector(num_spatial_dims, 1); + AddOperationInput(*conv_op, "dilations", model_builder.AddConstant(op_type, "dilations", default_value)); } - } - const auto group = helper.Get("group", static_cast(1)); - - auto* coreml_conv = layer->mutable_convolution(); - - std::string expand_output_name = model_builder.GetUniqueName(node.Name() + "_expandDims"); - - if (is_1d_conv) { - const auto expand_layer_name = model_builder.GetUniqueName(MakeString(node.Name(), "_Conv_expand")); - std::unique_ptr expand_layer = CreateNNLayer(expand_layer_name); - // Add an expanddims layer here. CoreML only supports 2d convolution, so for 1d Conv case - // we need to add an additional dimension here to the input to make it "2d Conv" like. - // NxCxH -> NxCxHx1 - expand_layer->mutable_expanddims()->add_axes(-1); - *expand_layer->mutable_input()->Add() = input_name; - *expand_layer->mutable_output()->Add() = expand_output_name; - model_builder.AddLayer(std::move(expand_layer)); - } - coreml_conv->set_outputchannels(weight_shape[0]); // M - coreml_conv->set_kernelchannels(weight_shape[1]); // C/Group - coreml_conv->add_kernelsize(weight_shape[2]); // H - coreml_conv->add_kernelsize(weight_shape[3]); // W - coreml_conv->set_ngroups(group); - *coreml_conv->mutable_stride() = {strides.cbegin(), strides.cend()}; - *coreml_conv->mutable_dilationfactor() = {dilations.cbegin(), dilations.cend()}; - - coreml_conv->set_isdeconvolution(false); - - // Add Padding - // Usually using autopadding is more efficient than using explicit padding - // Try to see if we can map explicit padding to auto padding - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - AutoPadType auto_pad_type; - ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], - onnx_pads, strides, dilations, - StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), - auto_pad_type)); - - if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { - auto* padding_type = coreml_conv->mutable_same(); - if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER - padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + + if (groups) { + AddOperationInput(*conv_op, "groups", model_builder.AddScalarConstant(op_type, "groups", *groups)); } - } else { - auto* padding_type = coreml_conv->mutable_valid(); - if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { - // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts - auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - height_border->set_startedgesize(onnx_pads[0]); - height_border->set_endedgesize(onnx_pads[2]); - auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); - width_border->set_startedgesize(onnx_pads[1]); - width_border->set_endedgesize(onnx_pads[3]); + + AutoPadType auto_pad_type = StringToAutoPadType(helper.Get("auto_pad", "NOTSET")); + + // pad type (string) + // valid - no pads (ONNX auto_pad VALID) + // custom - pads input (ONNX NOTSET) + // same - inferred to be `d_out[i] = ceil(d_in[i] / strides[i])` (assuming == ONNX SAME_UPPER) + // same_lower - as per same but any extra rows/cols are added at top/left if padding is odd (ONNX SAME_LOWER) + // + // TODO: See if we want to update HandleAutoPad to support 1D (and 3D) so we can infer if an autopad value + // can be used. TBD if that provides any performance benefit with ML Program though as CoreML could + // potentially do that for us. + switch (auto_pad_type) { + case AutoPadType::NOTSET: { + // use `pads` attribute. + auto onnx_pads = helper.GetInt64s("pads"); // 'pads' must be provided if auto_pad is NOTSET + if (onnx_pads) { + AddOperationInput(*conv_op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string("custom"))); + + // need to re-order from x1_start, x2_start..., x1_end, x2_end... to + // x1_start, x1_end, x2_start, x2_end,... + size_t num_pads = onnx_pads->size(); + size_t num_dims = num_pads / 2; + std::vector reordered_pads(num_pads, 0); + for (size_t i = 0; i < num_pads; ++i) { + auto cur_dim = i % num_dims; + if (i < num_dims) { // start values + reordered_pads[cur_dim * 2] = (*onnx_pads)[i]; + } else { // end values + reordered_pads[cur_dim * 2 + 1] = (*onnx_pads)[i]; + } + } + + AddOperationInput(*conv_op, "pad", model_builder.AddConstant(op_type, "pad", reordered_pads)); + + break; + } + + // in theory the pads may not be provided and in that case the default is no padding. + // as that is the same as 'valid', fall through + [[fallthrough]]; + } + case AutoPadType::VALID: + AddOperationInput(*conv_op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string("valid"))); + + break; + case AutoPadType::SAME_UPPER: + case AutoPadType::SAME_LOWER: { + const auto pad_type = (auto_pad_type == AutoPadType::SAME_UPPER ? "same" : "same_lower"); + AddOperationInput(*conv_op, "pad_type", + model_builder.AddScalarConstant(op_type, "pad_type", std::string(pad_type))); + + // despite what the spec says, a 'pad' input seems to be required. + // https://github.com/apple/coremltools/issues/2127 + // provide the default value. passing in an empty vector also works. TBD what's better. + std::vector ignored_pads(num_spatial_dims * 2, 0); + AddOperationInput(*conv_op, "pad", model_builder.AddConstant(op_type, "pad", ignored_pads)); + + break; + } } - } - // Add weight - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_weights(), weight_tensor)); + // set output + AddOperationOutput(*conv_op, *node.OutputDefs()[0]); + + model_builder.AddOperation(std::move(conv_op)); + } else +#endif // defined(COREML_ENABLE_MLPROGRAM) + { + std::unique_ptr layer = model_builder.CreateNNLayer(node); + + auto strides = helper.Get("strides", std::vector{1, 1}); + auto dilations = helper.Get("dilations", std::vector{1, 1}); + auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + const auto group = helper.Get("group", static_cast(1)); + + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + + const auto& weight_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); + std::vector weight_shape = {weight_tensor.dims().cbegin(), weight_tensor.dims().cend()}; + + const bool is_1d_conv = (weight_shape.size() == 3); + + // add dummy 'W' dim with value of 1 so we can use 2D conv. + if (is_1d_conv) { + input_shape.push_back(1); + weight_shape.push_back(1); + + // Strides/dilations for 1d conv is normally of length 1. Expand them by 1 + // to meet the required length 2 (for 2d conv it's normally 2) + if (strides.size() < 2) { + ORT_RETURN_IF_NOT(strides.size() == 1, "strides size does not equal 1 for Conv 1d"); + strides.push_back(1); + } + + if (dilations.size() < 2) { + ORT_RETURN_IF_NOT(dilations.size() == 1, "dilations size does not equal 1 for Conv 1d"); + dilations.push_back(1); + } + + // Similarly 1d conv normally has a length 2 padding. Expand it to length 4 by adding additional zeros. + if (onnx_pads.size() < 4) { + ORT_RETURN_IF_NOT(onnx_pads.size() == 2, "onnx_pads size does not equal 2 for Conv 1d"); + onnx_pads.insert(onnx_pads.begin() + 1, 0); + onnx_pads.push_back(0); + } + } - // Add bias if present - if (input_defs.size() > 2) { - coreml_conv->set_hasbias(true); - const auto& bias_tensor = *model_builder.GetInitializerTensors().at(input_defs[2]->Name()); - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_bias(), bias_tensor)); - } + auto* coreml_conv = layer->mutable_convolution(); - if (is_1d_conv) { - std::string conv_output_name = model_builder.GetUniqueName(node.Name() + "_conv_output"); - *layer->mutable_input()->Add() = expand_output_name; - *layer->mutable_output()->Add() = conv_output_name; - model_builder.AddLayer(std::move(layer)); - - // Add a squeeze layer here. Since CoreML only supports 2d conv and we expanded the dimension by 1 before, - // we need to squeeze it back from NxCxHx1->NxCxH. - const auto squeeze_layer_name = model_builder.GetUniqueName(MakeString(node.Name(), "_Conv_squeeze")); - std::unique_ptr squeeze_layer = CreateNNLayer(squeeze_layer_name); - squeeze_layer->mutable_squeeze()->add_axes(-1); - *squeeze_layer->mutable_input()->Add() = conv_output_name; - *squeeze_layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(squeeze_layer)); - } else { - *layer->mutable_input()->Add() = input_name; - *layer->mutable_output()->Add() = output_name; - model_builder.AddLayer(std::move(layer)); + std::string expand_output_name = model_builder.GetUniqueName(node.Name() + "_expandDims"); + + if (is_1d_conv) { + // Add an expanddims layer here. CoreML only supports 2d convolution, so for 1d Conv case + // we need to add an additional dimension here to the input to make it "2d Conv" like. + // NxCxH -> NxCxHx1 + auto expand_layer = model_builder.CreateNNLayer(node, "_Conv_expand"); + expand_layer->mutable_expanddims()->add_axes(-1); + *expand_layer->mutable_input()->Add() = input_name; + *expand_layer->mutable_output()->Add() = expand_output_name; + model_builder.AddLayer(std::move(expand_layer)); + } + + coreml_conv->set_outputchannels(weight_shape[0]); // M + coreml_conv->set_kernelchannels(weight_shape[1]); // C/Group + coreml_conv->add_kernelsize(weight_shape[2]); // H + coreml_conv->add_kernelsize(weight_shape[3]); // W + coreml_conv->set_ngroups(group); + *coreml_conv->mutable_stride() = {strides.cbegin(), strides.cend()}; + *coreml_conv->mutable_dilationfactor() = {dilations.cbegin(), dilations.cend()}; + + coreml_conv->set_isdeconvolution(false); + + // Add Padding + // Usually using autopadding is more efficient than using explicit padding + // Try to see if we can map explicit padding to auto padding + AutoPadType auto_pad_type; + ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], + onnx_pads, strides, dilations, + StringToAutoPadType(helper.Get("auto_pad", "NOTSET")), + auto_pad_type)); + + if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { + auto* padding_type = coreml_conv->mutable_same(); + if (AutoPadType::SAME_LOWER == auto_pad_type) { // default is SAME_UPPER + padding_type->set_asymmetrymode(COREML_SPEC::SamePadding_SamePaddingMode_TOP_LEFT_HEAVY); + } + } else { + auto* padding_type = coreml_conv->mutable_valid(); + if (AutoPadType::NOTSET == auto_pad_type && onnx_pads != std::vector{0, 0, 0, 0}) { + // NOTSET is adding the explicit padding to the ValidPadding.paddingAmounts + auto* height_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + height_border->set_startedgesize(onnx_pads[0]); + height_border->set_endedgesize(onnx_pads[2]); + auto* width_border = padding_type->mutable_paddingamounts()->add_borderamounts(); + width_border->set_startedgesize(onnx_pads[1]); + width_border->set_endedgesize(onnx_pads[3]); + } + } + + // Add weight + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_weights(), weight_tensor)); + + // Add bias if present + if (input_defs.size() > 2) { + coreml_conv->set_hasbias(true); + const auto& bias_tensor = *model_builder.GetConstantInitializer(input_defs[2]->Name()); + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*coreml_conv->mutable_bias(), bias_tensor)); + } + + if (is_1d_conv) { + std::string conv_output_name = model_builder.GetUniqueName(node.Name() + "_conv_output"); + *layer->mutable_input()->Add() = expand_output_name; + *layer->mutable_output()->Add() = conv_output_name; + model_builder.AddLayer(std::move(layer)); + + // Add a squeeze layer here. Since CoreML only supports 2d conv and we expanded the dimension by 1 before, + // we need to squeeze it back from NxCxHx1->NxCxH. + auto squeeze_layer = model_builder.CreateNNLayer(node, "_Conv_squeeze"); + squeeze_layer->mutable_squeeze()->add_axes(-1); + *squeeze_layer->mutable_input()->Add() = conv_output_name; + *squeeze_layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(squeeze_layer)); + } else { + *layer->mutable_input()->Add() = input_name; + *layer->mutable_output()->Add() = output_name; + model_builder.AddLayer(std::move(layer)); + } } return Status::OK(); } -#endif - -// Operator support related bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -186,23 +297,73 @@ bool ConvOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara const auto& input_defs = node.InputDefs(); const auto& weight_name = input_defs[1]->Name(); - const auto& initializers = input_params.graph_viewer.GetAllInitializedTensors(); - if (Contains(initializers, weight_name)) { - const auto& tensor = *initializers.at(weight_name); - if (tensor.dims().size() != 4 && tensor.dims().size() != 3) { - LOGS(logger, VERBOSE) << "Conv [" << name << "] dimension: " << tensor.dims().size() - << " Only conv 2d and conv 1d are supported."; + const auto* weight = input_params.graph_viewer.GetConstantInitializer(weight_name, true); + +#if defined(COREML_ENABLE_MLPROGRAM) + if (input_params.create_mlprogram) { + // ML Program supports non-const weight, 1D, 2D and 3D. + // keep to 1D and 2D for consistency with the NeuralNetwork implementation for now. + // add 3D support as/when needed. + } else +#endif // defined (COREML_ENABLE_MLPROGRAM) + { + if (!weight) { + LOGS(logger, VERBOSE) << "The weight of Conv [" << name << "] must be a constant initializer"; return false; } - } else { - LOGS(logger, VERBOSE) << "The weight of Conv [" << name << "] must be known"; + } + + // use the weight for the shape as it should always be known + const auto* weight_shape = input_defs[1]->Shape(); + int64_t num_dims = weight_shape ? weight_shape->dim_size() : -1; + + // ONNX spec requires N and C as first 2 dims + if (num_dims != 3 && num_dims != 4) { + LOGS(logger, VERBOSE) << "Conv [" << name << "] is " << num_dims - 2 << "D. " + << "Only 1D and 2D Conv are supported currently."; return false; } - if (input_defs.size() > 2) { - const auto& bias_name = input_defs[2]->Name(); - if (!Contains(initializers, bias_name)) { - LOGS(logger, VERBOSE) << "The bias of Conv [" << name << "] must be a constant initializer"; + if (input_defs.size() > 2 && !input_params.graph_viewer.GetConstantInitializer(input_defs[2]->Name(), true)) { + LOGS(logger, VERBOSE) << "The bias of Conv [" << name << "] must be a constant initializer"; + return false; + } + + NodeAttrHelper helper(node); + +#if defined(COREML_ENABLE_MLPROGRAM) + // spec says same_lower is supported in CoreML 5. it lies. CoreML 6 is required otherwise you get + // `Unexpected value for parameter pad_type[0] "same_lower" not in ("custom", "same", "valid").` + // We _could_ manually calculate the pads, but not implementing that until we have a real use case to justify + // the effort as it's not clear how common usage of same_lower is. + if (input_params.create_mlprogram && input_params.coreml_version < 6) { + if (StringToAutoPadType(helper.Get("auto_pad", "NOTSET")) == AutoPadType::SAME_LOWER) { + LOGS(logger, VERBOSE) << "Pad type of SAME_LOWER [" << name << "] is not supported until CoreML 6." + << "Available version is CoreML " << input_params.coreml_version; + return false; + } + } +#endif + + // there's no equivalent to allow a manual kernel shape in CoreML. + // it's OK if a specified kernel_shape matches kH and kW dims of the weight input. + auto kernel_shape = helper.GetInt64s("kernel_shape"); + if (kernel_shape) { + bool valid = true; + if (static_cast(kernel_shape->size()) == num_dims - 2) { + for (int i = 0; i < num_dims - 2; ++i) { + // check the specified kernel shape matches the weight shape. skip the initial N and C dims in the latter. + if ((*kernel_shape)[i] != weight_shape->dim()[i + 2].dim_value()) { + valid = false; + break; + } + } + } else { + valid = false; + } + + if (!valid) { + LOGS(logger, VERBOSE) << "Conv [" << name << "] kernel_shape attribute does not match the weight shape"; return false; } } diff --git a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc index a4ad1c31b5027..1eba312b2577b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/depthtospace_op_builder.cc @@ -4,37 +4,26 @@ #include "core/common/safeint.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class DepthToSpaceOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_defs = node.InputDefs(); const auto& output_defs = node.OutputDefs(); @@ -54,9 +43,6 @@ Status DepthToSpaceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool DepthToSpaceOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc index b303fe7884cb1..f0adb70587bcf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/flatten_op_builder.cc @@ -3,39 +3,26 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class FlattenOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ - Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + std::unique_ptr layer = model_builder.CreateNNLayer(node); // Note: ONNX Flatten corresponds to CoreML FlattenTo2DLayerParams auto* coreml_flatten = layer->mutable_flattento2d(); @@ -51,9 +38,6 @@ Status FlattenOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related bool FlattenOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc index 9c7ec306ca093..7d32675e3e510 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gather_op_builder.cc @@ -2,34 +2,24 @@ // Licensed under the MIT License. #include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime::coreml { class GatherOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related -#if defined(__APPLE__) namespace { int64_t GetAxisAttribute(const Node& node) { NodeAttrHelper node_attr_helper{node}; @@ -38,8 +28,8 @@ int64_t GetAxisAttribute(const Node& node) { } // namespace Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - auto layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + auto layer = model_builder.CreateNNLayer(node); layer->mutable_gather()->set_axis(GetAxisAttribute(node)); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); // data *layer->mutable_input()->Add() = node.InputDefs()[1]->Name(); // indices @@ -47,10 +37,9 @@ Status GatherOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif // defined(__APPLE__) -// Operator support related -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { +bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc index 71b08db6d44d8..48f77354d7c30 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/gemm_op_builder.cc @@ -7,38 +7,25 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/impl/builder_utils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class GemmOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& /* node */, const OpBuilderInputParams& /* input_params */, const logging::Logger& /* logger */) const override; }; -// Add operator related - -#ifdef __APPLE__ void GemmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& op = node.OpType(); const auto& input_defs(node.InputDefs()); @@ -71,7 +58,7 @@ static Status GetTensorFloatDataTransposed(const ONNX_NAMESPACE::TensorProto& te Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); @@ -120,9 +107,6 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool GemmOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc index ba12600e8bc40..99d6f01cb8c5b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pad_op_builder.cc @@ -7,30 +7,20 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class PadOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -64,9 +54,6 @@ static InlinedVector GetPaddingAxesData(const InitializedTensorSet& ini return axes_tensor_data; } -// Add operator related - -#ifdef __APPLE__ void PadOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // pads model_builder.AddInitializerToSkip(node.InputDefs()[2]->Name()); // constant_value @@ -78,7 +65,7 @@ void PadOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_pad = layer->mutable_padding(); auto* constant_padding_type = coreml_pad->mutable_constant(); // CoreML::Specification::PaddingLayerParams_PaddingConstant @@ -122,9 +109,6 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related bool PadOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc index fd1c77c851e6f..01aced739b36d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/pool_op_builder.cc @@ -4,38 +4,27 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/impl/builder_utils.h" -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class PoolOpBuilder : public BaseOpBuilder { - // Add operator related - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_pool = layer->mutable_pooling(); const auto& op_type = node.OpType(); @@ -108,9 +97,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif -// Operator support related bool PoolOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, const logging::Logger& logger) const { const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc index 6a2014e7952a2..32378b1f654d8 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reduction_op_builder.cc @@ -1,36 +1,27 @@ // Copyright (c) Shukant Pal. // Licensed under the MIT License. +#include "core/optimizer/initializer.h" #include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" - -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" -#include "core/optimizer/initializer.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class ReductionOpBuilder : public BaseOpBuilder { -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - private: + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -#ifdef __APPLE__ namespace { template void AddReductionParams(T* params, const std::vector& axes, bool keepdims, bool noop_with_empty_axes) { @@ -76,7 +67,7 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co const bool keepdims = helper.Get("keepdims", 1) != 0; const bool noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0) != 0; - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "ReduceSum") { AddReductionParams(layer->mutable_reducesum(), axes, keepdims, noop_with_empty_axes); @@ -93,7 +84,6 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif bool ReductionOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { @@ -124,4 +114,4 @@ void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations } } // namespace coreml -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc index 67aee73630cdb..7ae1746be3122 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/reshape_op_builder.cc @@ -6,31 +6,21 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/reshape_helper.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ReshapeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -38,9 +28,6 @@ class ReshapeOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 5; } }; -// Add operator related - -#ifdef __APPLE__ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); } @@ -48,7 +35,7 @@ void ReshapeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); @@ -69,9 +56,6 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool ReshapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc index 5f963dc30dd8f..35dcde41a6bcf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/resize_op_builder.cc @@ -8,31 +8,21 @@ #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/reshape_helper.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class ResizeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -41,7 +31,7 @@ class ResizeOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } }; -// Helper functions +namespace { bool GetResizeScales(const InitializedTensorSet& initializers, const Node& node, std::vector& scales, const logging::Logger&) { @@ -73,10 +63,8 @@ bool GetResizeOutputSizes(const InitializedTensorSet& initializers, sizes = std::vector(sizes_data.begin(), sizes_data.end()); return true; } +} // namespace -// Add operator related - -#ifdef __APPLE__ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { // We don't really use ROI here, so add it to skipped list if it's an initializer tensor model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // ROI @@ -96,7 +84,7 @@ void ResizeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const N Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_upsample = layer->mutable_upsample(); NodeAttrHelper helper(node); @@ -131,9 +119,6 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool ResizeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc index fd64153ffd283..a86e3d9538d87 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/shape_op_builder.cc @@ -2,44 +2,30 @@ // Licensed under the MIT License. #include "core/providers/coreml/builders/impl/base_op_builder.h" - +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/shared/utils/utils.h" // for NodeAttrHelper -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime::coreml { class ShapeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related -#if defined(__APPLE__) Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, - const logging::Logger& logger) const { - auto layer = CreateNNLayer(model_builder, node); + const logging::Logger& /*logger*/) const { + auto layer = model_builder.CreateNNLayer(node); layer->mutable_getshape(); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif // defined(__APPLE__) -// Operator support related bool ShapeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, const logging::Logger& logger) const { NodeAttrHelper node_attr_helper{node}; diff --git a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc index 2c250b3cc9f5a..b716af738e1b1 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/slice_op_builder.cc @@ -1,39 +1,31 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/optimizer/initializer.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/cpu/tensor/slice_helper.h" #include "core/providers/shared/utils/utils.h" -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime::coreml { class SliceOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: int GetMinSupportedOpSet(const Node& /* node */) const override { // Before Slice-10, some inputs were attributes instead. We don't support that for now. return 10; } - bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& input_params, + const logging::Logger& logger) const override; + bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& builder_params, const logging::Logger& logger) const override; }; @@ -107,9 +99,6 @@ bool ValidateSliceComputeMetadataForCoreML(const SliceOp::PrepareForComputeMetad } } // namespace -// Add operator related -#if defined(__APPLE__) - void SliceOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& input_defs = node.InputDefs(); @@ -132,7 +121,7 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const ORT_RETURN_IF_ERROR(PrepareSliceComputeMetadataFromConstantInitializers(node, model_builder.GetGraphViewer(), compute_metadata)); - auto layer = CreateNNLayer(model_builder, node); + auto layer = model_builder.CreateNNLayer(node); *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); *layer->mutable_output()->Add() = node.OutputDefs()[0]->Name(); auto* slice_static = layer->mutable_slicestatic(); @@ -163,10 +152,8 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } -#endif // defined(__APPLE__) - -// Operator support related -bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const { +bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const OpBuilderInputParams& /*input_params*/, + const logging::Logger& logger) const { int32_t input_type; if (!GetType(*node.InputDefs()[0], input_type, logger)) return false; diff --git a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc index c454a2a779f6e..266396a0fe90e 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/softmax_op_builder.cc @@ -1,43 +1,29 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" -#include "core/providers/coreml/shape_utils.h" -#include "core/providers/shared/utils/utils.h" - -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" +#include "core/providers/coreml/shape_utils.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace coreml { class SoftmaxOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ - Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); const auto& input_name = node.InputDefs()[0]->Name(); const auto& output_name = node.OutputDefs()[0]->Name(); @@ -68,9 +54,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto reshape1_output_name = model_builder.GetUniqueName(MakeString(node.Name(), "reshape1_output")); { // Add reshape layer - const auto softmax_reshape1_layer_name = - model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape1")); - auto reshape_layer = CreateNNLayer(softmax_reshape1_layer_name); + auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape1"); *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {target_shape.cbegin(), target_shape.cend()}; *reshape_layer->mutable_input()->Add() = input_name; *reshape_layer->mutable_output()->Add() = reshape1_output_name; @@ -86,9 +70,7 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } { // Add reshape back layer - const auto softmax_reshape2_layer_name = - model_builder.GetUniqueName(MakeString(node.Name(), "_Softmax_reshape2")); - auto reshape_layer = CreateNNLayer(softmax_reshape2_layer_name); + auto reshape_layer = model_builder.CreateNNLayer(node, "_Softmax_reshape2"); *reshape_layer->mutable_reshapestatic()->mutable_targetshape() = {data_shape.cbegin(), data_shape.cend()}; *reshape_layer->mutable_input()->Add() = softmax_output_name; *reshape_layer->mutable_output()->Add() = output_name; @@ -99,10 +81,6 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related - bool SoftmaxOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& /* input_params */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 56c87c883156b..0497357c45c54 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -1,35 +1,24 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/coreml/builders/impl/base_op_builder.h" - #include "core/optimizer/initializer.h" #include "core/providers/common.h" #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#if defined(__APPLE__) -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class SplitOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; @@ -37,10 +26,6 @@ class SplitOpBuilder : public BaseOpBuilder { int GetMinSupportedOpSet(const Node& /* node */) const override { return 13; } }; -// Add operator related - -#ifdef __APPLE__ - void SplitOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { const auto& input_defs = node.InputDefs(); @@ -63,7 +48,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // attribute introduced since opset 18 uint64_t num_outputs; - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_splitnd = layer->mutable_splitnd(); coreml_splitnd->set_axis(axis); @@ -82,7 +67,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, coreml_splitnd->set_numsplits(num_outputs); } else { // note: for opset 18+ 'num_outputs' is a required attribute - num_outputs = narrow(helper.GetInt("num_outputs").value()); + num_outputs = narrow(helper.GetInt64("num_outputs").value()); // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; uint64_t chunk_size = narrow((split_dim_size + num_outputs - 1) / num_outputs); @@ -111,10 +96,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -#endif - -// Operator support related - bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); @@ -159,7 +140,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar } } else { if (node.SinceVersion() >= 18) { - const auto num_outputs = helper.GetInt("num_outputs"); + const auto num_outputs = helper.GetInt64("num_outputs"); if (!num_outputs.has_value()) { LOGS(logger, VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; return false; @@ -169,9 +150,10 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar << "CoreML SplitND requires at least 2 outputs. num_outputs: " << num_outputs.value(); return false; } - if (num_outputs.value() != static_cast(node.OutputDefs().size()) || num_outputs.value() > split_dims_at_axis) { - LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n." - << "The value should be smaller or equal to the size of dimension being split. num_outputs: " + if (num_outputs.value() != static_cast(node.OutputDefs().size()) || + num_outputs.value() > split_dims_at_axis) { + LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n. The value should be smaller or equal to the size " + "of dimension being split. num_outputs: " << num_outputs.value(); return false; } diff --git a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc index 2e14c85ce69c1..e9cc1c2dbf638 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/squeeze_op_builder.cc @@ -1,48 +1,30 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include + +#include "core/common/safeint.h" #include "core/framework/tensorprotoutils.h" #include "core/providers/common.h" -#include "core/providers/shared/utils/utils.h" -#include "core/optimizer/initializer.h" - -#ifdef __APPLE__ +#include "core/providers/coreml/builders/impl/base_op_builder.h" #include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/op_builder_factory.h" - -#include "base_op_builder.h" +#include "core/providers/shared/utils/utils.h" +#include "core/optimizer/initializer.h" namespace onnxruntime { namespace coreml { class SqueezeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - public: void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif - // Operator support related - private: bool IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const override; }; -// Add operator related - -#ifdef __APPLE__ -void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { - if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) { - model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); - } -} - -/* static */ Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector& axes) { +namespace { +Status GetAxes(ModelBuilder& model_builder, const Node& node, std::vector& axes) { // Squeeze opset 13 use input as axes if (node.SinceVersion() > 12) { // If axes is not provided, return an empty axes as default to squeeze all @@ -62,11 +44,18 @@ void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const return Status::OK(); } +} // namespace + +void SqueezeOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + if (node.SinceVersion() > 12 && node.InputDefs().size() > 1) { + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); + } +} Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); auto* coreml_squeeze = layer->mutable_squeeze(); std::vector axes; @@ -84,9 +73,6 @@ Status SqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related bool SqueezeOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& /*logger*/) const { diff --git a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc index 7d5018a19f74c..f6a61d55a3d63 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/transpose_op_builder.cc @@ -3,33 +3,23 @@ #include "core/providers/coreml/builders/helper.h" #include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/shape_utils.h" #include "core/providers/shared/utils/utils.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif - namespace onnxruntime { namespace coreml { class TransposeOpBuilder : public BaseOpBuilder { - // Add operator related -#ifdef __APPLE__ - private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif }; -// Add operator related - -#ifdef __APPLE__ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); NodeAttrHelper helper(node); std::vector perm = helper.Get("perm", std::vector()); @@ -51,7 +41,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc index 660755b43c043..3403378d59114 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/unary_op_builder.cc @@ -3,32 +3,25 @@ #include "core/providers/common.h" -#ifdef __APPLE__ -#include "core/providers/coreml/builders/model_builder.h" -#endif #include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/impl/base_op_builder.h" +#include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/builders/op_builder_factory.h" -#include "base_op_builder.h" - namespace onnxruntime { namespace coreml { class UnaryOpBuilder : public BaseOpBuilder { - private: -#ifdef __APPLE__ Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override; -#endif }; -#ifdef __APPLE__ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { const auto& op_type(node.OpType()); const auto& input_defs(node.InputDefs()); - std::unique_ptr layer = CreateNNLayer(model_builder, node); + std::unique_ptr layer = model_builder.CreateNNLayer(node); if (op_type == "Sqrt") { layer->mutable_unary()->set_type(COREML_SPEC::UnaryFunctionLayerParams::SQRT); @@ -45,9 +38,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const model_builder.AddLayer(std::move(layer)); return Status::OK(); } -#endif - -// Operator support related void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); @@ -55,4 +45,4 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op } } // namespace coreml -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 9c8b7bce507e4..daab36f7b933d 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -2,56 +2,555 @@ // Licensed under the MIT License. #include -#include - -#include "model_builder.h" -#include "helper.h" -#include "op_builder_factory.h" +#include "core/common/safeint.h" +#include "core/framework/tensorprotoutils.h" +#include "core/platform/env.h" #include "core/providers/common.h" +#include "core/providers/coreml/builders/model_builder.h" +#include "core/providers/coreml/builders/helper.h" +#include "core/providers/coreml/builders/op_builder_factory.h" #include "core/providers/coreml/builders/impl/builder_utils.h" +#include "core/providers/coreml/coreml_provider_factory.h" #include "core/providers/coreml/model/host_utils.h" -#include "core/providers/coreml/model/model.h" #include "core/providers/coreml/shape_utils.h" +#if defined(COREML_ENABLE_MLPROGRAM) +// includes from coremltools-src in _deps +#include "modelpackage/src/ModelPackage.hpp" +#include "mlmodel/src/MILBlob/Blob/StorageWriter.hpp" +using MILBlob::Blob::StorageWriter; +#endif + +using namespace CoreML::Specification; + namespace onnxruntime { namespace coreml { -ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, uint32_t coreml_flags) - : graph_viewer_(graph_viewer), - logger_(logger), - coreml_flags_(coreml_flags) { +namespace { +#if defined(COREML_ENABLE_MLPROGRAM) +// Should the initializer be written to file or kept as an immediate value +bool ShouldWriteInitializerToWeightsFile(const ONNX_NAMESPACE::TensorProto& tensor_proto) { + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/load.py#L51-L57 + + bool use_weight_file = false; + + switch (tensor_proto.data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + auto num_elements = TensorShape(utils::GetTensorShapeFromTensorProto(tensor_proto)).Size(); + use_weight_file = num_elements >= 10; + break; + } + default: + break; + } + + return use_weight_file; +} + +// copy from the ONNX TensorProto to a CoreML field. +// T1 is the source type. T2 is the target type. If the types differ, T1 must be smaller than T2. +// e.g. uint32_t data can be written to RepeatedField +template +void CopyRawDataToRepeatedField(const ONNX_NAMESPACE::TensorProto& tensor_proto, + google::protobuf::RepeatedField& repeated_field) { + const auto& raw_data = tensor_proto.raw_data(); + const T1* data = reinterpret_cast(raw_data.data()); + const T1* data_end = data + (raw_data.size() / sizeof(T1)); + if constexpr (sizeof(T1) == sizeof(T2)) { + repeated_field.Add(data, data_end); + } else { + static_assert(sizeof(T1) < sizeof(T2)); + // we need to iterate over the data and copy to the repeated field, converting to T2 as we go. + repeated_field.Resize(data_end - data, T2(0)); + for (int i = 0; data != data_end; ++data, ++i) { + repeated_field[i] = static_cast(*data); + } + } +} + +// copy T data from the TensorProto.int32_t field to TensorValue.bytes +template +void CopyInt32DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue tensor_value) { + const int num_entries = tensor_proto.int32_data_size(); + std::string& bytes = *tensor_value.mutable_bytes()->mutable_values(); + bytes.resize(num_entries * sizeof(T)); + T* out = reinterpret_cast(bytes.data()); + + const int32_t* in = tensor_proto.int32_data().data(); + for (int i = 0; i < num_entries; ++i) { + out[i] = static_cast(in[i]); + } +} + +// copy T data from the TensorProto.uint64_data field to TensorValue.bytes +template +void CopyUInt64DataToBytes(const ONNX_NAMESPACE::TensorProto& tensor_proto, MILSpec::TensorValue tensor_value) { + const int num_entries = tensor_proto.uint64_data_size(); + std::string& bytes = *tensor_value.mutable_bytes()->mutable_values(); + bytes.resize(num_entries * sizeof(T)); + T* out = reinterpret_cast(bytes.data()); + + const uint64_t* in = tensor_proto.uint64_data().data(); + for (int i = 0; i < num_entries; ++i) { + out[i] = static_cast(in[i]); + } +} + +// NOTE: This supports all the ONNX data types. Weights in CoreML may not need all these +void CopyOnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + MILSpec::TensorValue& tensor_value) { + bool has_raw_data = tensor_proto.has_raw_data(); + auto data_type = tensor_proto.data_type(); + + // handling based on + // ONNX TensorProto field usage + // https://github.com/onnx/onnx/blob/b86cc54efce19530fb953e4b21f57e6b3888534c/onnx/onnx.proto#L544-L572 + // CoreMLTools conversion implementation that maps data types to fields + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L98 + // along with some special cased types that are stored in bytes + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L23 + // IMMEDIATE_VALUE_TYPES_IN_BYTES = (types.fp16, types.int8, types.uint8, types.uint32) + + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + // from: float_data/raw, to: floats + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_floats()->mutable_values()); + } else { + tensor_value.mutable_floats()->mutable_values()->CopyFrom(tensor_proto.float_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: { + // from: double_data/raw, to: doubles + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_doubles()->mutable_values()); + } else { + tensor_value.mutable_doubles()->mutable_values()->CopyFrom(tensor_proto.double_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT32: { + // from: int32_data/raw, to: ints + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_ints()->mutable_values()); + } else { + tensor_value.mutable_ints()->mutable_values()->CopyFrom(tensor_proto.int32_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT64: { + // from: int64_data/raw, to: longints + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); + + } else { + tensor_value.mutable_longints()->mutable_values()->CopyFrom(tensor_proto.int64_data()); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data(); + } else { + // iterate the int32_data, taking the 16-bits from each entry, and copying to the bytes. + // we use uint16_t as only the size of the data type matters + CopyInt32DataToBytes(tensor_proto, tensor_value); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data(); + } else { + // copy from int32_data to bytes. uint8_t for both as only the size of the data type matters when copying + CopyInt32DataToBytes(tensor_proto, tensor_value); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: { + // from: uint64_data/raw, to: bytes + if (has_raw_data) { + *tensor_value.mutable_bytes()->mutable_values() = tensor_proto.raw_data(); + } else { + // copy uint32_t values from TensorProto.uint64_data + CopyUInt64DataToBytes(tensor_proto, tensor_value); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: { + // from: uint64_data/raw, to: longints + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_longints()->mutable_values()); + } else { + // TODO: Is this safe? Need to check the CopyFrom implementation. As it's a straight copy of bytes this + // hopefully can do it as one block instead of iterating and potentially doing a static_cast of each + // individual value. + tensor_value.mutable_longints()->mutable_values()->CopyFrom( + reinterpret_cast&>(tensor_proto.uint64_data())); + } + + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: { + // from: int32_data/raw, to: bools + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_bools()->mutable_values()); + } else { + const auto& int32s = tensor_proto.int32_data(); + auto& bools = *tensor_value.mutable_bools()->mutable_values(); + const int num_entries = int32s.size(); + bools.Reserve(num_entries); + const int32_t* in = int32s.data(); + for (int i = 0; i < num_entries; ++i) { + *bools.AddAlreadyReserved() = *in++; + } + } + + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_STRING: { + // from: string_data (which is protobuf type bytes), to: strings (protobuf type string) + // due to the protobuf type mismatch we need to iterate and copy + auto& in = tensor_proto.string_data(); + auto& out = *tensor_value.mutable_strings()->mutable_values(); + out.Reserve(in.size()); + for (const auto& iter : in) { + *out.Add() = iter; + } + + break; + } + /* Not clear if there's an actual use-case for 16-bit int data currently, so leaving commented out + case ONNX_NAMESPACE::TensorProto_DataType_INT16: + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + // from: int32_data/raw, to: ints + // WARNING: This may change to write to mutable_bytes + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L113-L115 + if (has_raw_data) { + CopyRawDataToRepeatedField(tensor_proto, *tensor_value.mutable_ints()->mutable_values()); + } else { + tensor_value.mutable_ints()->mutable_values()->CopyFrom(tensor_proto.int32_data()); + } + break; + } */ + default: + ORT_THROW("AddTensorProtoDataToMILSpecTensorValue: Unsupported data type: ", data_type); + } +} + +template +uint64_t WriteRawDataUsingStorageWriter(const onnx::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& writer) { + MILBlob::Util::Span data(reinterpret_cast(tensor_proto.raw_data().data()), + tensor_proto.raw_data().size() / sizeof(T)); + return writer.WriteData(data); +} + +// Write T1 data from the TensorProto.int32_data field using StorageWriter. +// Currently int32_data can have any of these data types: +// INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, +// FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ +// T1 provides the size of the ONNX data type. T2 is the CoreML type. +// The sizes and layout of T1 and T2 must match as we simply cast the bytes to T2. +template +uint64_t WriteFromInt32DataUsingStorageWriter(const onnx::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& writer) { + static_assert(sizeof(T1) == sizeof(T2), "Data sizes must match"); + + // need to copy to temporary data as we have to extract a subset of bytes from each int32_t entry. + // works better to extract the ONNX type first with static_cast, and reinterpret_cast to the CoreML type at the end. + std::vector values; + const int num_values = tensor_proto.int32_data_size(); + values.resize(num_values); // resize so we're not updating the length inside the copy loop + + const int32_t* in = tensor_proto.int32_data().data(); + for (int i = 0; i < num_values; ++i) { + values[i] = static_cast(in[i]); + } + + MILBlob::Util::Span data(reinterpret_cast(values.data()), + num_values); + return writer.WriteData(data); +} + +// write the initializer to weight.bin and return the offset +// StorageWriter is currently limited to fp32, fp16, bfloat16, uint8/int8, uint16/int16. +// AFAIK we don't use bfloat16/int16/uint16 for weights in ONNX, so limit handling to fp32, fp16, uint8/int8 +uint64_t CopyOnnxTensorToCoreMLWeightsFile(const onnx::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& writer) { + bool has_raw_data = tensor_proto.has_raw_data(); + auto data_type = tensor_proto.data_type(); + + uint64_t offset = 0; + + // See AddTensorProtoDataToMILSpecTensorValue for links to sources for info on where the different typed data is + // stored for ONNX and CoreML + + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + // from: float_data/raw, to: floats + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + } else { + MILBlob::Util::Span data(tensor_proto.float_data().data(), tensor_proto.float_data().size()); + offset = writer.WriteData(data); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + } else { + offset = WriteFromInt32DataUsingStorageWriter(tensor_proto, writer); + } + + break; + } + + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + } else { + offset = WriteFromInt32DataUsingStorageWriter(tensor_proto, writer); + } + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + // from: int32_data/raw, to: bytes + if (has_raw_data) { + offset = WriteRawDataUsingStorageWriter(tensor_proto, writer); + + } else { + offset = WriteFromInt32DataUsingStorageWriter(tensor_proto, writer); + } + break; + } + default: + ORT_THROW("AddWeightToFile: Unsupported data type: ", data_type); + } + + return offset; +} + +MILSpec::Value OnnxTensorToCoreMLTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + MILBlob::Blob::StorageWriter& weights_file_writer) { + MILSpec::Value value; + + // populate ValueType with tensor data type, dims and rank + MILSpec::ValueType& value_type = *value.mutable_type(); + MILSpec::TensorType& tensor_type = *value_type.mutable_tensortype(); + tensor_type.set_datatype(OnnxDataTypeToMILSpec(tensor_proto.data_type())); + + tensor_type.set_rank(tensor_proto.dims().size()); + for (const auto& dim : tensor_proto.dims()) { + tensor_type.add_dimensions()->mutable_constant()->set_size(dim); + } + + // add data to either weights.bin or as an immediate value + if (ShouldWriteInitializerToWeightsFile(tensor_proto)) { + uint64_t offset = CopyOnnxTensorToCoreMLWeightsFile(tensor_proto, weights_file_writer); + + auto* file_value = value.mutable_blobfilevalue(); + // Filename copied from + // https://github.com/apple/coremltools/blob/dbb0094fd0cb936469e35320bf37e866ef7a1da4/coremltools/converters/mil/backend/mil/helper.py#L329 + file_value->set_filename("@model_path/weights/weight.bin"); + file_value->set_offset(offset); + } else { + MILSpec::TensorValue& tensor_value = *value.mutable_immediatevalue()->mutable_tensor(); + CopyOnnxTensorToCoreMLTensor(tensor_proto, tensor_value); + } + + return value; +} + +void CreateEmptyFile(const std::string& filename) { + std::ofstream file(filename, std::ofstream::out | std::ofstream::binary); + ORT_ENFORCE(file.is_open(), "Failed to open file ", filename); } -Status ModelBuilder::Initialize() { - coreml_model_ = std::make_unique(); - { // initialize CoreML model +#endif // defined(COREML_ENABLE_MLPROGRAM) + +std::string GetModelOutputPath(bool create_ml_program) { + // path is used to create the ML Package directory for ML Program, and for the model directly otherwise. + auto path = util::GetTemporaryFilePath(); + if (!create_ml_program) { + path += ".model.mlmodel"; + } + + return path; +} +} // namespace + +ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags) + : graph_viewer_(graph_viewer), + logger_(logger), + coreml_version_(coreml_version), + coreml_flags_(coreml_flags), + create_ml_program_((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0), + model_output_path_(GetModelOutputPath(create_ml_program_)), + coreml_model_(std::make_unique()) { + if (create_ml_program_) { +#if defined(COREML_ENABLE_MLPROGRAM) + coreml_model_->set_specificationversion(CoreMLSpecVersion()); + MILSpec::Program& mlprogram = *coreml_model_->mutable_mlprogram(); + MILSpec::Function& main = (*mlprogram.mutable_functions())["main"]; + + const std::string coreml_opset = "CoreML" + std::to_string(CoreMLVersion()); + *main.mutable_opset() = coreml_opset; + mlprogram_main_ = &(*main.mutable_block_specializations())[coreml_opset]; + + // create the ModelPackage. this creates the output directory. + mlpackage_ = std::make_unique(model_output_path_, /* create */ true); + + // ModelPackage::addItem does a copy of the file. Due to this we 'add' an empty file first, + // and do the actual writes to the file created in the package. + // We can't use ModelPackage::createFile as we have to add a directory for the weights. + std::string tmp_dir = model_output_path_ + "/tmp"; + ORT_THROW_IF_ERROR(Env::Default().CreateFolder(ToPathString(tmp_dir))); + CreateEmptyFile(tmp_dir + "/weight.bin"); + + std::string weights_id = mlpackage_->addItem(tmp_dir, "weights", "com.microsoft.OnnxRuntime", + "CoreML Model Weights"); + auto weights_info = mlpackage_->findItem(weights_id); + weights_file_writer_ = std::make_unique(weights_info->path() + "/weight.bin"); +#else + // should never happen due to handling in coreml_execution_provider.cc + ORT_THROW("ML Program is not enabled in this build"); +#endif + } else { // We support CorelML Specification Version 4 (Core ML 3) coreml_model_->set_specificationversion(4); auto* neural_network = coreml_model_->mutable_neuralnetwork(); - neural_network->set_arrayinputshapemapping(::CoreML::Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); + neural_network->set_arrayinputshapemapping( + CoreML::Specification::NeuralNetworkMultiArrayShapeMapping::EXACT_ARRAY_MAPPING); } +} - PreprocessInitializers(); - ORT_RETURN_IF_ERROR(RegisterInitializers()); - ORT_RETURN_IF_ERROR(RegisterModelInputs()); - ORT_RETURN_IF_ERROR(AddOperations()); - ORT_RETURN_IF_ERROR(RegisterModelOutputs()); +ModelBuilder::~ModelBuilder() = default; - return Status::OK(); +/* + * NeuralNetwork related helpers + */ +std::unique_ptr ModelBuilder::CreateNNLayer(const Node& node, std::string_view suffix) { + auto layer_name = GetUniqueName(node, suffix); + + std::unique_ptr layer = std::make_unique(); + layer->set_name(layer_name); + return layer; +} + +void ModelBuilder::AddLayer(std::unique_ptr layer) { + auto* neural_network = coreml_model_->mutable_neuralnetwork(); + neural_network->mutable_layers()->AddAllocated(layer.release()); } -/* static */ const IOpBuilder* ModelBuilder::GetOpBuilder(const Node& node) { - const auto& op_builders = GetOpBuilders(); - const auto it = op_builders.find(node.OpType()); - if (it != op_builders.cend()) - return it->second; +#if defined(COREML_ENABLE_MLPROGRAM) + +/* + * ML Program related helpers + */ +std::unique_ptr ModelBuilder::CreateOperation(const Node& node, + std::string_view op_type, + std::string_view suffix) { + std::string operation_name = GetUniqueName(node, suffix); + + std::unique_ptr op = std::make_unique(); + op->set_type(std::string(op_type)); + (*op->mutable_attributes())["name"] = CreateScalarTensorValue(operation_name); + + return op; +} + +void ModelBuilder::AddConstant(std::string_view name, const ONNX_NAMESPACE::TensorProto& initializer) { + MILSpec::Value coreml_tensor = OnnxTensorToCoreMLTensor(initializer, *weights_file_writer_); + AddConstantOperation(name, std::move(coreml_tensor)); +} + +void ModelBuilder::AddConstantOperation(std::string_view name, MILSpec::Value&& coreml_tensor) { + // Replicates coremltools/converters/mil/backend/mil/load.py translate_const logic + MILSpec::Operation& const_op = *mlprogram_main_->mutable_operations()->Add(); + const_op.set_type("const"); + + MILSpec::NamedValueType& output = *const_op.mutable_outputs()->Add(); + output.set_name(std::string(name)); + *output.mutable_type() = coreml_tensor.type(); + + auto& attr_map = *const_op.mutable_attributes(); + attr_map["name"] = CreateScalarTensorValue(std::string(name)); + attr_map["val"] = std::move(coreml_tensor); +} + +// Add operation to the Block for the main function in the ML Program +void ModelBuilder::AddOperation(std::unique_ptr operation) { + mlprogram_main_->mutable_operations()->AddAllocated(operation.release()); +} + +std::string ModelBuilder::AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, + MILSpec::Value&& input_value) { + auto unique_value_name = GetUniqueName(MakeString(op_type, "_", value_type)); + AddConstantOperation(unique_value_name, std::move(input_value)); + return unique_value_name; +} + +template +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape) { + // add specialization below + static_assert(false_for_T, "Missing specialization for value type"); + return ""; // unreachable +} + +template <> +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + +template <> +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); // CoreML uses int32 + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} + +template <> +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); +} - return nullptr; +template <> +std::string ModelBuilder::AddConstantImpl(std::string_view op_type, std::string_view value_type, + gsl::span value, + std::optional> shape) { + auto input_value = CreateTensorValue(value, shape); + return AddTensorValueAsConstantOperation(op_type, value_type, std::move(input_value)); } +#endif // defined(COREML_ENABLE_MLPROGRAM) + +/* + * General implementation + */ void ModelBuilder::PreprocessInitializers() { - // TODO: We should be using GetConstantInitializer not GetAllInitializedTensors in all places + // TODO: We should be using GetConstantInitializer not GetAllInitializedTensors in all places. + // non-constant initializers need to be passed in as model inputs in case they're overridden at runtime. const auto& initializers = graph_viewer_.GetAllInitializedTensors(); const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); @@ -64,6 +563,7 @@ void ModelBuilder::PreprocessInitializers() { initializer_usage_[input->Name()]++; } } + if (const auto* op_builder = GetOpBuilder(node)) { op_builder->AddInitializersToSkip(*this, node); } @@ -77,27 +577,34 @@ Status ModelBuilder::RegisterInitializers() { // skip initializer if there is no remaining usage auto usage_count = initializer_usage_[name]; - if (usage_count == 0) + if (usage_count == 0) { continue; + } - std::unique_ptr layer = std::make_unique(); - layer->set_name(GetUniqueName("initializer_" + name)); - - // TODO,look at using LoadConstantLayer instead of LoadConstantNDLayer - auto* constant_tensor = layer->mutable_loadconstantnd(); - const auto& shape = tensor.dims(); - if (shape.empty()) { - // This is a scalar initializer, CoreML constant layer requires a shape, make this a {1} tensor - constant_tensor->mutable_shape()->Add(1); + if (create_ml_program_) { +#if defined(COREML_ENABLE_MLPROGRAM) + AddConstant(name, tensor); +#endif } else { - std::transform(shape.cbegin(), shape.cend(), - google::protobuf::RepeatedFieldBackInserter(constant_tensor->mutable_shape()), - [](int64_t dim) -> uint64_t { return SafeInt(dim); }); - } + std::unique_ptr layer = std::make_unique(); + layer->set_name(GetUniqueName("initializer_" + name)); + + // TODO,look at using LoadConstantLayer instead of LoadConstantNDLayer + auto* constant_tensor = layer->mutable_loadconstantnd(); + const auto& shape = tensor.dims(); + if (shape.empty()) { + // This is a scalar initializer, CoreML constant layer requires a shape, make this a {1} tensor + constant_tensor->mutable_shape()->Add(1); + } else { + std::transform(shape.cbegin(), shape.cend(), + google::protobuf::RepeatedFieldBackInserter(constant_tensor->mutable_shape()), + [](int64_t dim) -> uint64_t { return SafeInt(dim); }); + } - ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*constant_tensor->mutable_data(), tensor)); - *layer->mutable_output()->Add() = name; - AddLayer(std::move(layer)); + ORT_RETURN_IF_ERROR(CreateCoreMLWeight(*constant_tensor->mutable_data(), tensor)); + *layer->mutable_output()->Add() = name; + AddLayer(std::move(layer)); + } } return Status::OK(); @@ -179,15 +686,15 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i data_type = type_proto->tensor_type().elem_type(); switch (data_type) { case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: - multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::FLOAT32); + multi_array->set_datatype(ArrayFeatureType::FLOAT32); break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: - multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::INT32); + multi_array->set_datatype(ArrayFeatureType::INT32); break; case ONNX_NAMESPACE::TensorProto_DataType_INT64: // If we have an int64 input/output type, since COREML_SPEC:ArrayFeatureType does not support INT64 // we assign it to be INT32 here - multi_array->set_datatype(COREML_SPEC::ArrayFeatureType::INT32); + multi_array->set_datatype(ArrayFeatureType::INT32); if (!is_input) { // Record the output names and we need to change them back to Int64 when CoreML EP returns these values to ORT AddInt64Output(name); @@ -204,6 +711,19 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i input_output_info_.emplace(name, OnnxTensorInfo{data_type, shape}); +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + MILSpec::Function& main = (*coreml_model_->mutable_mlprogram()->mutable_functions())["main"]; + if (is_input) { + // the model inputs need to be wired up as args to the 'main' function + main.mutable_inputs()->Add(CreateNamedTensorValueType(node_arg)); + } else { + // the model outputs need to be set as outputs of the Block for the 'main' function + *mlprogram_main_->mutable_outputs()->Add() = node_arg.Name(); + } + } +#endif // defined(COREML_ENABLE_MLPROGRAM) + return Status::OK(); } @@ -215,16 +735,16 @@ Status ModelBuilder::RegisterModelInputs() { return Status::OK(); } -Status ModelBuilder::AddOperations() { - const auto builder_params = MakeOpBuilderParams(graph_viewer_, coreml_flags_); - const auto& node_indices = graph_viewer_.GetNodesInTopologicalOrder(); - for (size_t i = 0; i < node_indices.size(); i++) { - const auto* node(graph_viewer_.GetNode(node_indices[i])); - if (const auto* op_builder = GetOpBuilder(*node)) { - ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, *node, builder_params, logger_)); +Status ModelBuilder::ProcessNodes() { + for (const auto node_idx : graph_viewer_.GetNodesInTopologicalOrder()) { + const auto& node = *graph_viewer_.GetNode(node_idx); + if (const auto* op_builder = GetOpBuilder(node)) { + ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(*this, node, logger_)); } else { + // This shouldn't happen as this is called from CoreMLExecutionProvider::Compile and should only be processing + // nodes that we said were supported and were returned from CoreMLExecutionProvider::GetCapability. return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Node [", node->Name(), "], type [", node->OpType(), "] is not supported"); + "Node [", node.Name(), "], type [", node.OpType(), "] is not supported"); } } @@ -239,29 +759,72 @@ Status ModelBuilder::RegisterModelOutputs() { return Status::OK(); } -Status ModelBuilder::Compile(std::unique_ptr& model, const std::string& path) { - ORT_RETURN_IF_ERROR(SaveCoreMLModel(path)); - model.reset(new Model(path, logger_, coreml_flags_)); - model->SetScalarOutputs(std::move(scalar_outputs_)); - model->SetInt64Outputs(std::move(int64_outputs_)); - model->SetInputOutputInfo(std::move(input_output_info_)); - return model->LoadModel(); +Status ModelBuilder::CreateModel() { + PreprocessInitializers(); + + ORT_RETURN_IF_ERROR(RegisterInitializers()); + ORT_RETURN_IF_ERROR(RegisterModelInputs()); + ORT_RETURN_IF_ERROR(ProcessNodes()); + ORT_RETURN_IF_ERROR(RegisterModelOutputs()); + + return Status::OK(); } -Status ModelBuilder::SaveCoreMLModel(const std::string& path) { - ORT_RETURN_IF_ERROR(Initialize()); - std::ofstream stream(path, std::ofstream::out | std::ofstream::binary); - ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&stream), "Save the CoreML model failed"); +Status ModelBuilder::SaveModel() { + std::string output_path = model_output_path_; + +#if defined(COREML_ENABLE_MLPROGRAM) + if (create_ml_program_) { + std::string tmp_model_path = model_output_path_ + "/tmp/model.mlmodel"; + CreateEmptyFile(tmp_model_path); + + std::string model_id = mlpackage_->setRootModel(tmp_model_path, "model.mlmodel", "com.microsoft.OnnxRuntime", + "CoreML Model Specification"); + auto model_info = mlpackage_->findItem(model_id); + output_path = model_info->path(); + } +#endif - // TODO, Delete, debug only - if (const char* path = std::getenv("ORT_COREML_EP_CONVERTED_MODEL_PATH")) { - std::ofstream temp_stream(path, std::ofstream::out | std::ofstream::binary); - ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&temp_stream), "Save the CoreML model failed"); + // scope this so the stream is closed and flushed by the ofstream dtor + { + LOGS(logger_, INFO) << "Writing CoreML Model to " << output_path; + std::ofstream stream(output_path, std::ofstream::out | std::ofstream::binary); + ORT_RETURN_IF_NOT(coreml_model_->SerializeToOstream(&stream), "Saving the CoreML model failed. Path=", output_path); } +#if defined(COREML_ENABLE_MLPROGRAM) + // need to delete the ModelPackage instance for it to write out the manifest. clear out the other ML Program + // related types as well. + mlprogram_main_ = nullptr; + mlpackage_.reset(); + weights_file_writer_.reset(); +#endif + return Status::OK(); } +Status ModelBuilder::LoadModel(std::unique_ptr& model) { + model = std::make_unique(model_output_path_, + std::move(input_output_info_), + std::move(scalar_outputs_), + std::move(int64_outputs_), + logger_, coreml_flags_); + + return model->LoadModel(); // load using CoreML API, including compilation +} + +// static +Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags, + std::unique_ptr& model) { + ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_flags); + + ORT_RETURN_IF_ERROR(builder.CreateModel()); + ORT_RETURN_IF_ERROR(builder.SaveModel()); + + return builder.LoadModel(model); +} + void ModelBuilder::AddScalarOutput(const std::string& output_name) { scalar_outputs_.insert(output_name); } @@ -270,11 +833,6 @@ void ModelBuilder::AddInt64Output(const std::string& output_name) { int64_outputs_.insert(output_name); } -void ModelBuilder::AddLayer(std::unique_ptr layer) { - auto* neural_network = coreml_model_->mutable_neuralnetwork(); - neural_network->mutable_layers()->AddAllocated(layer.release()); -} - void ModelBuilder::AddInitializerToSkip(const std::string& tensor_name) { // decrement usage count if this is a known initializer. // For simplicity the OpBuilder::AddInitializersToSkip implementations may call this for arbitrary input names @@ -289,7 +847,7 @@ void ModelBuilder::AddInputToSkip(const std::string& input_name) { skipped_inputs_.insert(input_name); } -std::string ModelBuilder::GetUniqueName(const std::string& base_name) { +std::string ModelBuilder::GetUniqueName(std::string_view base_name) { std::string unique_name; do { std::ostringstream os; @@ -300,5 +858,12 @@ std::string ModelBuilder::GetUniqueName(const std::string& base_name) { return unique_name; } +std::string ModelBuilder::GetUniqueName(const Node& node, std::string_view suffix) { + if (node.Name().empty()) { + return GetUniqueName(MakeString("Node_", node.Index(), "_", node.OpType(), suffix)); + } else { + return GetUniqueName(node.Name() + std::string(suffix)); + } +} } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.h b/onnxruntime/core/providers/coreml/builders/model_builder.h index af2d5437be8d1..961ba647257b5 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.h +++ b/onnxruntime/core/providers/coreml/builders/model_builder.h @@ -3,57 +3,171 @@ #pragma once +#include "core/common/span_utils.h" #include "core/graph/graph_viewer.h" #include "core/providers/coreml/builders/coreml_spec.h" +#include "core/providers/coreml/model/model.h" + +#if defined(COREML_ENABLE_MLPROGRAM) +// coremltools classes +namespace MPL { +class ModelPackage; +} + +namespace MILBlob { +namespace Blob { +class StorageWriter; +} +} // namespace MILBlob +#endif namespace onnxruntime { namespace coreml { class IOpBuilder; class Model; -struct OnnxTensorInfo; class ModelBuilder { + private: + ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags); + public: - ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, uint32_t coreml_flags); - ~ModelBuilder() = default; + // Create the CoreML model, serialize to disk, load and compile using the CoreML API and return in `model` + static Status Build(const GraphViewer& graph_viewer, const logging::Logger& logger, + int32_t coreml_version, uint32_t coreml_flags, + std::unique_ptr& model); - Status Compile(std::unique_ptr& model, const std::string& path); - Status SaveCoreMLModel(const std::string& path); + ~ModelBuilder(); - // Accessors for members const GraphViewer& GetGraphViewer() const { return graph_viewer_; } const InitializedTensorSet& GetInitializerTensors() const { return graph_viewer_.GetAllInitializedTensors(); } - + const ONNX_NAMESPACE::TensorProto* GetConstantInitializer(const std::string& name) const { + return graph_viewer_.GetConstantInitializer(name, true); + } + + // Since CoreML 2 the spec version is +1 as CoreML 1.1 was spec version 2. + // We only support CoreML 3 and later so the spec version is always version + 1. + int32_t CoreMLVersion() const { return coreml_version_; } + int32_t CoreMLSpecVersion() const { return coreml_version_ + 1; } + + // Returns true if we are creating an ML Program + bool CreateMLProgram() const { +#if defined(COREML_ENABLE_MLPROGRAM) + return create_ml_program_; +#else + return false; +#endif + } + + /* + * NeuralNetworkLayer helpers + */ + + // Create a NeuralNetwork layer using the node name and optional suffix for the name. + // If Node has no name a unique name will be generated from the node index and operator. + std::unique_ptr CreateNNLayer(const Node& node, std::string_view suffix = ""); + + // Add layer to the Core ML NeuralNetwork model void AddLayer(std::unique_ptr layer); - // The initializer will be processed separately, skip it as an initializer +#if defined(COREML_ENABLE_MLPROGRAM) + /* + * MLProgram helpers + */ + + // Create Operation, set type and the unique name attribute. + std::unique_ptr CreateOperation(const Node& node, std::string_view op_type, + std::string_view suffix = ""); + + // + // Helpers for adding attributes from ONNX nodes as inputs to an ML Program Operation + // + + /// + /// Add a value as a 'const' operation, generating a unique name for the value from op_type and value_type. + /// Use for values that were not initializers in the original ONNX model. e.g. attributes from ONNX nodes. + /// Add existing initializers using AddConstant with the TensorProto. + /// + /// e.g. adding the bias input of Gemm would have op_type='gemm' and value_type='bias'. + /// + /// Value type. + /// Typically MILSpec::Operation.type(). + /// Typically the input name of the operation that will consume the value. + /// Value to add. + /// Optional shape for the value. + /// If T is a primitive type `shape` is ignored and the value is treated as a scalar. + /// For a container type, if `shape` is not provided the shape is inferred to be 1-D of {value.size()}. + /// + /// Unique name generated for value. + template + std::string AddConstant(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape = std::nullopt) { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v, + // add specialization in AddConstantImpl for new types if needed + "AddConstant currently supports float, int64_t, std::string and bool."); + return AddConstantImpl(op_type, value_type, value, shape); + } + + template + std::string AddConstant(std::string_view op_type, std::string_view value_type, const std::vector& value, + std::optional> shape = std::nullopt) { + return AddConstant(op_type, value_type, AsSpan(value), shape); + } + + /// + /// Add a scalar value as a 'const' operation. See AddConstant for details. + /// + template + std::string AddScalarConstant(std::string_view op_type, std::string_view value_type, const T& value) { + return AddConstant(op_type, value_type, AsSpan({value}), AsSpan({})); + } + + /// + /// Add an existing a constant ONNX initializer to the ML Program as a 'const' operation + /// + /// Initializer name + /// Initializer data + void AddConstant(std::string_view name, const ONNX_NAMESPACE::TensorProto& initializer); + + // add the operation to the main function + void AddOperation(std::unique_ptr operation); +#endif + + /* + * General helpers + */ + + // The initializer is processed separately (e.g. layout is transformed) by the operator builder, + // so we don't do a copy of the original initializer into the model. void AddInitializerToSkip(const std::string& tensor_name); // There are some input which will not be used, add it to a list which will not // be added to CoreML model, since CoreML does not like input unused void AddInputToSkip(const std::string& input_name); - std::string GetUniqueName(const std::string& base_name); + std::string GetUniqueName(std::string_view base_name); + std::string GetUniqueName(const Node& node, std::string_view suffix); private: - const GraphViewer& graph_viewer_; - const logging::Logger& logger_; - uint32_t coreml_flags_; - - std::unique_ptr coreml_model_; - std::unordered_set scalar_outputs_; - std::unordered_set int64_outputs_; - std::unordered_map input_output_info_; - - std::unordered_map initializer_usage_; - std::unordered_set skipped_inputs_; - - uint32_t name_token_{0}; - std::unordered_set unique_names_; - - // Convert the onnx model to CoreML::Specification::Model - Status Initialize(); +#if defined(COREML_ENABLE_MLPROGRAM) + template + std::string AddConstantImpl(std::string_view op_type, std::string_view value_type, gsl::span value, + std::optional> shape = std::nullopt); + + void AddConstantOperation(std::string_view name, COREML_SPEC::MILSpec::Value&& initializer); + std::string AddTensorValueAsConstantOperation(std::string_view op_type, std::string_view value_type, + COREML_SPEC::MILSpec::Value&& input_value); +#endif + + // Convert the ONNX model in graph_viewer_ to a CoreML::Specification::Model and serialize to disk. + // We then load it using CoreML in order compile it. + Status CreateModel(); + Status SaveModel(); + Status LoadModel(std::unique_ptr& model); // If a CoreML operation will use initializers directly, we will add the initializers to the skip list void PreprocessInitializers(); @@ -61,7 +175,7 @@ class ModelBuilder { // Copy and process all the initializers to CoreML model Status RegisterInitializers(); - Status AddOperations(); + Status ProcessNodes(); Status RegisterModelInputs(); Status RegisterModelOutputs(); Status RegisterModelInputOutput(const NodeArg& node_arg, bool is_input); @@ -72,7 +186,32 @@ class ModelBuilder { // Record the onnx int64 type output names void AddInt64Output(const std::string& output_name); - static const IOpBuilder* GetOpBuilder(const Node& node); + const GraphViewer& graph_viewer_; + const logging::Logger& logger_; + const int32_t coreml_version_; + const uint32_t coreml_flags_; + const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old) + const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel + + std::unique_ptr coreml_model_; + std::unordered_set scalar_outputs_; + std::unordered_set int64_outputs_; + std::unordered_map input_output_info_; + + std::unordered_map initializer_usage_; + std::unordered_set skipped_inputs_; + + uint32_t name_token_{0}; + std::unordered_set unique_names_; + +#if defined(COREML_ENABLE_MLPROGRAM) + // mlprogram_main_ is the main block of the CoreML ML Program. + // It is set in CreateModel to the CoreML Model.mlprogram.functions['main'].block_specializations['CoreML'] + // entry we create. + COREML_SPEC::MILSpec::Block* mlprogram_main_{nullptr}; + std::unique_ptr mlpackage_; + std::unique_ptr weights_file_writer_; +#endif }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/op_builder.h b/onnxruntime/core/providers/coreml/builders/op_builder.h index 79de6438c9700..0bb7f280c33e6 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder.h @@ -11,36 +11,39 @@ namespace coreml { class ModelBuilder; struct OpBuilderInputParams { - OpBuilderInputParams(const GraphViewer& graph_viewer, bool only_allow_static_input_shapes) + OpBuilderInputParams(const GraphViewer& graph_viewer, + int32_t coreml_version, + bool only_allow_static_input_shapes, + bool create_mlprogram) : graph_viewer(graph_viewer), - only_allow_static_input_shapes(only_allow_static_input_shapes) {} + coreml_version(coreml_version), + only_allow_static_input_shapes(only_allow_static_input_shapes), + create_mlprogram(create_mlprogram) {} const GraphViewer& graph_viewer; + const int32_t coreml_version; // required to determine which version of an operation can be used. const bool only_allow_static_input_shapes; + const bool create_mlprogram; // whether to create ML Program (Core ML 5+) or NeuralNetwork (Core ML 3+) }; class IOpBuilder { public: virtual ~IOpBuilder() = default; - // Add operator related -#ifdef __APPLE__ - public: // Check if the initializers of this operator need preprocess // which will not be copied virtual void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const = 0; // Add the operator to CoreML model virtual Status AddToModelBuilder(ModelBuilder& model_builder, const Node& node, - const OpBuilderInputParams& input_params, const logging::Logger& logger) const = 0; -#endif - // Operator support related - public: // Check if an operator is supported virtual bool IsOpSupported(const Node& node, const OpBuilderInputParams& input_params, const logging::Logger& logger) const = 0; + + // Does the builder implementation support creating an ML Program? + virtual bool SupportsMLProgram() const = 0; }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h index d72420bcfff88..6469b4cefa5ea 100644 --- a/onnxruntime/core/providers/coreml/builders/op_builder_factory.h +++ b/onnxruntime/core/providers/coreml/builders/op_builder_factory.h @@ -3,7 +3,7 @@ #pragma once -#include "op_builder.h" +#include "core/providers/coreml/builders/op_builder.h" namespace onnxruntime { namespace coreml { diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index c133f7b82aba4..8e718da07703c 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -2,9 +2,11 @@ // Licensed under the MIT License. #include "core/providers/coreml/coreml_execution_provider.h" +#include "core/providers/coreml/coreml_provider_factory.h" // defines flags #include +#include "core/common/logging/logging.h" #include "core/framework/compute_capability.h" #include "core/framework/tensorprotoutils.h" #include "core/graph/graph_viewer.h" @@ -12,12 +14,10 @@ #include "core/providers/partitioning_utils.h" #include "core/session/onnxruntime_cxx_api.h" -#ifdef __APPLE__ #include "core/providers/coreml/builders/model_builder.h" #include "core/providers/coreml/model/host_utils.h" #include "core/providers/coreml/model/model.h" #include "core/providers/coreml/shape_utils.h" -#endif namespace onnxruntime { @@ -25,7 +25,24 @@ constexpr const char* COREML = "CoreML"; CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags) : IExecutionProvider{onnxruntime::kCoreMLExecutionProvider}, - coreml_flags_(coreml_flags) { + coreml_flags_(coreml_flags), + coreml_version_(coreml::util::CoreMLVersion()) { + if (coreml_version_ < MINIMUM_COREML_VERSION) { + LOGS_DEFAULT(ERROR) << "CoreML EP is not supported on this platform."; + } + +#if defined(COREML_ENABLE_MLPROGRAM) + if (coreml_version_ < MINIMUM_COREML_MLPROGRAM_VERSION && + (coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) { + LOGS_DEFAULT(WARNING) << "ML Program is not supported on this OS version. Falling back to NeuralNetwork."; + coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM; + } +#else + if ((coreml_flags_ & COREML_FLAG_CREATE_MLPROGRAM) != 0) { + LOGS_DEFAULT(WARNING) << "ML Program is not supported in this build. Falling back to NeuralNetwork."; + coreml_flags_ ^= COREML_FLAG_CREATE_MLPROGRAM; + } +#endif } CoreMLExecutionProvider::~CoreMLExecutionProvider() {} @@ -35,28 +52,34 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const IKernelLookup& /*kernel_lookup*/) const { std::vector> result; - // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes - // TODO investigate whether we want to support subgraph using CoreML EP - if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) { + if (coreml_version_ < MINIMUM_COREML_VERSION) { return result; } const auto& logger = *GetLogger(); + // We do not run CoreML EP on subgraph, instead we cover this in the control flow nodes + // TODO investigate whether we want to support subgraph using CoreML EP. May simply require processing the + // implicit inputs of the control flow node that contains the subgraph as inputs to the CoreML model we generate. + if (graph_viewer.IsSubgraph() && !(coreml_flags_ & COREML_FLAG_ENABLE_ON_SUBGRAPH)) { + return result; + } + const bool has_neural_engine = coreml::HasNeuralEngine(logger); if ((coreml_flags_ & COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE) && !has_neural_engine) { - LOGS(logger, VERBOSE) << "The current system does not have Apple Neural Engine"; + LOGS(logger, WARNING) << "The current system does not have Apple Neural Engine. CoreML EP will not be used."; return result; } - const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_flags_); + const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_, coreml_flags_); const auto supported_nodes = coreml::GetSupportedNodes(graph_viewer, builder_params, logger); - const auto gen_metadef_name = [&]() { - HashValue model_hash; - int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); - return MakeString(COREML, "_", model_hash, "_", metadef_id); - }; + const auto gen_metadef_name = + [&]() { + HashValue model_hash; + int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + return MakeString(COREML, "_", model_hash, "_", metadef_id); + }; result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {}, gen_metadef_name, COREML, kCoreMLExecutionProvider); @@ -86,17 +109,16 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return result; } -#ifdef __APPLE__ +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) common::Status CoreMLExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, std::vector& node_compute_funcs) { for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { Node& fused_node = fused_node_and_graph.fused_node; const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); - coreml::ModelBuilder builder(graph_viewer, *GetLogger(), coreml_flags_); std::unique_ptr coreml_model; - const std::string coreml_model_file_path = coreml::util::GetTemporaryFilePath(); - ORT_RETURN_IF_ERROR(builder.Compile(coreml_model, coreml_model_file_path)); + ORT_RETURN_IF_ERROR(coreml::ModelBuilder::Build(graph_viewer, *GetLogger(), coreml_version_, coreml_flags_, + coreml_model)); { const auto& input_defs = fused_node.InputDefs(); @@ -241,22 +263,6 @@ common::Status CoreMLExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { - for (const auto& fused_node_and_graph : fused_nodes_and_graphs) { - ORT_UNUSED_PARAMETER(fused_node_and_graph); - NodeComputeInfo compute_info; - compute_info.create_state_func = [](ComputeContext* /*context*/, FunctionState* /*state*/) { return 0; }; - compute_info.release_state_func = [](FunctionState /*state*/) {}; - compute_info.compute_func = [](FunctionState /* state */, const OrtApi* /* api */, - OrtKernelContext* /* context */) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Compute is not supported in this build."); - }; - node_compute_funcs.push_back(compute_info); - } - return Status::OK(); -} -#endif //__APPLE__ +#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.h b/onnxruntime/core/providers/coreml/coreml_execution_provider.h index 0201739547dd1..24a001280eef5 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.h +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.h @@ -3,9 +3,9 @@ #pragma once +#include "core/common/inlined_containers.h" #include "core/framework/execution_provider.h" #include "core/framework/model_metadef_id_generator.h" -#include "core/providers/coreml/coreml_provider_factory.h" namespace onnxruntime { namespace coreml { @@ -26,15 +26,14 @@ class CoreMLExecutionProvider : public IExecutionProvider { std::vector& node_compute_funcs) override; #endif + private: // The bit flags which define bool options for COREML EP, bits are defined as // COREMLFlags in include/onnxruntime/core/providers/coreml/coreml_provider_factory.h - const uint32_t coreml_flags_; - - private: -// > -#ifdef __APPLE__ - std::unordered_map> coreml_models_; -#endif + uint32_t coreml_flags_; + const int32_t coreml_version_; ModelMetadefIdGenerator metadef_id_generator_; + + // map of fused_node_name to compiled_coreml_model + InlinedHashMap> coreml_models_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/host_utils.h b/onnxruntime/core/providers/coreml/model/host_utils.h index f7f45bce087bc..4f9a014c4d885 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.h +++ b/onnxruntime/core/providers/coreml/model/host_utils.h @@ -8,10 +8,50 @@ #include -#define API_AVAILABLE_OS_VERSIONS API_AVAILABLE(macos(10.15), ios(13)) +#if defined(__APPLE__) +// See https://apple.github.io/coremltools/mlmodel/Format/Model.html for the info on each CoreML specification version. +// See https://apple.github.io/coremltools/source/coremltools.converters.mil.mil.ops.defs.html for the list of ops +// in each CoreML specification version. -// Base requireed OS to run CoreML Specification Version 4 (Core ML 3) -#define HAS_VALID_BASE_OS_VERSION @available(macOS 10.15, iOS 13, *) +// Specification Versions : OS Availability(Core ML Version) +// +// 4 : iOS 13, macOS 10.15, tvOS 13, watchOS 6 (Core ML 3) +// - initial version of CoreML EP +// 5 : iOS 14, macOS 11, tvOS 14, watchOS 7 (Core ML 4) +// - additional layers in NeuralNetwork but currently none are implemented by the CoreML EP +// 6 : iOS 15, macOS 12, tvOS 15, watchOS 8 (Core ML 5) +// - adds MLProgram (MILSpec.Program) +// - iOS 15 ops +// 7 : iOS 16, macOS 13, tvOS 16, watchOS 9 (Core ML 6) +// - iOS 16 ops +// 8 : iOS 17, macOS 14, tvOS 17, watchOS 10 (Core ML 7) +// - iOS 17 ops +// +// **NOTE** We use the Core ML version not the spec version. +// +// e.g. iOS 13 has Core ML 3 (which is Core ML Specification version 4), and the related macros are +// API_AVAILABLE_COREML3, HAS_COREML3_OR_LATER and onnxruntime::coreml::util::CoreMLVersion() will return 3. + +// https://developer.apple.com/documentation/swift/marking-api-availability-in-objective-c +// API_AVAILABLE is used to decorate Objective-C APIs +#define API_AVAILABLE_COREML3 API_AVAILABLE(macos(10.15), ios(13)) +#define API_AVAILABLE_COREML4 API_AVAILABLE(macos(11), ios(14)) +#define API_AVAILABLE_COREML5 API_AVAILABLE(macos(12), ios(15)) +#define API_AVAILABLE_COREML6 API_AVAILABLE(macos(13), ios(16)) +#define API_AVAILABLE_COREML7 API_AVAILABLE(macos(14), ios(17)) + +// @available is used in implementation code +// Base required OS to run CoreML Specification Version 4 (Core ML 3) +#define HAS_COREML3_OR_LATER @available(macOS 10.15, iOS 13, *) +#define HAS_COREML4_OR_LATER @available(macOS 11, iOS 14, *) +#define HAS_COREML5_OR_LATER @available(macOS 12, iOS 15, *) +#define HAS_COREML6_OR_LATER @available(macOS 13, iOS 16, *) +#define HAS_COREML7_OR_LATER @available(macOS 14, iOS 17, *) + +#endif + +#define MINIMUM_COREML_VERSION 3 // first version we support +#define MINIMUM_COREML_MLPROGRAM_VERSION 5 // first version where ML Program was available namespace onnxruntime { namespace coreml { @@ -21,6 +61,9 @@ namespace util { // This corresponds to [CoreML Specification Version 4 (Core ML 3)] bool HasRequiredBaseOS(); +// Return the CoreML version if 3 or higher. Otherwise returns -1. +int CoreMLVersion(); + // Get a temporary macOS/iOS temp file path std::string GetTemporaryFilePath(); diff --git a/onnxruntime/core/providers/coreml/model/host_utils.mm b/onnxruntime/core/providers/coreml/model/host_utils.mm index 4c394386cd37a..0ae0cf8f0d207 100644 --- a/onnxruntime/core/providers/coreml/model/host_utils.mm +++ b/onnxruntime/core/providers/coreml/model/host_utils.mm @@ -10,19 +10,33 @@ namespace util { bool HasRequiredBaseOS() { - // This may look strange, but it is required "@available(macOS ....)" to safe-guard some code - // otherwise the compiler will spit -Wunsupported-availability-guard - if (HAS_VALID_BASE_OS_VERSION) - return true; - else - return false; + return CoreMLVersion() >= 3; +} + +int32_t CoreMLVersion() { + if (HAS_COREML7_OR_LATER) + return 7; + if (HAS_COREML6_OR_LATER) + return 6; + if (HAS_COREML5_OR_LATER) + return 5; + if (HAS_COREML4_OR_LATER) + return 4; + if (HAS_COREML3_OR_LATER) + return 3; + + return -1; } std::string GetTemporaryFilePath() { - // Get temporary directory. + // Get temporary directory for user. NSURL* temporary_directory_url = [NSURL fileURLWithPath:NSTemporaryDirectory() isDirectory:YES]; // Generate a Unique file name to use. NSString* temporary_filename = [[NSProcessInfo processInfo] globallyUniqueString]; + + // make it easy to see who generated it + temporary_filename = [@"onnxruntime-" stringByAppendingString:temporary_filename]; + // Create URL to that file. NSURL* temporary_file_url = [temporary_directory_url URLByAppendingPathComponent:temporary_filename]; diff --git a/onnxruntime/core/providers/coreml/model/host_utils_stub.cc b/onnxruntime/core/providers/coreml/model/host_utils_stub.cc new file mode 100644 index 0000000000000..5c383b0274e8c --- /dev/null +++ b/onnxruntime/core/providers/coreml/model/host_utils_stub.cc @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "core/platform/env.h" +#include "core/providers/coreml/model/host_utils.h" + +namespace onnxruntime { +namespace coreml { +namespace util { + +bool HasRequiredBaseOS() { + return true; +} + +int CoreMLVersion() { + return 7; // CoreML 7 is the latest we support. +} + +std::string GetTemporaryFilePath() { + static std::atomic counter = 0; + + // we want to avoid creating endless directories/names whilst avoiding clashes if tests run in parallel so cycle + // through 20 potential output names. + auto dir_name = "coreml_ep_test_run." + std::to_string(counter++ % 20); + + // to replicate the iOS/macOS host_utils.mm behavior where the output is / + // we want to return the name of something that does not exist. this is required for ML Package creation. + auto& env = Env::Default(); + if (env.FolderExists(dir_name)) { + ORT_THROW_IF_ERROR(env.DeleteFolder(ToPathString(dir_name))); + } + + return dir_name; +} + +} // namespace util +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/model.h b/onnxruntime/core/providers/coreml/model/model.h index 105b6a0333b15..b940c4b768aec 100644 --- a/onnxruntime/core/providers/coreml/model/model.h +++ b/onnxruntime/core/providers/coreml/model/model.h @@ -33,19 +33,29 @@ using GetOutputTensorMutableRawDataFn = std::function static_shape)>; class Model { - friend class ModelBuilder; - public: + Model(const std::string& path, + std::unordered_map&& input_output_info, + std::unordered_set&& scalar_outputs, + std::unordered_set&& int64_outputs, + const logging::Logger& logger, uint32_t coreml_flags); + ~Model(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Model); + Status LoadModel(); + Status Predict(const std::unordered_map& inputs, const std::unordered_map& outputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn); - bool IsScalarOutput(const std::string& output_name) const; + bool IsScalarOutput(const std::string& output_name) const { + return Contains(scalar_outputs_, output_name); + } - bool IsInt64Output(const std::string& output_name) const; + bool IsInt64Output(const std::string& output_name) const { + return Contains(int64_outputs_, output_name); + } // Mutex for exclusive lock to this model object OrtMutex& GetMutex() { return mutex_; } @@ -57,35 +67,27 @@ class Model { const std::vector& GetOnnxOutputs() const { return onnx_outputs_; } void SetOnnxOutputs(std::vector&& outputs) { onnx_outputs_ = std::move(outputs); } - const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const; - const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const; + const OnnxTensorInfo* TryGetInputOutputInfo(const std::string& name) const { + const auto info_it = input_output_info_.find(name); + return info_it != input_output_info_.end() ? &info_it->second : nullptr; + } + + const OnnxTensorInfo& GetInputOutputInfo(const std::string& name) const { + const auto* info = TryGetInputOutputInfo(name); + ORT_ENFORCE(info != nullptr, "Failed to get info for input/output: ", name); + return *info; + } private: std::unique_ptr execution_; + std::unordered_map input_output_info_; std::unordered_set scalar_outputs_; std::unordered_set int64_outputs_; std::vector onnx_inputs_; std::vector onnx_outputs_; - std::unordered_map input_output_info_; - OrtMutex mutex_; - - Model(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags); - Status LoadModel(); - - void SetInputOutputInfo(std::unordered_map&& input_output_info) { - input_output_info_ = std::move(input_output_info); - } - - void SetScalarOutputs(std::unordered_set&& scalar_outputs) { - scalar_outputs_ = std::move(scalar_outputs); - } - - void SetInt64Outputs(std::unordered_set&& int64_outputs) { - int64_outputs_ = std::move(int64_outputs); - } }; } // namespace coreml diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm index 155201ad4c39c..d5cd70bff9479 100644 --- a/onnxruntime/core/providers/coreml/model/model.mm +++ b/onnxruntime/core/providers/coreml/model/model.mm @@ -252,14 +252,14 @@ - (instancetype)initWithPath:(const std::string&)path coreml_flags:(uint32_t)coreml_flags; - (void)cleanup; - (void)dealloc; -- (Status)loadModel API_AVAILABLE_OS_VERSIONS; +- (Status)loadModel API_AVAILABLE_COREML3; - (Status)predict:(const std::unordered_map&)inputs outputs:(const std::unordered_map&)outputs getOutputTensorDataFn:(const GetOutputTensorMutableRawDataFn&) get_output_tensor_mutable_raw_data_fn - API_AVAILABLE_OS_VERSIONS; + API_AVAILABLE_COREML3; -@property(nullable) MLModel* model API_AVAILABLE_OS_VERSIONS; +@property(nullable) MLModel* model API_AVAILABLE_COREML3; @end @@ -308,6 +308,10 @@ - (Status)loadModel { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create model URL from path"); } + // TODO: Update this to version with callback handler as the API used here is deprecated. + // https://developer.apple.com/documentation/coreml/mlmodel/3929553-compilemodelaturl + // As we call loadModel during EP Compile there shouldn't be an issue letting the actual compile run in the + // background. We will have to check for completion in `predict` and block until it is done. NSError* error = nil; NSURL* compileUrl = [MLModel compileModelAtURL:modelUrl error:&error]; @@ -454,7 +458,7 @@ Status Predict(const std::unordered_map& inputs, return Status::OK(); } - if (HAS_VALID_BASE_OS_VERSION) { + if (HAS_COREML3_OR_LATER) { Status status{}; @autoreleasepool { status = [execution_ loadModel]; @@ -471,7 +475,7 @@ Status Predict(const std::unordered_map& inputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) { ORT_RETURN_IF_NOT(model_loaded, "Execution::Predict requires Execution::LoadModel"); - if (HAS_VALID_BASE_OS_VERSION) { + if (HAS_COREML3_OR_LATER) { @autoreleasepool { return [execution_ predict:inputs outputs:outputs @@ -482,8 +486,16 @@ Status Predict(const std::unordered_map& inputs, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Execution::Predict requires macos 10.15+ or ios 13+"); } -Model::Model(const std::string& path, const logging::Logger& logger, uint32_t coreml_flags) - : execution_(std::make_unique(path, logger, coreml_flags)) { +Model::Model(const std::string& path, + std::unordered_map&& input_output_info, + std::unordered_set&& scalar_outputs, + std::unordered_set&& int64_outputs, + const logging::Logger& logger, + uint32_t coreml_flags) + : execution_(std::make_unique(path, logger, coreml_flags)), + input_output_info_(std::move(input_output_info)), + scalar_outputs_(std::move(scalar_outputs)), + int64_outputs_(std::move(int64_outputs)) { } Model::~Model() {} @@ -497,25 +509,5 @@ Status Predict(const std::unordered_map& inputs, const GetOutputTensorMutableRawDataFn& get_output_tensor_mutable_raw_data_fn) { return execution_->Predict(inputs, outputs, get_output_tensor_mutable_raw_data_fn); } - -bool Model::IsScalarOutput(const std::string& output_name) const { - return Contains(scalar_outputs_, output_name); -} - -bool Model::IsInt64Output(const std::string& output_name) const { - return Contains(int64_outputs_, output_name); -} - -const OnnxTensorInfo* Model::TryGetInputOutputInfo(const std::string& name) const { - const auto info_it = input_output_info_.find(name); - return info_it != input_output_info_.end() ? &info_it->second : nullptr; -} - -const OnnxTensorInfo& Model::GetInputOutputInfo(const std::string& name) const { - const auto* info = TryGetInputOutputInfo(name); - ORT_ENFORCE(info != nullptr, "Failed to get info for input/output: ", name); - return *info; -} - } // namespace coreml } // namespace onnxruntime diff --git a/onnxruntime/core/providers/coreml/model/model_stub.cc b/onnxruntime/core/providers/coreml/model/model_stub.cc new file mode 100644 index 0000000000000..087c9f8c05d5f --- /dev/null +++ b/onnxruntime/core/providers/coreml/model/model_stub.cc @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/coreml/model/model.h" + +namespace onnxruntime { +namespace coreml { + +class Execution {}; + +Model::Model(const std::string& /*path*/, + std::unordered_map&& input_output_info, + std::unordered_set&& scalar_outputs, + std::unordered_set&& int64_outputs, + const logging::Logger& /*logger*/, + uint32_t /*coreml_flags*/) + : execution_(std::make_unique()), + input_output_info_(std::move(input_output_info)), + scalar_outputs_(std::move(scalar_outputs)), + int64_outputs_(std::move(int64_outputs)) { +} + +Model::~Model() { +} + +Status Model::LoadModel() { + // return OK so we hit more CoreML EP code. + return Status::OK(); +} + +Status Model::Predict(const std::unordered_map& /*inputs*/, + const std::unordered_map& /*outputs*/, + const GetOutputTensorMutableRawDataFn& /*get_output_tensor_mutable_raw_data_fn*/) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Executing a CoreML model is not supported on this platform."); +} + +} // namespace coreml +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 77e682e05a2a4..48a952e6dd98f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -989,6 +989,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Sqrt); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Sqrt); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Log); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Log); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Log); @@ -1882,6 +1883,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 655877f425054..fd8b69d7bd2f5 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -160,7 +160,7 @@ UNARY_OP_CSILHFD(Neg, 13) UNARY_OP_HFD(Floor, 13) UNARY_OP_HFD(Ceil, 13) UNARY_OP_HFD(Reciprocal, 13) -UNARY_OP_HFD(Sqrt, 13) +UNARY_OP_HFDX(Sqrt, 13) UNARY_OP_HFD(Log, 13) UNARY_OP_HFD(Exp, 13) UNARY_OP_HFD(Erf, 13) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 5c3db4a499972..73c5ac80756be 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -83,7 +83,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_CSILHFD(Neg) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Floor) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Ceil) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal) -SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Sqrt) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf) diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index 82f3503919237..a417be5a86c32 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -326,7 +326,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), - CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType())); + CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType(), + UseTF32())); if (context->InputCount() >= 3) { const Tensor* B = context->Input(2); @@ -351,8 +352,13 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) + if constexpr (std::is_same::value) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } cudnnConvolutionFwdAlgoPerf_t perf; int algo_count = 1; @@ -399,6 +405,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); if (std::is_same::value) { perf.mathType = CUDNN_TENSOR_OP_MATH; + } else if (std::is_same::value && !UseTF32()) { + perf.mathType = CUDNN_FMA_MATH; } else { perf.mathType = CUDNN_DEFAULT_MATH; } @@ -480,7 +488,8 @@ Status CudnnConvolutionDescriptor::Set( const gsl::span& dilations, int groups, cudnnConvolutionMode_t mode, - cudnnDataType_t data_type) { + cudnnDataType_t data_type, + bool use_tf32) { if (!desc_) CUDNN_RETURN_IF_ERROR(cudnnCreateConvolutionDescriptor(&desc_)); @@ -513,6 +522,8 @@ Status CudnnConvolutionDescriptor::Set( CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_DEFAULT_MATH)); if (data_type == CUDNN_DATA_HALF) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_TENSOR_OP_MATH)); + } else if (data_type == CUDNN_DATA_FLOAT && !use_tf32) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(desc_, CUDNN_FMA_MATH)); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index bcaa4d855b81e..181fbc99fd8e9 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -29,7 +29,8 @@ class CudnnConvolutionDescriptor final { const gsl::span& dilations, int groups, cudnnConvolutionMode_t mode, - cudnnDataType_t data_type); + cudnnDataType_t data_type, + bool use_tf32); operator cudnnConvolutionDescriptor_t() const { return desc_; } diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 55dceaa2698e8..939b9959af818 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -167,7 +167,8 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; ORT_RETURN_IF_ERROR(s_.conv_desc.Set(p.kernel_shape.size(), p.pads, p.strides, p.dilations, gsl::narrow_cast(conv_transpose_attrs_.group), mode, - CudnnTensor::GetDataType())); + CudnnTensor::GetDataType(), + UseTF32())); if (has_bias) { const auto& b_shape = p.B->Shape(); @@ -187,8 +188,13 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dy GetScratchBuffer(AlgoSearchWorkspaceSize, context->GetComputeStream()); // set math type to tensor core before algorithm search - if constexpr (std::is_same::value) + if constexpr (std::is_same::value) { CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + } else if constexpr (std::is_same::value) { + if (!UseTF32()) { + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_FMA_MATH)); + } + } cudnnConvolutionBwdDataAlgoPerf_t perf; int algo_count = 1; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 30c339b845b36..44004b5d77f70 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -43,6 +43,10 @@ class DmlOperatorRotaryEmbedding : public DmlOperator ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 4); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + // When the input is 4D, it has the shape [batchSize, numHeads, sequenceLength, headSize]. Otherwise, + // it has the shape [batchSize, sequenceLength, hiddenSize] + const bool inputIs4D = kernelInfo.GetInputTensorDimensionCount(inputDataIndex) == 4; + // When positionIds is a scalar, it represents the start offset for each sequence const bool positionIdsIsOffset = kernelInfo.GetInputTensorDimensionCount(positionIdsIndex) == 1; @@ -63,9 +67,9 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // We resize the data to be of shape [batchSize, sequenceLength, numHeads, headSize] const auto inputDataSizes = m_inputTensorDescs[inputDataIndex].GetSizes(); - const uint32_t batchSize = inputDataSizes[1]; + const uint32_t batchSize = inputIs4D ? inputDataSizes[0] : inputDataSizes[1]; const uint32_t sequenceLength = inputDataSizes[2]; - const uint32_t numHeads = inputDataSizes[3] / headSize; + const uint32_t numHeads = inputIs4D ? inputDataSizes[1] : inputDataSizes[3] / headSize; const auto cosCacheSizes = m_inputTensorDescs[cosCacheIndex].GetSizes(); const uint32_t maxSequenceLength = cosCacheSizes[cosCacheSizes.size() - 2]; @@ -80,16 +84,24 @@ class DmlOperatorRotaryEmbedding : public DmlOperator std::vector inputDescs = GetDmlInputDescs(); const MLOperatorTensorDataType dataType = kernelInfo.GetInputEdgeDescription(inputDataIndex).tensorDataType; - // Splitting the hiddenSize into numHeads and headSize dimensions makes it easier for DML to handle const std::array inputOutputShape = {batchSize, sequenceLength, numHeads, headSize}; TensorDesc inputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + TensorDesc stridedInputOutputTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputOutputShape); + + if (inputIs4D) + { + const std::array inputOutputStrides = {headSize * numHeads * sequenceLength, headSize, sequenceLength * headSize, 1}; + stridedInputOutputTensorDesc.SetStrides(inputOutputStrides); + } + const DML_TENSOR_DESC inputOutputDmlTensorDesc = inputOutputTensorDesc.GetDmlDesc(); + const DML_TENSOR_DESC stridedInputOutputDmlTensorDesc = stridedInputOutputTensorDesc.GetDmlDesc(); // Copy the input to preserve its real input shape in the graph without reshaping it. This will disappear during DML's graph compilation phase. DML_SCALE_BIAS scaleBias = {1.0f, 0.0f}; DML_ELEMENT_WISE_IDENTITY_OPERATOR_DESC copyInputDesc{}; - copyInputDesc.InputTensor = &inputOutputDmlTensorDesc; + copyInputDesc.InputTensor = &stridedInputOutputDmlTensorDesc; copyInputDesc.OutputTensor = &inputOutputDmlTensorDesc; copyInputDesc.ScaleBias = &scaleBias; const DML_OPERATOR_DESC copyInputDmlDesc = {DML_OPERATOR_ELEMENT_WISE_IDENTITY, ©InputDesc}; @@ -104,8 +116,12 @@ class DmlOperatorRotaryEmbedding : public DmlOperator : std::vector({batchSize, sequenceLength, numHeads, 1, headSize / 2}); TensorDesc inputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC inputDataDmlTensorDesc = inputDataTensorDesc.GetDmlDesc(); + TensorDesc joinedDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, inputDataTensorShape); + const DML_TENSOR_DESC joinedDataDmlTensorDesc = joinedDataTensorDesc.GetDmlDesc(); + TensorDesc splitInputDataTensorDesc = TensorDesc::ConstructDefaultTensorDesc(dataType, splitInputDataTensorShape); const std::array splitInputDataDmlTensorDescs = {splitInputDataTensorDesc.GetDmlDesc(), splitInputDataTensorDesc.GetDmlDesc()}; @@ -122,7 +138,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator // Swap the 2 halves and join them together DML_JOIN_OPERATOR_DESC joinInputDesc{}; joinInputDesc.InputTensors = splitInputDataDmlTensorDescs.data(); - joinInputDesc.OutputTensor = &inputDataDmlTensorDesc; + joinInputDesc.OutputTensor = &joinedDataDmlTensorDesc; joinInputDesc.Axis = splitInputDesc.Axis; joinInputDesc.InputCount = gsl::narrow_cast(splitInputDataDmlTensorDescs.size()); const DML_OPERATOR_DESC joinInputDmlDesc = {DML_OPERATOR_JOIN, &joinInputDesc}; @@ -212,23 +228,23 @@ class DmlOperatorRotaryEmbedding : public DmlOperator const DML_TENSOR_DESC broadcastedSignDmlTensorDesc = broadcastedSignCosSinTensorDesc.GetDmlDesc(); DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulSignDesc{}; - mulSignDesc.ATensor = &inputDataDmlTensorDesc; + mulSignDesc.ATensor = &joinedDataDmlTensorDesc; mulSignDesc.BTensor = &broadcastedSignDmlTensorDesc; - mulSignDesc.OutputTensor = &inputDataDmlTensorDesc; + mulSignDesc.OutputTensor = &joinedDataDmlTensorDesc; const DML_OPERATOR_DESC mulSignDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulSignDesc}; // Multiply the non-rotated data with the cos and the rotated data with the sin DML_ELEMENT_WISE_MULTIPLY_OPERATOR_DESC mulCosSinDesc{}; - mulCosSinDesc.ATensor = &inputDataDmlTensorDesc; + mulCosSinDesc.ATensor = &joinedDataDmlTensorDesc; mulCosSinDesc.BTensor = &broadcastedCosSinDmlTensorDesc; - mulCosSinDesc.OutputTensor = &inputDataDmlTensorDesc; + mulCosSinDesc.OutputTensor = &joinedDataDmlTensorDesc; const DML_OPERATOR_DESC mulCosSinDmlDesc = {DML_OPERATOR_ELEMENT_WISE_MULTIPLY, &mulCosSinDesc}; // Add the multiplied cos and sin values together DML_ELEMENT_WISE_ADD_OPERATOR_DESC addDesc{}; addDesc.ATensor = &inputOutputDmlTensorDesc; addDesc.BTensor = &inputOutputDmlTensorDesc; - addDesc.OutputTensor = &inputOutputDmlTensorDesc; + addDesc.OutputTensor = &stridedInputOutputDmlTensorDesc; const DML_OPERATOR_DESC addDmlDesc = {DML_OPERATOR_ELEMENT_WISE_ADD, &addDesc}; // Construct the graph diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc index b2225643b788e..edee298ad1ccf 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/split_op_builder.cc @@ -67,7 +67,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const int32_t num_outputs; if (node_unit.SinceVersion() >= 18) { - num_outputs = SafeInt(*helper.GetInt("num_outputs")); + num_outputs = SafeInt(*helper.GetInt64("num_outputs")); } else { num_outputs = SafeInt(node_unit.Outputs().size()); } @@ -127,7 +127,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const No } else { uint32_t num_outputs; if (node_unit.SinceVersion() >= 18) { - auto num_outputs_attr = helper.GetInt("num_outputs"); + auto num_outputs_attr = helper.GetInt64("num_outputs"); if (!num_outputs_attr.has_value()) { LOGS_DEFAULT(VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; return false; diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index e6c093d584031..0779940983aea 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -70,10 +70,13 @@ BasicBackend::BasicBackend(const ONNX_NAMESPACE::ModelProto& model_proto, LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } #else - if (global_context_.disable_dynamic_shapes && dev_prec != "CPU_FP16") { - const std::string model = model_proto.SerializeAsString(); - exe_network_ = global_context_.ie_core.LoadNetwork( - model, hw_target, device_config, subgraph_context_.subgraph_name); + if (!subgraph_context_.has_dynamic_input_shape && + global_context_.onnx_model_path_name != "" && + dev_prec != "CPU_FP16") { + exe_network_ = global_context_.ie_core.LoadNetwork(global_context_.onnx_model_path_name, + hw_target, + device_config, + subgraph_context_.subgraph_name); LOGS_DEFAULT(INFO) << log_tag << "Loaded model to the plugin"; } else { ie_cnn_network_ = CreateOVModel(model_proto, global_context_, subgraph_context_, const_outputs_map_); diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 931173fd7ef47..ea481791111fc 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -87,13 +87,13 @@ OVExeNetwork OVCore::LoadNetwork(std::shared_ptr& ie_cnn_network, } } -OVExeNetwork OVCore::LoadNetwork(const std::string& model, +OVExeNetwork OVCore::LoadNetwork(const std::string onnx_model_path, std::string& hw_target, ov::AnyMap& device_config, std::string name) { ov::CompiledModel obj; try { - obj = oe.compile_model(model, ov::Tensor(), hw_target, device_config); + obj = oe.compile_model(onnx_model_path, hw_target, device_config); OVExeNetwork exe(obj); return exe; } catch (const Exception& e) { diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 3db19463809cf..cf4d867d4df55 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -45,7 +45,7 @@ class OVCore { std::string& hw_target, ov::AnyMap& device_config, std::string name); - OVExeNetwork LoadNetwork(const std::string& model_stream, + OVExeNetwork LoadNetwork(const std::string model_path, std::string& hw_target, ov::AnyMap& device_config, std::string name); diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index d95e2baa9457f..4a9106f0c06af 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -94,5 +94,28 @@ void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_r void CreateExpandOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +struct HandleConvertResult { + Status status; // Indicates an unexpected error. Check if q_node_unit != nullptr to determine + // whether a DQ -> Q sequence was successfully merged into a Convert. + const NodeUnit* q_node_unit; // Non-null if successfully merged DQ -> Q sequence. + // Set to nullptr if this node unit could not be merged into a Convert. +}; + +/** + * Tries to merge a DQ -> Q sequence into a QNN Convert operator. The DQ -> Q must be converting from + * one quantization type (e.g., uint8_t) to another (e.g., uint16_t). + * + * \param qnn_model_wrapper The QNN model that is being built. + * \param maybe_dq_node_unit The node unit that could potentially start the DQ -> Q sequence. + * \param logger The logger. + * \param do_op_validation True if should call QNN operator validation APIs. + * \return An qnn::HandleConvertResult object that indicates success/failure and provides a pointer + * to the Q node unit that was successfully merged with the provided DQ node unit. + */ +HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& maybe_dq_node_unit, + const std::unordered_map& node_unit_map, + const logging::Logger& logger, + bool do_op_validation); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc new file mode 100644 index 0000000000000..977a9e0b3d9d0 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/convert_op_builder.cc @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/graph/graph_utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/shared/utils/utils.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/common/safeint.h" +#include "onnx/defs/data_type_utils.h" + +#include "QnnOpDef.h" // From QNN SDK: contains QNN constants (e.g., op names, param values). + +namespace onnxruntime { +namespace qnn { + +class ConvertOpBuilder : public BaseOpBuilder { + public: + ConvertOpBuilder() : BaseOpBuilder("ConvertOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConvertOpBuilder); + + Status AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger, + bool do_op_validation) const ORT_MUST_USE_RESULT; +}; + +Status ConvertOpBuilder::AddConvertToModelBuilder(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& dq_node_unit, + const NodeUnit& q_node_unit, + const logging::Logger& logger, + bool do_op_validation) const { + std::vector input_names; + + // Process the input from the DQ node + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, dq_node_unit.Inputs()[0], logger, input_names)); + + // Process the output from the Q node. Override the QNN operator type to "Convert". + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, q_node_unit, std::move(input_names), {}, + logger, do_op_validation, QNN_OP_CONVERT)); + return Status::OK(); +} + +HandleConvertResult TryHandleConvertSequence(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& maybe_dq_node_unit, + const std::unordered_map& node_unit_map, + const logging::Logger& logger, + bool do_op_validation) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + // Looking for a standalone DQ to start the sequence. + if (maybe_dq_node_unit.OpType() != QDQ::DQOpName || maybe_dq_node_unit.UnitType() != NodeUnit::Type::SingleNode) { + return {}; + } + + const Node& dq_node = maybe_dq_node_unit.GetNode(); + + // DQ must have a single Q child. DQ must not produce a graph output. + auto children = graph_utils::FindChildrenByType(dq_node, QDQ::QOpName); + if (children.size() != 1 || dq_node.GetOutputEdgesCount() != 1 || graph_viewer.NodeProducesGraphOutput(dq_node)) { + return {}; + } + + const Node& q_node = *children[0]; + const auto q_node_unit_it = node_unit_map.find(&q_node); + + if (q_node_unit_it == node_unit_map.end()) { + return {ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Node does not have a corresponding NodeUnit"), nullptr}; + } + + const NodeUnit* q_node_unit = q_node_unit_it->second; + + // Q child must not already be part of a QDQ NodeUnit (i.e., be standalone). + if (q_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return {}; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + // DQ and Q must have equal scale type and different zp type. + if (!QDQ::IsDQQConversion(dq_node, q_node, get_const_initializer, graph_viewer.ModelPath())) { + return {}; + } + + ConvertOpBuilder op_builder; + + LOGS(logger, VERBOSE) << " Adding QNN Convert. dq_node name: [" << dq_node.Name() + << "] dq_node optype: [" << dq_node.OpType() + << "] q_node name: [" << q_node_unit->Name() + << "] q_node optype: [" << q_node_unit->OpType() + << "]"; + + auto status = op_builder.AddConvertToModelBuilder(qnn_model_wrapper, maybe_dq_node_unit, *q_node_unit, logger, + do_op_validation); + return status.IsOK() ? HandleConvertResult{status, q_node_unit} : HandleConvertResult{status, nullptr}; +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 314cab4a36ca9..dc91b9dfa199e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -114,6 +114,8 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper."); } + std::unordered_set handled_node_units; + // Op builer const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); for (size_t i = 0; i < node_indices.size(); i++) { @@ -122,20 +124,43 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // Check whether it's part of NodeUnit const NodeUnit& node_unit = GetNodeUnit(node, node_unit_map); // Q, DQ nodes in the node unit only carry the quantization parameters - // Add the QNN node when it is the target node (It's a normal node or a singel Q/DQ node) + // Add the QNN node when it is the target node (It's a normal node or a single Q/DQ node) const std::string& op_type = node_unit.OpType(); + + if (node != &node_unit.GetNode()) { + continue; + } + + if (handled_node_units.count(&node_unit) != 0) { + continue; // Already handled. + } + + // Try to convert particular DQ -> Q sequences into QNN Convert op + auto convert_result = TryHandleConvertSequence(qnn_model_wrapper, + node_unit, + node_unit_map, + logger_, + false /*do_op_validation*/); + ORT_RETURN_IF_ERROR(convert_result.status); + + if (convert_result.q_node_unit) { + // Successfully merged DQ -> Q sequence into a QNN Convert op. + // Mark both of these node units as handled. + handled_node_units.insert(&node_unit); + handled_node_units.insert(convert_result.q_node_unit); + continue; + } + LOGS(logger_, VERBOSE) << " node name: [" << node->Name() << "] node optype: [" << op_type << "] as part of the NodeUnit type: [" << node_unit.OpType() << "] name: [" << node_unit.Name() << "]"; - if (node != &node_unit.GetNode()) { - continue; - } - if (const auto* op_builder = GetOpBuilder(op_type)) { ORT_RETURN_IF_ERROR(op_builder->AddToModelBuilder(qnn_model_wrapper, node_unit, logger_)); } + + handled_node_units.insert(&node_unit); } ORT_RETURN_IF_NOT(qnn_model_wrapper.ComposeQnnGraph(), "Failed to compose Qnn graph."); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index b58f6e10df94c..f5a166d36b15a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -286,33 +286,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - std::unordered_map& node_unit_supported_result, const logging::Logger& logger) const { - // If we have visited one of the nodes in the node_unit, use the result directly - const auto it = node_unit_supported_result.find(&node_unit); - if (it != node_unit_supported_result.cend()) { - return it->second; + const std::string& op_type = node_unit.OpType(); + bool supported = false; + const auto* op_builder = qnn::GetOpBuilder(op_type); + if (op_builder == nullptr) { + LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." + << node_unit.OpType() << " node `" << node_unit.Name() + << "` will not be assigned to QNN EP."; } else { - const std::string& op_type = node_unit.OpType(); - - bool supported = false; - const auto* op_builder = qnn::GetOpBuilder(op_type); - if (op_builder == nullptr) { - LOGS(logger, WARNING) << "Operators of type `" << node_unit.OpType() << "` are not supported by QNN EP." - << node_unit.OpType() << " node `" << node_unit.Name() - << "` will not be assigned to QNN EP."; - } else { - auto status = op_builder->IsOpSupported(qnn_model_wrapper, - node_unit, logger); - if (Status::OK() != status) { - LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() - << "` is not supported: " << status.ErrorMessage(); - } - supported = (Status::OK() == status); + auto status = op_builder->IsOpSupported(qnn_model_wrapper, + node_unit, logger); + if (Status::OK() != status) { + LOGS(logger, WARNING) << node_unit.OpType() << " node `" << node_unit.Name() + << "` is not supported: " << status.ErrorMessage(); } - node_unit_supported_result[&node_unit] = supported; - return supported; + supported = (Status::OK() == status); } + return supported; } std::unordered_set @@ -391,24 +382,51 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, if (node != &node_unit->GetNode()) { continue; } - const bool supported = IsNodeSupported(qnn_model_wrapper, - *node_unit, - node_unit_supported_result, - logger); - LOGS(logger, VERBOSE) << "Node supported: [" << supported - << "] index: [" << node->Index() - << "] name: [" << node->Name() - << "] Operator type: [" << node->OpType() - << "] as part of the NodeUnit type: [" << node_unit->OpType() - << "] index: [" << node_unit->Index() - << "] name: [" << node_unit->Name() - << "]"; + + if (node_unit_supported_result.count(node_unit) != 0) { + continue; // Already handled this node unit + } + + // Try to convert certain standalone DQ -> Q sequences into QNN Convert op + auto convert_result = TryHandleConvertSequence(qnn_model_wrapper, + *node_unit, + node_unit_map, + logger, + true /*do_op_validation*/); + if (!convert_result.status.IsOK()) { + LOGS(logger, WARNING) << "Failed to convert DQ -> Q sequence to QNN Convert. " + << "Type: " << node_unit->OpType() << ", Node name: " << node_unit->Name() << ", " + << "Message: " << convert_result.status.ErrorMessage(); + } + + bool supported = false; + + if (convert_result.status.IsOK() && convert_result.q_node_unit) { // Merged DQ -> Q sequence into QNN Convert op + supported = true; + + // Mark the Q node unit as handled and supported here so that we don't try to process it again. + node_unit_supported_result.insert({convert_result.q_node_unit, true}); + supported_nodes.insert(&convert_result.q_node_unit->GetNode()); + } else { + supported = IsNodeSupported(qnn_model_wrapper, *node_unit, logger); + LOGS(logger, VERBOSE) << "Node supported: [" << supported + << "] index: [" << node->Index() + << "] name: [" << node->Name() + << "] Operator type: [" << node->OpType() + << "] as part of the NodeUnit type: [" << node_unit->OpType() + << "] index: [" << node_unit->Index() + << "] name: [" << node_unit->Name() + << "]"; + } + if (supported) { // If the node_unit is supported, add all of its nodes to the supported list. for (const auto* node_in_group : node_unit->GetAllNodesInGroup()) { supported_nodes.insert(node_in_group); } } + + node_unit_supported_result.insert({node_unit, supported}); } return supported_nodes; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 09bcb24db4dc2..0bcaa39b22f6d 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -42,7 +42,6 @@ class QNNExecutionProvider : public IExecutionProvider { private: bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, - std::unordered_map& node_unit_supported_result, const logging::Logger& logger) const; std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 37ad14ac2e9b1..c07a0929353b1 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -118,84 +118,134 @@ NodeAttrHelper::NodeAttrHelper(const NodeUnit& node_unit) : node_attributes_(node_unit.GetNode().GetAttributes()) {} float NodeAttrHelper::Get(const std::string& key, float def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return entry->second.f(); + } - return node_attributes_.at(key).f(); + return def_val; } int32_t NodeAttrHelper::Get(const std::string& key, int32_t def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return narrow(entry->second.i()); + } - return SafeInt(node_attributes_.at(key).i()); + return def_val; } uint32_t NodeAttrHelper::Get(const std::string& key, uint32_t def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return narrow(entry->second.i()); + } - return SafeInt(node_attributes_.at(key).i()); + return def_val; } int64_t NodeAttrHelper::Get(const std::string& key, int64_t def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return entry->second.i(); + } - return node_attributes_.at(key).i(); + return def_val; } const std::string& NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + return entry->second.s(); + } - return node_attributes_.at(key).s(); + return def_val; } std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& attr = entry->second; + std::vector v; + v.reserve(static_cast(attr.ints_size())); + std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v), + [](int64_t val) -> int32_t { return narrow(val); }); + return v; + } - const auto& attr(node_attributes_.at(key)); - std::vector v; - v.reserve(static_cast(attr.ints_size())); - std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v), - [](int64_t val) -> int32_t { return SafeInt(val); }); - return v; + return def_val; } std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& attr = entry->second; + std::vector v; + v.reserve(static_cast(attr.ints_size())); + std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v), + [](int64_t val) -> uint32_t { return narrow(val); }); + return v; + } - const auto& attr(node_attributes_.at(key)); - std::vector v; - v.reserve(static_cast(attr.ints_size())); - std::transform(attr.ints().cbegin(), attr.ints().cend(), std::back_inserter(v), - [](int64_t val) -> uint32_t { return SafeInt(val); }); - return v; + return def_val; } std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.ints(); + return std::vector{values.cbegin(), values.cend()}; + } - const auto& source(node_attributes_.at(key).ints()); - return std::vector{source.cbegin(), source.cend()}; + return def_val; } std::vector NodeAttrHelper::Get(const std::string& key, const std::vector& def_val) const { - if (!HasAttr(key)) - return def_val; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.floats(); + return std::vector{values.cbegin(), values.cend()}; + } - const auto& source(node_attributes_.at(key).floats()); - return std::vector{source.cbegin(), source.cend()}; + return def_val; } -std::optional NodeAttrHelper::GetInt(const std::string& key) const { - if (!HasAttr(key)) - return std::nullopt; - return node_attributes_.at(key).i(); +std::optional NodeAttrHelper::GetFloat(const std::string& key) const { + std::optional result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + result = entry->second.f(); + } + + return result; +} + +std::optional NodeAttrHelper::GetInt64(const std::string& key) const { + std::optional result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + result = entry->second.i(); + } + + return result; +} + +std::optional> NodeAttrHelper::GetFloats(const std::string& key) const { + std::optional> result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.floats(); + result = std::vector(values.begin(), values.end()); + } + + return result; +} + +std::optional> NodeAttrHelper::GetInt64s(const std::string& key) const { + std::optional> result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + const auto& values = entry->second.ints(); + result = std::vector(values.begin(), values.end()); + } + + return result; +} + +std::optional NodeAttrHelper::GetString(const std::string& key) const { + std::optional result; + if (auto entry = node_attributes_.find(key); entry != node_attributes_.end()) { + result = entry->second.s(); + } + + return result; } bool NodeAttrHelper::HasAttr(const std::string& key) const { diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h index 31b1aba2e1a63..5813dcc48d72b 100644 --- a/onnxruntime/core/providers/shared/utils/utils.h +++ b/onnxruntime/core/providers/shared/utils/utils.h @@ -47,15 +47,17 @@ class NodeAttrHelper { // Get the attributes from the target node of the node_unit explicit NodeAttrHelper(const NodeUnit& node_unit); + /* + * Get with default + */ float Get(const std::string& key, float def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; int64_t Get(const std::string& key, int64_t def_val) const; + std::vector Get(const std::string& key, const std::vector& def_val) const; const std::string& Get(const std::string& key, const std::string& def_val) const; - std::vector Get(const std::string& key, const std::vector& def_val) const; - std::vector Get(const std::string& key, const std::vector& def_val) const; - // Convert the i() or ints() of the attribute from int64_t to int32_t int32_t Get(const std::string& key, int32_t def_val) const; std::vector Get(const std::string& key, const std::vector& def_val) const; @@ -64,7 +66,16 @@ class NodeAttrHelper { uint32_t Get(const std::string& key, uint32_t def_val) const; std::vector Get(const std::string& key, const std::vector& def_val) const; - std::optional GetInt(const std::string& key) const; + /* + * Get without default. + */ + std::optional GetFloat(const std::string& key) const; + std::optional> GetFloats(const std::string& key) const; + + std::optional GetInt64(const std::string& key) const; + std::optional> GetInt64s(const std::string& key) const; + + std::optional GetString(const std::string& key) const; bool HasAttr(const std::string& key) const; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index bb8732784945d..3bec9aa146f76 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1682,7 +1682,11 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O if (legacy_ov_options->device_type != nullptr) ov_options_converted_map["device_type"] = legacy_ov_options->device_type; - ov_options_converted_map["enable_npu_fast_compile"] = legacy_ov_options->enable_npu_fast_compile; + if (legacy_ov_options->enable_npu_fast_compile) { + ov_options_converted_map["enable_npu_fast_compile"] = "false"; + } else { + ov_options_converted_map["enable_npu_fast_compile"] = "true"; + } if (legacy_ov_options->device_id != nullptr) ov_options_converted_map["device_id"] = legacy_ov_options->device_id; @@ -1701,14 +1705,12 @@ ProviderOptions OrtOpenVINOProviderOptionsToOrtOpenVINOProviderOptionsV2(const O ov_options_converted_map["enable_opencl_throttling"] = legacy_ov_options->enable_opencl_throttling; - if (legacy_ov_options->enable_dynamic_shapes != '\0') { - std::string enable_dynamic_shapes = reinterpret_cast(legacy_ov_options->enable_dynamic_shapes); - if (enable_dynamic_shapes == "true" || enable_dynamic_shapes == "True") { - ov_options_converted_map["disable_dynamic_shapes"] = "false"; - } else if (enable_dynamic_shapes == "false" || enable_dynamic_shapes == "False") { - ov_options_converted_map["disable_dynamic_shapes"] = "true"; - } + if (legacy_ov_options->enable_dynamic_shapes) { + ov_options_converted_map["disable_dynamic_shapes"] = "false"; + } else { + ov_options_converted_map["disable_dynamic_shapes"] = "true"; } + // Add new provider option below ov_options_converted_map["num_streams"] = "1"; return ov_options_converted_map; diff --git a/onnxruntime/core/util/thread_utils.cc b/onnxruntime/core/util/thread_utils.cc index a5a165e150cf1..2a6c14ff1b058 100644 --- a/onnxruntime/core/util/thread_utils.cc +++ b/onnxruntime/core/util/thread_utils.cc @@ -93,22 +93,31 @@ static std::unique_ptr CreateThreadPoolHelper(Env* env, OrtThreadPoolParams options) { ThreadOptions to; if (options.thread_pool_size <= 0) { // default - auto default_affinities = Env::Default().GetDefaultThreadAffinities(); - if (default_affinities.size() <= 1) { - return nullptr; - } - options.thread_pool_size = static_cast(default_affinities.size()); if (options.auto_set_affinity) { #ifdef _WIN32 // Only set thread affinity on Server with auto affinity. // On client best to let OS scheduler handle. // On big (P-Core) / little (E-Core) CPU designs affinity overrides QoS and has high power usage if (IsWindowsServer()) { + auto default_affinities = Env::Default().GetDefaultThreadAffinities(); + if (default_affinities.size() <= 1) { + return nullptr; + } + options.thread_pool_size = static_cast(default_affinities.size()); to.affinities = std::move(default_affinities); + } else { + options.thread_pool_size = Env::Default().GetNumPhysicalCpuCores(); } #else + auto default_affinities = Env::Default().GetDefaultThreadAffinities(); + if (default_affinities.size() <= 1) { + return nullptr; + } + options.thread_pool_size = static_cast(default_affinities.size()); to.affinities = std::move(default_affinities); #endif + } else { + options.thread_pool_size = Env::Default().GetNumPhysicalCpuCores(); } } if (options.thread_pool_size <= 1) { diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py index e32cb032798fc..8334d20e47c86 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/groupnorm_test.py @@ -35,7 +35,11 @@ def sigmoid_function(x): return 1.0 / (1.0 + np.exp(-x)) -def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish): +def group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, with_silu, has_skip): + add_output = None + if has_skip: + input_x = input_x + skip_x + bias_x + add_output = input_x n, h, w, c = input_x.shape input_x = input_x.transpose([0, 3, 1, 2]) assert c % num_groups == 0 @@ -45,46 +49,70 @@ def group_norm(input_x, gamma, beta, num_groups, epsilon, with_swish): x = x.transpose([0, 2, 3, 1]) x = x * gamma + beta - if with_swish: + if with_silu: x = x * sigmoid_function(x) - return x + return x, add_output -def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func): +def run_group_norm( + batch_size: int, height: int, num_channels: int, num_groups: int, dtype: str, silu: bool, has_skip: bool, func +): np.random.seed(0) width = height input_x = np.random.rand(batch_size, height, width, num_channels).astype(np.float32) gamma = np.random.rand(num_channels).astype(np.float32) beta = np.random.rand(num_channels).astype(np.float32) # the size of workspace is defined in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h L18 - workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * 32 * 32).astype(np.float32) + workspace = np.random.rand((np.dtype(np.float32).itemsize * 2) * batch_size * num_groups).astype(np.float32) epsilon = 1e-05 output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - use_swish = swish - host_x = input_x.astype(dtype) - input_d = ke.DeviceArray(host_x) + skip_x = ( + np.random.rand(batch_size, height, width, num_channels).astype(np.float32) + if has_skip + else np.empty((0), dtype=dtype) + ) + bias_x = np.random.rand(num_channels).astype(np.float32) if has_skip else np.empty((0), dtype=dtype) + add_output = ( + np.random.rand(batch_size, height, width, num_channels).astype(dtype) + if has_skip + else np.empty((0), dtype=dtype) + ) + use_silu = silu + broadcast_skip = False + channels_per_block = 0 # Compute in params initialization + + input_d = ke.DeviceArray(input_x.astype(dtype)) + skip_d = ke.DeviceArray(skip_x.astype(dtype)) + bias_d = ke.DeviceArray(bias_x.astype(dtype)) gamma_d = ke.DeviceArray(gamma) beta_d = ke.DeviceArray(beta) workspace_d = ke.DeviceArray(workspace) y_d = ke.DeviceArray(output_y) + y_add_d = ke.DeviceArray(add_output) f = getattr(ke, func) my_op = f( y_d, - workspace_d, + y_add_d, input_d, + skip_d, + bias_d, gamma_d, beta_d, + workspace_d, + epsilon, batch_size, + num_channels, height, width, - num_channels, num_groups, - epsilon, - use_swish, + use_silu, + broadcast_skip, + channels_per_block, ) - y_ref = group_norm(input_x, gamma, beta, num_groups, epsilon, use_swish).astype(dtype) + y_ref, y_add_d_ref = group_norm(input_x, skip_x, bias_x, gamma, beta, num_groups, epsilon, use_silu, has_skip) + y_ref = y_ref.astype(dtype) for impl in my_op.ListOps(): if not my_op.SelectOp(impl): @@ -95,6 +123,10 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: y_d.UpdateHostNumpyArray() np.testing.assert_allclose(y_ref, output_y, atol=1e-02) + if has_skip: + y_add_d_ref = y_add_d_ref.astype(dtype) + y_add_d.UpdateHostNumpyArray() + np.testing.assert_allclose(y_add_d_ref, add_output, atol=1e-02) dtypes = ["float32", "float16"] @@ -102,19 +134,21 @@ def run_group_norm(batch_size: int, height: int, num_channels: int, num_groups: @pytest.mark.parametrize("sd_sizes", get_sd_sizes()) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("swish", [True]) -def test_group_norm(sd_sizes, dtype, swish): +@pytest.mark.parametrize("silu", [True]) +@pytest.mark.parametrize("has_skip", [True, False]) +def test_group_norm(sd_sizes, dtype, silu, has_skip): for func in dtype_to_funcs(dtype): - run_group_norm(*sd_sizes, dtype, swish, func) + run_group_norm(*sd_sizes, dtype, silu, has_skip, func) @pytest.mark.parametrize("sd_sizes", get_sd_sizes()) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("swish", [True]) -def test_group_norm_ck(sd_sizes, dtype, swish): - swish_suffix = "Swish" if swish else "Pass" - ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype) - run_group_norm(*sd_sizes, dtype, swish, ck_f_name) +@pytest.mark.parametrize("silu", [True]) +@pytest.mark.parametrize("has_skip", [False]) +def test_group_norm_ck(sd_sizes, dtype, silu, has_skip): + silu_suffix = "Silu" if silu else "Pass" + ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype) + run_group_norm(*sd_sizes, dtype, silu, has_skip, ck_f_name) @dataclass @@ -136,37 +170,67 @@ def report(self): def profile_group_norm_func( - batch_size: int, height: int, width: int, num_channels: int, num_groups: int, dtype: str, swish: bool, func + batch_size: int, + height: int, + width: int, + num_channels: int, + num_groups: int, + dtype: str, + silu: bool, + has_skip: bool, + func, ): np.random.seed(0) input_x = np.random.rand(batch_size, height, width, num_channels).astype(dtype) gamma = np.random.rand(num_channels).astype(np.float32) beta = np.random.rand(num_channels).astype(np.float32) - workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * 32 * 32).astype(np.float32) + workspace = np.random.rand(np.dtype(np.float32).itemsize * 2 * batch_size * num_groups).astype(np.float32) epsilon = 0.05 output_y = np.random.rand(batch_size, height, width, num_channels).astype(dtype) - use_swish = swish + + skip_x = ( + np.random.rand(batch_size, height, width, num_channels).astype(dtype) + if has_skip + else np.empty((0), dtype=dtype) + ) + bias_x = np.random.rand(num_channels).astype(dtype) if has_skip else np.empty((0), dtype=dtype) + add_output = ( + np.random.rand(batch_size, height, width, num_channels).astype(dtype) + if has_skip + else np.empty((0), dtype=dtype) + ) + use_silu = silu + broadcast_skip = False + channels_per_block = 0 # Compute in params initialization input_d = ke.DeviceArray(input_x) + skip_d = ke.DeviceArray(skip_x) + bias_d = ke.DeviceArray(bias_x) gamma_d = ke.DeviceArray(gamma) beta_d = ke.DeviceArray(beta) workspace_d = ke.DeviceArray(workspace) y_d = ke.DeviceArray(output_y) + y_add_d = ke.DeviceArray(add_output) f = getattr(ke, func) my_op = f( y_d, - workspace_d, + y_add_d, input_d, + skip_d, + bias_d, gamma_d, beta_d, + workspace_d, + epsilon, batch_size, + num_channels, height, width, - num_channels, num_groups, - epsilon, - use_swish, + use_silu, + broadcast_skip, + channels_per_block, ) for impl in my_op.ListOps(): duration_ms = -1 @@ -181,14 +245,14 @@ def profile_group_norm_func( ) -def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, swish=True, sort=True): +def profile_with_args(batch_size, height, width, num_channels, num_groups, dtype, silu=True, has_skip=True, sort=True): with ke.benchmark(sort): for func in dtype_to_funcs(dtype): - profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, func) + profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, func) # ck function - swish_suffix = "Swish" if swish else "Pass" - ck_f_name = "CKGroupNormNHWC" + swish_suffix + "_" + dtype_to_suffix(dtype) - profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, swish, ck_f_name) + silu_suffix = "Silu" if silu else "Pass" + ck_f_name = "CKGroupNormNHWC" + silu_suffix + "_" + dtype_to_suffix(dtype) + profile_group_norm_func(batch_size, height, width, num_channels, num_groups, dtype, silu, has_skip, ck_f_name) sd_profile_sizes = [ @@ -227,7 +291,8 @@ def profile(): group.add_argument("num_channels", type=int) group.add_argument("num_groups", type=int) group.add_argument("dtype", choices=dtypes) - group.add_argument("--swish", action="store_true") + group.add_argument("--silu", action="store_true") + group.add_argument("--has_skip", action="store_true") group.add_argument("--sort", action="store_true") if len(sys.argv) == 1: @@ -241,6 +306,7 @@ def profile(): args.num_channels, args.num_groups, args.dtype, - args.swish, + args.silu, + args.has_skip, args.sort, ) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu index 0bd47b2c0387e..6af163ab94b10 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu @@ -12,17 +12,21 @@ #include "python/tools/kernel_explorer/kernel_explorer_interface.h" namespace py = pybind11; - +using onnxruntime::contrib::rocm::GetGroupNormWorkspaceSizeInBytes; namespace onnxruntime { template class GroupNormNHWC : public IKernelExplorer { public: - GroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + GroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias, + DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu, + bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { type_string_ = "GroupNormNHWC_" + std::to_string(ThreadsPerBlock) + "_" + std::to_string(VecSize); } @@ -40,7 +44,7 @@ class GroupNormNHWC : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; ParamsT params_{}; contrib::rocm::GroupNormNHWCOp op_{}; std::string type_string_{}; @@ -49,11 +53,15 @@ class GroupNormNHWC : public IKernelExplorer { template class GroupNormNHWCStaticSelection : public IKernelExplorer { public: - GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { type_string_ = "GroupNormNHWCStaticSelection"; } @@ -71,7 +79,7 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; ParamsT params_{}; std::string type_string_{}; }; @@ -79,11 +87,15 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer { template class GroupNormNHWCTunable : public IKernelExplorer { public: - GroupNormNHWCTunable(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { + GroupNormNHWCTunable(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { params_.TuningContext()->EnableTunableOpAndTuning(); } @@ -100,21 +112,25 @@ class GroupNormNHWCTunable : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; ParamsT params_{}; contrib::rocm::GroupNormNHWCTunableOp op_{}; }; #ifdef USE_COMPOSABLE_KERNEL -template +template class CKGroupNormNHWC : public IKernelExplorer { public: - CKGroupNormNHWC(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { - for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps()) { + CKGroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { + for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); } @@ -141,7 +157,7 @@ class CKGroupNormNHWC : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; using OpT = rocm::tunable::Op; ParamsT params_{}; std::vector ops_; @@ -151,15 +167,19 @@ class CKGroupNormNHWC : public IKernelExplorer { #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL -template +template class GroupNormNHWCTriton : public IKernelExplorer { public: - GroupNormNHWCTriton(DeviceArray& output, DeviceArray& workspace, DeviceArray& input, DeviceArray& gamma, DeviceArray& beta, - int batch_size, int height, int width, int num_channels, int num_groups, float epsilon, bool use_swish) - : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(workspace.ptr()), - static_cast(input.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), - batch_size, height, width, num_channels, num_groups, epsilon, use_swish) { - for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps()) { + GroupNormNHWCTriton(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) + : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), + static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), + static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { + for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps()) { name_strings_.emplace_back(name); ops_.emplace_back(std::move(op)); } @@ -186,7 +206,7 @@ class GroupNormNHWCTriton : public IKernelExplorer { } private: - using ParamsT = contrib::rocm::GroupNormNHWCParams; + using ParamsT = contrib::rocm::GroupNormNHWCTunableParams; using OpT = rocm::tunable::Op; ParamsT params_{}; std::vector ops_; @@ -198,7 +218,8 @@ class GroupNormNHWCTriton : public IKernelExplorer { #define REGISTER_OP(name, type, threads_per_block, vec_size) \ py::class_>(m, #name "_" #type "_" #threads_per_block "_" #vec_size) \ .def(py::init()) \ + DeviceArray&, DeviceArray&, DeviceArray&, float, \ + int, int, int, int, int, bool, bool, int>()) \ .def("SetRepeats", &name::SetRepeats) \ .def("Profile", &name::Profile) \ .def("Run", &name::Run) \ @@ -220,7 +241,8 @@ class GroupNormNHWCTriton : public IKernelExplorer { #define REGISTER_COMMON(name, type, ...) \ py::class_>(m, name) \ .def(py::init()) \ + DeviceArray&, DeviceArray&, DeviceArray&, float, \ + int, int, int, int, int, bool, bool, int>()) \ .def("SetRepeats", &type<__VA_ARGS__>::SetRepeats) \ .def("Profile", &type<__VA_ARGS__>::Profile) \ .def("Run", &type<__VA_ARGS__>::Run) \ @@ -230,11 +252,11 @@ class GroupNormNHWCTriton : public IKernelExplorer { #define REGISTER_OP_TYPED(name, type) \ REGISTER_COMMON(#name "_" #type, name, type) -#define REGISTER_CK(type, with_swish, swish_suffix) \ - REGISTER_COMMON("CKGroupNormNHWC" swish_suffix "_" #type, CKGroupNormNHWC, type, with_swish) +#define REGISTER_CK(type, with_silu, silu_suffix) \ + REGISTER_COMMON("CKGroupNormNHWC" silu_suffix "_" #type, CKGroupNormNHWC, type, with_silu) -#define REGISTER_TRITON(type, with_swish, swish_suffix) \ - REGISTER_COMMON("GroupNormNHWCTriton" swish_suffix "_" #type, GroupNormNHWCTriton, type, with_swish) +#define REGISTER_TRITON(type, with_silu, silu_suffix) \ + REGISTER_COMMON("GroupNormNHWCTriton" silu_suffix "_" #type, GroupNormNHWCTriton, type, with_silu) KE_REGISTER(m) { REGISTER_OP_FOR_ALL_THREADS_PER_BLOCK_ALL_VEC_SIZE(GroupNormNHWC, half); @@ -248,16 +270,16 @@ KE_REGISTER(m) { #ifdef USE_COMPOSABLE_KERNEL REGISTER_CK(half, false, "Pass"); - REGISTER_CK(half, true, "Swish"); + REGISTER_CK(half, true, "Silu"); REGISTER_CK(float, false, "Pass"); - REGISTER_CK(float, true, "Swish"); + REGISTER_CK(float, true, "Silu"); #endif // USE_COMPOSABLE_KERNEL #ifdef USE_TRITON_KERNEL REGISTER_TRITON(half, false, "Pass"); - REGISTER_TRITON(half, true, "Swish"); + REGISTER_TRITON(half, true, "Silu"); REGISTER_TRITON(float, false, "Pass"); - REGISTER_TRITON(float, true, "Swish"); + REGISTER_TRITON(float, true, "Silu"); #endif } diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 3e9f9a6544a71..eb7bbec997d59 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -349,6 +349,10 @@ def process(self): self.int4_quant_algo() +def ort_convert_str_to_bool(value): + return value.lower() in ("true", "1") + + def parse_args(): parser = argparse.ArgumentParser( description="""Blockwise int4 quantization for MatMul 2D weight matrices. @@ -366,7 +370,10 @@ def parse_args(): "--symmetric", required=False, default=True, - type=bool, + const=True, + nargs="?", + type=ort_convert_str_to_bool, + choices=[True, False], help="Indicate whether to quantize the model symmetrically", ) parser.add_argument( diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index a72d21c03a8a6..9450426f12444 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -385,7 +385,7 @@ def add_new_nodes(self, nodes): def quantize_model(self): if self.has_QDQ_nodes(): logging.warning( - "Please check if the model is already quantized." + "Please check if the model is already quantized. " "Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly." ) @@ -442,6 +442,23 @@ def is_valid_quantize_weight(self, weight_name): return False return self.parent.is_valid_quantize_weight(weight_name) + def _get_default_tensor_type(self, tensor_name): + if "DefaultTensorType" in self.extra_options: + logging.info( + "get_tensor_type returns DefaultTensorType for tensor name %r, use %d", + tensor_name, + self.extra_options["DefaultTensorType"], + ) + return self.extra_options["DefaultTensorType"] + raise RuntimeError( + f"Unable to find data type for weight_name={tensor_name!r}. " + f"shape_inference failed to return a type probably this node is " + f"from a different domain or using an input produced by such an operator. " + f"This may happen if you quantize a model already quantized. " + f"You may use extra_options `DefaultTensorType` to indicate " + f"the default weight type, usually `onnx.TensorProto.FLOAT`." + ) + def get_tensor_type(self, tensor_name, mandatory=False): weight = find_by_name(tensor_name, self.model.initializer()) if weight is not None: @@ -450,11 +467,11 @@ def get_tensor_type(self, tensor_name, mandatory=False): vi = self.value_infos[tensor_name] if vi.type.HasField("tensor_type"): if mandatory and vi.type.tensor_type.elem_type == 0: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return vi.type.tensor_type.elem_type if (not self.enable_subgraph_quantization) or (self.parent is None): if mandatory: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return None otype = self.parent.is_valid_quantize_weight(tensor_name) if otype is not None: @@ -464,7 +481,7 @@ def get_tensor_type(self, tensor_name, mandatory=False): if res is not None: return res if mandatory: - raise RuntimeError(f"Unable to find data type for weight_name={tensor_name!r}") + return self._get_default_tensor_type(tensor_name) return None def is_float_tensor(self, tensor_name): @@ -1332,9 +1349,15 @@ def _dequantize_value(self, value_name): if (value_name in self.quantized_value_map) and (value_name not in self.generated_value_names): quantized_value = self.quantized_value_map[value_name] # Add DequantizeLinear Node for this input + scale_init = find_by_name(quantized_value.scale_name, self.model.initializer()) - # axis is not specified so scale_init must be a scalar. - assert onnx.numpy_helper.to_array(scale_init).size == 1 + + # In case we are working with subgraphs, the graph `producer_name` is set to `"onnx-quantizer"` in the `quantize_subgraph` method. In this case, the scale initializer may be on the top level graph, so the check below can not be done. + if self.model.model.producer_name != "onnx-quantizer" or ( + self.model.model.producer_name == "onnx-quantizer" and scale_init is not None + ): + # axis is not specified so scale_init must be a scalar. + assert onnx.numpy_helper.to_array(scale_init).size == 1 dqlinear_name = value_name + "_DequantizeLinear" dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph()) diff --git a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py index b98aafc27579a..2ae64a72d08fe 100644 --- a/onnxruntime/python/tools/tensorrt/perf/build/build_image.py +++ b/onnxruntime/python/tools/tensorrt/perf/build/build_image.py @@ -45,7 +45,7 @@ def get_common_docker_build_args(args: argparse.Namespace) -> List[str]: :return: A list of common 'docker build' arguments. """ - return [ + command = [ "--no-cache", "-t", f"{args.image_name}", @@ -54,6 +54,14 @@ def get_common_docker_build_args(args: argparse.Namespace) -> List[str]: "--build-arg", f"ONNXRUNTIME_BRANCH={args.branch}", ] + if args.use_tensorrt_oss_parser: + command.extend( + [ + "--build-arg", + "PARSER_CONFIG=--use_tensorrt_oss_parser", + ] + ) + return command def is_valid_ver_str(version: str, min_comps: int = 0, max_comps: int = 0) -> bool: @@ -187,7 +195,7 @@ def parse_arguments() -> argparse.Namespace: parser.add_argument("-r", "--repo_path", required=True, help="Path to the onnxruntime repository") parser.add_argument("-i", "--image_name", required=True, help="The resulting Docker image name") parser.add_argument("-b", "--branch", default="main", help="Name of the onnxruntime git branch to checkout") - parser.add_argument("-t", "--trt_version", default="8.4.1.5", help="TensorRT version (e.g., 8.4.1.5)") + parser.add_argument("-t", "--trt_version", default="8.6.1.6", help="TensorRT version (e.g., 8.6.1.6)") parser.add_argument("-a", "--cuda_arch", default="75", help="CUDA architecture (e.g., 75)") # Command-line options for installing TensorRT from binaries. @@ -208,6 +216,12 @@ def parse_arguments() -> argparse.Namespace: help="CUDA version (e.g., 8.6) used to find TensorRT EA binary tar.gz package", ) parser.add_argument("--trt_bins_dir", default="", help="Directory containing TensorRT tar.gz package") + parser.add_argument( + "--use_tensorrt_oss_parser", + action="store_true", + default=False, + help="Use TensorRT OSS Parser", + ) return parser.parse_args() diff --git a/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py b/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py index 6e20071683d90..c7d4a7836132a 100755 --- a/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py +++ b/onnxruntime/python/tools/tensorrt/perf/build/ort_build_latest.py @@ -13,6 +13,12 @@ def parse_arguments(): parser.add_argument("-b", "--branch", required=False, default="master", help="Github branch to test perf off of") parser.add_argument("-s", "--save", required=False, help="Directory to archive wheel file") parser.add_argument("-a", "--use_archived", required=False, help="Archived wheel file") + parser.add_argument( + "--use_tensorrt_oss_parser", + action="store_true", + default=False, + help="Use TensorRT OSS Parser", + ) args = parser.parse_args() return args @@ -35,14 +41,14 @@ def install_new_ort_wheel(ort_master_path): def main(): args = parse_arguments() - cmake_tar = "cmake-3.18.4-Linux-x86_64.tar.gz" + cmake_tar = "cmake-3.28.3-linux-x86_64.tar.gz" if not os.path.exists(cmake_tar): - subprocess.run(["wget", "-c", "https://cmake.org/files/v3.18/" + cmake_tar], check=True) + subprocess.run(["wget", "-c", "https://cmake.org/files/v3.28/" + cmake_tar], check=True) tar = tarfile.open(cmake_tar) tar.extractall() tar.close() - os.environ["PATH"] = os.path.join(os.path.abspath("cmake-3.18.4-Linux-x86_64"), "bin") + ":" + os.environ["PATH"] + os.environ["PATH"] = os.path.join(os.path.abspath("cmake-3.28.3-linux-x86_64"), "bin") + ":" + os.environ["PATH"] os.environ["CUDACXX"] = os.path.join(args.cuda_home, "bin", "nvcc") ort_master_path = args.ort_master_path @@ -57,24 +63,24 @@ def main(): subprocess.run(["git", "fetch"], check=True) subprocess.run(["git", "checkout", args.branch], check=True) subprocess.run(["git", "pull", "origin", args.branch], check=True) - subprocess.run( - [ - "./build.sh", - "--config", - "Release", - "--use_tensorrt", - "--tensorrt_home", - args.tensorrt_home, - "--cuda_home", - args.cuda_home, - "--cudnn", - "/usr/lib/x86_64-linux-gnu", - "--build_wheel", - "--skip_tests", - "--parallel", - ], - check=True, - ) + command = [ + "./build.sh", + "--config", + "Release", + "--use_tensorrt", + "--tensorrt_home", + args.tensorrt_home, + "--cuda_home", + args.cuda_home, + "--cudnn", + "/usr/lib/x86_64-linux-gnu", + "--build_wheel", + "--skip_tests", + "--parallel", + ] + if args.use_tensorrt_oss_parser: + command.append("--use_tensorrt_oss_parser") + subprocess.run(command, check=True) ort_wheel_file = install_new_ort_wheel(ort_master_path) diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index 22f2cfb8a01ca..89f9947688583 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -36,6 +36,8 @@ python benchmark.py -e torchscript onnxruntime -p "int8" -o Run OnnxRuntime with the ROCM provider and graph optimization script: python benchmark.py -g -m bert-base-cased --provider rocm --optimizer_info by_script --disable_embed_layer_norm + Run OnnxRuntime with bfloat16 fastmath mode kernels on aarch64 platforms with bfloat16 support: + python benchmark.py --enable_arm64_bfloat16_fastmath_mlas_gemm It is recommended to use run_benchmark.sh to launch benchmark. """ @@ -106,6 +108,7 @@ def run_onnxruntime( use_raw_attention_mask, model_fusion_statistics, model_source, + enable_arm64_bfloat16_fastmath_mlas_gemm, args, ): import onnxruntime @@ -209,6 +212,7 @@ def run_onnxruntime( enable_all_optimization=True, num_threads=num_threads, verbose=verbose, + enable_mlas_gemm_fastmath_arm64_bfloat16=enable_arm64_bfloat16_fastmath_mlas_gemm, ) if ort_session is None: continue @@ -760,6 +764,14 @@ def parse_arguments(): help="Manually set the model's layer number", ) + parser.add_argument( + "--enable_arm64_bfloat16_fastmath_mlas_gemm", + required=False, + action="store_true", + help="Enable bfloat16 mlas gemm kernels on aarch64. Supported only for CPU EP ", + ) + parser.set_defaults(enable_arm64_bfloat16_fastmath_mlas_gemm=False) + FusionOptions.add_arguments(parser) args = parser.parse_args() @@ -905,6 +917,7 @@ def main(): use_raw_attention_mask, model_fusion_statistics, args.model_source, + args.enable_arm64_bfloat16_fastmath_mlas_gemm, args, ) except Exception: diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index c5edae42590fd..c7d93470a729e 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -85,6 +85,7 @@ def create_onnxruntime_session( num_threads=-1, enable_profiling=False, verbose=False, + enable_mlas_gemm_fastmath_arm64_bfloat16=False, provider_options={}, # map execution provider name to its option # noqa: B006 ): session = None @@ -136,6 +137,9 @@ def create_onnxruntime_session( if provider_options: providers = [(name, provider_options[name]) if name in provider_options else name for name in providers] + if enable_mlas_gemm_fastmath_arm64_bfloat16: + sess_options.add_session_config_entry("mlas.enable_gemm_fastmath_arm64_bfloat16", "1") + session = onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers) except Exception: logger.error("Exception", exc_info=True) diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 4c43e4487bfb1..edac1989e4e9e 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -29,6 +29,13 @@ class AttentionOpType(Enum): def __str__(self): return self.value + # Override __eq__ to return string comparison + def __hash__(self): + return hash(self.value) + + def __eq__(self, other): + return other.value == self.value + class FusionOptions: """Options of fusion in graph optimization""" diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index b7881d064067d..796d6ec55ef80 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -138,9 +138,6 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str): # We keep last three layers of Attention as float32 or bfloat16 to avoid overflow. node_block_list = ( [ - "GroupQueryAttention_29", - "GroupQueryAttention_30", - "GroupQueryAttention_31", "Attention_29", "Attention_30", "Attention_31", diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 02100266200f8..7a678f2734ade 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -1,5 +1,22 @@ # Whisper +## Prerequisites + +Please note the package versions needed for using Whisper in the `requirements.txt` file that fits your scenario. +- `requirements-cpu.txt` + - For running Whisper on CPU +- `requirements-cuda.txt` + - For running Whisper on CUDA + - Note that `torch` with CUDA enabled is not installed automatically. This is because `torch` should be installed with the CUDA version used on your machine. Please visit [the PyTorch website](https://pytorch.org/get-started/locally/) to download the `torch` version that is used with the CUDA version installed on your machine and satisfies the requirement listed in the file. +- `requirements.txt` + - Package versions needed in each of the above files + +In addition to the above packages, you will need to install `ffmpeg` on your machine. Visit the [FFmpeg website](https://ffmpeg.org/) for details. You can also install it natively using package managers. + +- Linux: `sudo apt-get install ffmpeg` +- MacOS: `sudo brew install ffmpeg` +- Windows: Download from website + ## Exporting Whisper with Beam Search There are several ways to export Whisper with beam search (using Whisper tiny as an example). @@ -10,10 +27,10 @@ There are several ways to export Whisper with beam search (using Whisper tiny as # From source $ git clone https://github.com/microsoft/onnxruntime $ cd onnxruntime/onnxruntime/python/tools/transformers/ -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format # From wheel -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format ``` ### Option 2: end-to-end model from [Olive](https://github.com/microsoft/Olive/tree/main/examples/whisper) @@ -39,40 +56,49 @@ model.save_pretrained(model_name.split("/")[-1] + "-onnx") Here are some additional examples for exporting Whisper with beam search. +To see all available options +``` +# From source: +$ python3 -m models.whisper.convert_to_onnx --help + +# From wheel: +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx --help +``` + Export with Forced Decoder Input Ids ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --use_forced_decoder_ids +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --use_forced_decoder_ids ``` Export + Optimize for FP32 ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32 # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp32 +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp32 ``` Export + Optimize for FP16 and GPU ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda --disable_auto_mixed_precision ``` Export + Quantize for INT8 ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-tiny --output whispertiny --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3 --output whisperlargev3 --use_external_data_format --precision int8 --quantize_embedding_layer ``` ## Benchmark Whisper diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 759ae6d14f184..e57385aa6db8f 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import argparse import ast import datetime @@ -54,6 +60,8 @@ def load_via_numpy(): inputs["decoder_input_ids"] = np.array([args.decoder_input_ids], dtype=np.int32) if args.has_logits_processor: inputs["logits_processor"] = np.array([args.logits_processor], dtype=np.int32) + if args.has_temperature: + inputs["temperature"] = np.array([args.temperature], dtype=np.float32) # Measure time taken to load audio file logger.info(f"Load audio: {args.audio_path}") @@ -163,6 +171,7 @@ def get_model(args: argparse.Namespace): def time_fn(args, fn, inputs): warmup_inputs = inputs[0] if type(inputs) is tuple else inputs benchmark_inputs = inputs[1] if type(inputs) is tuple else inputs + torch_device = torch.device(args.target_device) # Warm up warmup_range = ( @@ -180,7 +189,7 @@ def time_fn(args, fn, inputs): # Benchmark if args.device != "cpu": - torch.cuda.synchronize() + torch.cuda.synchronize(torch_device) start_time = time.time() bench_range = ( @@ -192,7 +201,7 @@ def time_fn(args, fn, inputs): fn(benchmark_inputs) if args.device != "cpu": - torch.cuda.synchronize() + torch.cuda.synchronize(torch_device) end_time = time.time() # Newline print after trange in order to print metrics on new lines without progress bar on same line @@ -500,7 +509,13 @@ def parse_args(): "--logits-processor", type=int, default=1, - help="Type of logits processor to use. See `BeamSearch` in https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/graph/contrib_ops/contrib_defs.cc for details.", + help="Whether to use timestamps logits processor or not (0 for false, 1 for true).", + ) + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="Temperature value for generation.", ) # Args for accessing detailed info @@ -581,6 +596,7 @@ def main(): args.has_audio_stream = "audio_stream" in ort_model_inputs setattr(args, "has_decoder_input_ids", "decoder_input_ids" in ort_model_inputs) # noqa: B010 setattr(args, "has_logits_processor", "logits_processor" in ort_model_inputs) # noqa: B010 + setattr(args, "has_temperature", "temperature" in ort_model_inputs) # noqa: B010 if args.decoder_input_ids == []: args.decoder_input_ids = [config.decoder_start_token_id] diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index d205a2d340721..814b0dd1ef6ac 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import argparse import datetime import json diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index bb697fe1e1506..35211aab272e4 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -28,17 +28,25 @@ def parse_arguments(argv=None): parser = argparse.ArgumentParser() - pretrained_models = PRETRAINED_WHISPER_MODELS - parser.add_argument( + conversion_args = parser.add_argument_group("Conversion Process Args") + optional_inputs = parser.add_argument_group("Optional Inputs (for WhisperBeamSearch op)") + optional_outputs = parser.add_argument_group("Optional Outputs (for WhisperBeamSearch op)") + quant_args = parser.add_argument_group("INT8 Quantization Args") + + ################################# + # Conversion options for Whisper + ################################# + + conversion_args.add_argument( "-m", "--model_name_or_path", required=False, default=PRETRAINED_WHISPER_MODELS[0], type=str, - help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models), + help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_WHISPER_MODELS), ) - parser.add_argument( + conversion_args.add_argument( "--model_impl", required=False, default="hf", @@ -47,7 +55,7 @@ def parse_arguments(argv=None): help="Select implementation for export of encoder and decoder subgraphs", ) - parser.add_argument( + conversion_args.add_argument( "--cache_dir", required=False, type=str, @@ -55,7 +63,7 @@ def parse_arguments(argv=None): help="Directory to cache pre-trained models", ) - parser.add_argument( + conversion_args.add_argument( "--output", required=False, type=str, @@ -63,19 +71,24 @@ def parse_arguments(argv=None): help="Output directory", ) - parser.add_argument( + conversion_args.add_argument( "-o", "--optimize_onnx", required=False, action="store_true", help="Use optimizer.py to optimize onnx model", ) - parser.set_defaults(optimize_onnx=False) + conversion_args.set_defaults(optimize_onnx=False) - parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference") - parser.set_defaults(use_gpu=False) + conversion_args.add_argument( + "--use_gpu", + required=False, + action="store_true", + help="Use GPU for model inference", + ) + conversion_args.set_defaults(use_gpu=False) - parser.add_argument( + conversion_args.add_argument( "-p", "--precision", required=False, @@ -85,221 +98,226 @@ def parse_arguments(argv=None): help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization", ) - parser.add_argument("--verbose", required=False, action="store_true") - parser.set_defaults(verbose=False) - - parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true") - parser.set_defaults(use_external_data_format=False) - - parser.add_argument( - "-s", - "--use_decoder_start_token", + conversion_args.add_argument( + "--use_int64_inputs", required=False, action="store_true", - help="Use config.decoder_start_token_id. Otherwise, add an extra graph input to \ - the encoder-decoder-init subgraph for decoder_input_ids.", + help="Use int64 instead of int32 for input_ids and attention_mask.", ) - parser.set_defaults(use_decoder_start_token=False) + conversion_args.set_defaults(use_int64_inputs=False) - parser.add_argument( - "-f", - "--use_forced_decoder_ids", + conversion_args.add_argument( + "--disable_auto_mixed_precision", required=False, action="store_true", - help="Use decoder_input_ids as an extra graph input to the beam search op", + help="Use pure fp16 instead of mixed precision", ) - parser.set_defaults(use_forced_decoder_ids=False) + conversion_args.set_defaults(disable_auto_mixed_precision=False) - parser.add_argument( - "-l", - "--use_logits_processor", + conversion_args.add_argument( + "-r", + "--provider", required=False, - action="store_true", - help="Use logits_processor as an extra graph input to enable specific logits processing", + type=str, + default="cpu", + choices=list(PROVIDERS.keys()), + help="Provider to benchmark. Default is CPUExecutionProvider.", ) - parser.set_defaults(use_specific_logits_processor=False) - parser.add_argument( - "-v", - "--use_vocab_mask", + conversion_args.add_argument( + "--verbose", required=False, action="store_true", - help="Use vocab_mask as an extra graph input to enable specific logits processing", + help="Enable verbose logging", ) - parser.set_defaults(use_vocab_mask=False) + conversion_args.set_defaults(verbose=False) - parser.add_argument( - "-u", - "--use_prefix_vocab_mask", + conversion_args.add_argument( + "-e", + "--use_external_data_format", required=False, action="store_true", - help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", + help="Save weights in external file. Necessary for 'small', 'medium', and 'large' models. Optional for 'tiny' and 'base' models.", ) - parser.set_defaults(use_prefix_vocab_mask=False) + conversion_args.set_defaults(use_external_data_format=False) - parser.add_argument( + conversion_args.add_argument( "-w", "--overwrite", required=False, action="store_true", - help="overwrite existing ONNX model", + help="Overwrite existing ONNX model", ) - parser.set_defaults(overwrite=False) + conversion_args.set_defaults(overwrite=False) - parser.add_argument( - "--disable_auto_mixed_precision", + conversion_args.add_argument( + "--separate_encoder_and_decoder_init", required=False, action="store_true", - help="use pure fp16 instead of mixed precision", + help="Do not merge encoder and decoder init to initialize past KV caches. Output 3 instead of 2 ONNX models.", ) - parser.set_defaults(disable_auto_mixed_precision=False) + conversion_args.set_defaults(separate_encoder_and_decoder_init=False) - parser.add_argument( - "--separate_encoder_and_decoder_init", + conversion_args.add_argument( + "--no_beam_search_op", required=False, action="store_true", - help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.", + help="Do not produce model with WhisperBeamSearch op, which chains encdecinit and decoder models into one op.", ) - parser.set_defaults(separate_encoder_and_decoder_init=False) + conversion_args.set_defaults(no_beam_search_op=False) - parser.add_argument( - "--use_int64_inputs", + conversion_args.add_argument( + "--state_dict_path", + type=str, + default="", + help="Filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", + ) + + ############################################################# + # Optional inputs for Whisper + # (listed below in the order that WhisperBeamSearch expects) + ############################################################# + + optional_inputs.add_argument( + "-v", + "--use_vocab_mask", required=False, action="store_true", - help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.", + help="Use vocab_mask as an extra graph input to enable specific logits processing", ) - parser.set_defaults(use_int64_inputs=False) + optional_inputs.set_defaults(use_vocab_mask=False) - parser.add_argument( - "--chain_model", + optional_inputs.add_argument( + "-u", + "--use_prefix_vocab_mask", required=False, action="store_true", - help="Produce beam search model with chained encdecinit and decoder.", + help="Use prefix_vocab_mask as an extra graph input to enable specific logits processing", ) - parser.set_defaults(chain_model=True) + optional_inputs.set_defaults(use_prefix_vocab_mask=False) - parser.add_argument( - "--use_whisper_beamsearch", + optional_inputs.add_argument( + "-f", + "--use_forced_decoder_ids", required=False, action="store_true", - help="When chain_model, using WhisperBeamSearch operator rather than BeamSearch operator. \ - It will be set to true when collect_cross_qk, extra_decoding_ids or output_no_speech_probs is set.", + help="Use decoder_input_ids as an extra graph input to the beam search op", ) - parser.set_defaults(use_whisper_beamsearch=False) + optional_inputs.set_defaults(use_forced_decoder_ids=False) - parser.add_argument( - "--extra_decoding_ids", + optional_inputs.add_argument( + "-l", + "--use_logits_processor", required=False, action="store_true", - help="Need extra starting decoding ids for some feature like cross qk. Default if false.", + help="Use logits_processor as an extra graph input to enable specific logits processing", ) - parser.set_defaults(extra_decoding_ids=False) + optional_inputs.set_defaults(use_specific_logits_processor=False) - parser.add_argument( + optional_inputs.add_argument( "--collect_cross_qk", required=False, action="store_true", help="Beam search model collect stacked cross QK.", ) - parser.set_defaults(collect_cross_qk=False) + optional_inputs.set_defaults(collect_cross_qk=False) - parser.add_argument( - "--output_cross_qk", + optional_inputs.add_argument( + "--extra_decoding_ids", required=False, action="store_true", - help="Beam search model output collected qk as output. Also hint collect_cross_qk", + help="Need extra starting decoding ids for some feature like cross qk. Default if false.", ) - parser.set_defaults(output_cross_qk=False) + optional_inputs.set_defaults(extra_decoding_ids=False) - parser.add_argument( - "--no_speech_token_id", - default=50362, + optional_inputs.add_argument( + "-t", + "--use_temperature", + required=False, + action="store_true", + help="Use temperature as an extra graph input for the WhisperBeamSearch op", + ) + optional_inputs.set_defaults(use_temperature=False) + + optional_inputs.add_argument( + "--no_repeat_ngram_size", type=int, - help="specify no_speech_token_id. Default is 50362. if >= 0, will be add into beam search attr. \ - Note that default value maybe different between the multilingual and English-only models.", + default=0, + help="default to 0", ) - parser.add_argument( - "--output_no_speech_probs", + ############################################################# + # Optional outputs for Whisper + # (listed below in the order that WhisperBeamSearch expects) + ############################################################# + + optional_outputs.add_argument( + "--output_sequence_scores", required=False, action="store_true", - help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", + help="Beam search model output scores for each generated sequence.", ) - parser.set_defaults(output_no_speech_probs=False) + optional_outputs.set_defaults(output_sequence_scores=False) - parser.add_argument( + optional_outputs.add_argument( "--output_scores", required=False, action="store_true", help="Beam search model output scores over vocab per generated token.", ) - parser.set_defaults(output_scores=False) + optional_outputs.set_defaults(output_scores=False) - parser.add_argument( - "--output_sequence_scores", + optional_outputs.add_argument( + "--output_cross_qk", required=False, action="store_true", - help="Beam search model output scores for each generated sequence.", + help="Beam search model output collected qk as output. Also hint collect_cross_qk", ) - parser.set_defaults(output_sequence_scores=False) + optional_outputs.set_defaults(output_cross_qk=False) - parser.add_argument( + optional_outputs.add_argument( "--cross_qk_onnx_model", required=False, type=str, default=None, - help="the model which consume cross_qk.", + help="The model which consumes cross_qk outputs.", ) - parser.add_argument( - "--beam_output_model", - type=str, - default="whisper_beamsearch.onnx", - help="default name is whisper_beamsearch.onnx.", + optional_outputs.add_argument( + "--output_no_speech_probs", + required=False, + action="store_true", + help="Beam search model output no speech probs which is computed from the encoder/context-decoder graph.", ) + optional_outputs.set_defaults(output_no_speech_probs=False) - parser.add_argument( + ################################### + # Quantization options for Whisper + ################################### + + quant_args.add_argument( "--quantize_embedding_layer", required=False, action="store_true", help="Quantize MatMul, GEMM, and Gather.", ) - parser.set_defaults(quantize_embedding_layer=False) + quant_args.set_defaults(quantize_embedding_layer=False) - parser.add_argument( + quant_args.add_argument( "--quantize_per_channel", required=False, action="store_true", help="Quantize weights per each channel.", ) - parser.set_defaults(quantize_per_channel=False) + quant_args.set_defaults(quantize_per_channel=False) - parser.add_argument( + quant_args.add_argument( "--quantize_reduce_range", required=False, action="store_true", help="Quantize weights with 7 bits.", ) - parser.set_defaults(quantize_reduce_range=False) - - parser.add_argument("--no_repeat_ngram_size", type=int, default=0, help="default to 0") - - parser.add_argument( - "--state_dict_path", - type=str, - default="", - help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)", - ) - - parser.add_argument( - "-r", - "--provider", - required=False, - type=str, - default="cpu", - choices=list(PROVIDERS.keys()), - help="Provider to benchmark. Default is CPUExecutionProvider.", - ) + quant_args.set_defaults(quantize_reduce_range=False) args = parser.parse_args(argv) args.collect_cross_qk = args.collect_cross_qk or args.output_cross_qk @@ -317,7 +335,7 @@ def export_onnx_models( optimize_onnx, precision, verbose, - use_decoder_start_token: bool = False, + use_forced_decoder_ids: bool = False, merge_encoder_and_decoder_init: bool = True, overwrite: bool = False, disable_auto_mixed_precision: bool = False, @@ -362,7 +380,6 @@ def export_onnx_models( onnx_path, verbose, use_external_data_format, - use_decoder_input_ids=not use_decoder_start_token, use_int32_inputs=use_int32_inputs, ) else: @@ -406,7 +423,7 @@ def export_onnx_models( extra_options={"MatMulConstBOnly": True}, ) else: - logger.info(f"Skip optimizing: existed ONNX model {onnx_path}") + logger.info(f"Skip optimizing: existing ONNX model {onnx_path}") else: output_path = onnx_path @@ -449,7 +466,7 @@ def main(argv=None): args.optimize_onnx, args.precision, args.verbose, - args.use_decoder_start_token, + args.use_forced_decoder_ids, not args.separate_encoder_and_decoder_init, args.overwrite, args.disable_auto_mixed_precision, @@ -462,7 +479,7 @@ def main(argv=None): ) max_diff = 0 - if args.chain_model: + if not args.no_beam_search_op: logger.info("Chaining model ... :") args.beam_model_output_dir = WhisperHelper.get_onnx_path( output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt new file mode 100644 index 0000000000000..db2cd95324328 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cpu.txt @@ -0,0 +1,2 @@ +-r requirements.txt +onnxruntime>=1.17.1 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt new file mode 100644 index 0000000000000..9bd215de9bc09 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements-cuda.txt @@ -0,0 +1,4 @@ +-r requirements.txt +# Please manually install torch>=1.13.0 with CUDA enabled for the CUDA version installed in your system. +# Instructions can be found here: https://pytorch.org/get-started/locally/ +onnxruntime-gpu>=1.17.1 diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt new file mode 100644 index 0000000000000..c307a3665f8a0 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -0,0 +1,11 @@ +torch>=1.13.0 +transformers>=4.24.0 +openai-whisper +ffmpeg-python +datasets +soundfile +librosa +optimum +onnxruntime-extensions>=0.9.0 +protobuf==3.20.2 +numpy==1.23.3 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index a74666b7af297..14691da4ad643 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + import logging import os @@ -9,7 +15,7 @@ update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha, ) from onnx import TensorProto, helper -from transformers import WhisperConfig +from transformers import WhisperConfig, WhisperTokenizer logger = logging.getLogger(__name__) @@ -23,11 +29,22 @@ def verify_inputs(beam_inputs, graph_inputs): assert graph_input.name in beam_input +def clean_list(arr, remove_all_strings=True): + if remove_all_strings: + # Remove all empty strings in list + return list(filter(lambda elm: elm != "", arr)) + + # Remove empty strings at end of list + while len(arr) > 0: + if arr[-1] == "": + arr.pop() + else: + break + return arr + + def chain_model(args): - # Load encoder/decoder and insert necessary (but unused) graph inputs expected by BeamSearch op or WhisperBeamSearch op - args.use_whisper_beamsearch = ( - args.use_whisper_beamsearch or args.collect_cross_qk or args.output_no_speech_probs or args.extra_decoding_ids - ) + # Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op encoder_model = onnx.load_model(args.encoder_path, load_external_data=True) encoder_model.graph.name = "encoderdecoderinit subgraph" @@ -35,7 +52,10 @@ def chain_model(args): decoder_model.graph.name = "decoder subgraph" config = WhisperConfig.from_pretrained(args.model_name_or_path) + tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path) + # Create inputs/outputs for WhisperBeamSearch op + temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature" beam_inputs = [ "input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features", "max_length", @@ -44,38 +64,27 @@ def chain_model(args): "num_return_sequences", "length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty", "repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty", - "vocab_mask" if args.use_prefix_vocab_mask else "", + "vocab_mask" if args.use_vocab_mask else "", "prefix_vocab_mask" if args.use_prefix_vocab_mask else "", "", # attention mask "decoder_input_ids" if args.use_forced_decoder_ids else "", "logits_processor" if args.use_logits_processor else "", + "cross_qk_layer_head" if args.collect_cross_qk else "", + "extra_decoding_ids" if args.extra_decoding_ids else "", + temperature_name if args.use_temperature else "", ] - beam_outputs = ["sequences"] - if args.output_sequence_scores: - beam_outputs.append("sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores") - if args.output_scores: - beam_outputs.append("scores_fp16" if args.precision == Precision.FLOAT16 else "scores") - - if args.use_whisper_beamsearch: - assert len(beam_inputs) == 12 - beam_inputs.extend( - [ - "cross_qk_layer_head" if args.collect_cross_qk else "", - "extra_decoding_ids" if args.extra_decoding_ids else "", - ] - ) - if args.collect_cross_qk: - while len(beam_outputs) < 3: - beam_outputs.extend([""]) - beam_outputs.extend(["cross_qk"]) - if args.output_no_speech_probs: - while len(beam_outputs) < 4: - beam_outputs.extend([""]) - beam_outputs.extend(["no_speech_probs_beam"]) - - input_features_cast_node, len_pen_cast_node, rep_pen_cast_node = None, None, None - output_scores_cast_node = output_sequence_scores_cast_node = None + sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores" + scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores" + beam_outputs = [ + "sequences", + sequence_scores_name if args.output_sequence_scores else "", + scores_name if args.output_scores else "", + "cross_qk" if args.collect_cross_qk else "", + "no_speech_probs_beam" if args.output_no_speech_probs else "", + ] + + graph_nodes = [] if args.precision == Precision.FLOAT16: input_features_cast_node = helper.make_node( "Cast", @@ -98,6 +107,18 @@ def chain_model(args): name="CastRepetitionPenaltyToFp16", to=TensorProto.FLOAT16, ) + graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node]) + + if args.use_temperature: + temp_cast_node = helper.make_node( + "Cast", + inputs=["temperature"], + outputs=["temperature_fp16"], + name="temperature_to_fp16", + to=TensorProto.FLOAT16, + ) + graph_nodes.append(temp_cast_node) + if args.output_sequence_scores: output_sequence_scores_cast_node = helper.make_node( "Cast", @@ -106,6 +127,8 @@ def chain_model(args): name="CastOutputSequenceScoresToFp32", to=TensorProto.FLOAT, ) + graph_nodes.append(output_sequence_scores_cast_node) + if args.output_scores: output_scores_cast_node = helper.make_node( "Cast", @@ -114,26 +137,38 @@ def chain_model(args): name="CastScoresToFp32", to=TensorProto.FLOAT, ) - - operator_type = "WhisperBeamSearch" if args.use_whisper_beamsearch else "BeamSearch" - node = helper.make_node(operator_type, inputs=beam_inputs, outputs=beam_outputs, name="BeamSearch_zcode") - node.domain = "com.microsoft" - node.attribute.extend( - [ - helper.make_attribute("eos_token_id", config.eos_token_id), - helper.make_attribute("pad_token_id", config.pad_token_id), - helper.make_attribute("decoder_start_token_id", config.decoder_start_token_id), - helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), - helper.make_attribute("early_stopping", True), - helper.make_attribute("model_type", 2), - ] + graph_nodes.append(output_scores_cast_node) + + # Create WhisperBeamSearch op + beam_search_attrs = [ + helper.make_attribute("eos_token_id", config.eos_token_id), + helper.make_attribute("pad_token_id", config.pad_token_id), + helper.make_attribute( + "decoder_start_token_id", config.decoder_start_token_id + ), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0] + helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]), + helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]), + helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]), + helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0]) + if args.output_no_speech_probs + else "", + helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]), + helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]), + helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + helper.make_attribute("early_stopping", True), + helper.make_attribute("model_type", 2), + helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "", + ] + node = helper.make_node( + "WhisperBeamSearch", + inputs=clean_list(beam_inputs, remove_all_strings=False), + outputs=clean_list(beam_outputs, remove_all_strings=False), + name="BeamSearch", + domain="com.microsoft", ) - if args.use_whisper_beamsearch: - if args.collect_cross_qk: - node.attribute.extend([helper.make_attribute("decoder_output_cross_qk", 1)]) - if args.no_speech_token_id >= 0: - node.attribute.extend([helper.make_attribute("no_speech_token", args.no_speech_token_id)]) + node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True)) + # Graph inputs input_features = helper.make_tensor_value_info( "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"] ) @@ -143,73 +178,63 @@ def chain_model(args): num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1]) length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) + vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) + prefix_vocab_mask = helper.make_tensor_value_info( + "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] + ) + decoder_input_ids = helper.make_tensor_value_info( + "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] + ) + logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) + cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2]) + extra_decoding_ids = helper.make_tensor_value_info( + "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] + ) + temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) - graph_inputs = [ - input_features, - max_length, - min_length, - num_beams, - num_return_sequences, - length_penalty, - repetition_penalty, - ] - if args.use_vocab_mask: - vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size]) - graph_inputs.append(vocab_mask) - - if args.use_prefix_vocab_mask: - prefix_vocab_mask = helper.make_tensor_value_info( - "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size] - ) - graph_inputs.append(prefix_vocab_mask) - - if args.use_forced_decoder_ids: - decoder_input_ids = helper.make_tensor_value_info( - "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"] - ) - graph_inputs.append(decoder_input_ids) - - if args.use_logits_processor: - logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1]) - graph_inputs.append(logits_processor) - - if args.collect_cross_qk: - cross_qk_layer_head = helper.make_tensor_value_info( - "cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2] - ) - graph_inputs.append(cross_qk_layer_head) - - if args.extra_decoding_ids: - extra_decoding_ids = helper.make_tensor_value_info( - "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"] - ) - graph_inputs.append(extra_decoding_ids) + graph_inputs = clean_list( + [ + input_features, + max_length, + min_length, + num_beams, + num_return_sequences, + length_penalty, + repetition_penalty, + vocab_mask if args.use_vocab_mask else "", + prefix_vocab_mask if args.use_prefix_vocab_mask else "", + decoder_input_ids if args.use_forced_decoder_ids else "", + logits_processor if args.use_logits_processor else "", + cross_qk_layer_head if args.collect_cross_qk else "", + extra_decoding_ids if args.extra_decoding_ids else "", + temperature if args.use_temperature else "", + ] + ) - # graph outputs + # Graph outputs sequences = helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"] ) - graph_outputs = [sequences] - if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk): - cross_qk = helper.make_tensor_value_info( - "cross_qk", - TensorProto.FLOAT, - ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], - ) - graph_outputs.extend([cross_qk]) - - if args.output_no_speech_probs: - no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([no_speech_probs]) - - if args.output_sequence_scores: - sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([sequence_scores]) + sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"]) + scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) + cross_qk = helper.make_tensor_value_info( + "cross_qk", + TensorProto.FLOAT, + ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"], + ) + no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"]) - if args.output_scores: - scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"]) - graph_outputs.extend([scores]) + graph_outputs = clean_list( + [ + sequences, + sequence_scores if args.output_sequence_scores else "", + scores if args.output_scores else "", + cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "", + no_speech_probs if args.output_no_speech_probs else "", + ] + ) + # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference if hasattr(args, "use_gpu") and args.use_gpu: if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph): logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!") @@ -230,19 +255,7 @@ def chain_model(args): opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)] - graph_nodes = ( - [ - input_features_cast_node, - len_pen_cast_node, - rep_pen_cast_node, - node, - output_sequence_scores_cast_node, - output_scores_cast_node, - ] - if args.precision == Precision.FLOAT16 - else [node] - ) - graph_nodes = [node for node in graph_nodes if node is not None] + graph_nodes.append(node) if args.output_no_speech_probs: prob_cast_node = helper.make_node( "Cast", @@ -251,9 +264,16 @@ def chain_model(args): name="no_speech_probs_cast_to_fp32", to=TensorProto.FLOAT, ) - graph_nodes.extend([prob_cast_node]) - - beam_graph = helper.make_graph(graph_nodes, "beam-search-test", graph_inputs, graph_outputs, initializers) + graph_nodes.append(prob_cast_node) + + # Make graph with WhisperBeamSearch op + beam_graph = helper.make_graph( + graph_nodes, + name="WhisperBeamSearch Graph", + inputs=graph_inputs, + outputs=graph_outputs, + initializer=initializers, + ) beam_graph_input_names = [gi.name for gi in graph_inputs] beam_graph_output_names = [go.name for go in graph_outputs] @@ -287,10 +307,12 @@ def chain_model(args): ir_version=decoder_model.ir_version, ) + # Save WhisperBeamSearch graph and external data if os.path.isfile(args.beam_model_output_dir): logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}") os.remove(args.beam_model_output_dir) os.remove(args.beam_model_output_dir + ".data") + onnx.save( beam_model, args.beam_model_output_dir, diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index 0d69960a095ac..93fd64c9eb7d3 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -170,7 +170,7 @@ def create_dummy( cross_attention_past_shape = [ batch_size, num_attention_heads, - past_decode_sequence_length, + encode_sequence_length, head_size, ] diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index 351173f525727..832f692e9980d 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -75,7 +75,7 @@ def create_dummy( config: WhisperConfig, batch_size: int, encode_sequence_length: int, - use_decoder_input_ids: int, + use_decoder_input_ids: bool, device: torch.device, use_int32_inputs: bool = False, ): # -> WhisperEncoderDecoderInitInputs: @@ -125,7 +125,7 @@ def export_onnx( model.config, batch_size=2, encode_sequence_length=3000, - use_decoder_input_ids=use_decoder_input_ids, + use_decoder_input_ids=True, device=device, use_int32_inputs=use_int32_inputs, ) @@ -159,7 +159,7 @@ def export_onnx( hidden_size = str(model.config.d_model) head_size = str(model.config.d_model // model.config.encoder_attention_heads) dynamic_axes = { - "encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"}, + "encoder_input_ids": {0: "batch_size", 1: "feature_size"}, "encoder_hidden_states": { 0: "batch_size", 1: "encode_sequence_length", diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py index e2dc79ca247ce..1b47b9426d983 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_helper.py @@ -6,12 +6,14 @@ import logging import os -import sys from pathlib import Path from typing import Dict, Tuple, Union import numpy as np import torch +from float16 import float_to_float16_max_diff +from onnx_model import OnnxModel +from optimizer import optimize_model from packaging import version from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor from transformers import __version__ as transformers_version @@ -21,24 +23,20 @@ from onnxruntime import InferenceSession -sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from float16 import float_to_float16_max_diff -from onnx_model import OnnxModel -from optimizer import optimize_model - logger = logging.getLogger(__name__) PRETRAINED_WHISPER_MODELS = [ "whisper-tiny", "whisper-tiny.en", + "whisper-base", + "whisper-base.en", "whisper-small", "whisper-small.en", "whisper-medium", "whisper-medium.en", - "whisper-base", - "whisper-base.en", "whisper-large", "whisper-large-v2", + "whisper-large-v3", ] @@ -346,7 +344,12 @@ def verify_onnx( ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features - batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 26, 0, 5, 1 + start_id = [config.decoder_start_token_id] # ex: [50258] + prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") + prompt_ids = list(map(lambda token: token[1], prompt_ids)) # ex: [50259, 50358, 50363] + forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363] + + batch_size, max_length, min_length, num_beams, num_return_sequences = 1, 30, 0, 1, 1 length_penalty, repetition_penalty = 1.0, 1.0 inputs = { "input_features": input_features.to(device), @@ -383,43 +386,51 @@ def verify_onnx( elif name == "prefix_vocab_mask": inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype]) elif name == "decoder_input_ids": - raw_input_ids = ( - [[config.decoder_start_token_id]] - if use_extra_decoding_ids - else [[config.decoder_start_token_id, 50259, 50359, 50363]] - ) + raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids] inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype]) elif name == "logits_processor": inputs[name] = np.array([1], dtype=ort_to_np[dtype]) elif name == "cross_qk_layer_head": inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype]) elif name == "extra_decoding_ids": - inputs[name] = np.repeat(np.array([[50259, 50359, 50363]], dtype=ort_to_np[dtype]), batch_size, 0) + inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0) + elif name == "temperature": + inputs[name] = np.array([1.0], dtype=ort_to_np[dtype]) else: inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype]) ort_outputs = ort_session.run(None, inputs)[0][0] - if pt_outputs.shape != ort_outputs.shape: - logger.warning("PyTorch and ONNX Runtime outputs do not have the same shape") + expected_transcription_no_comma = ( + " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." + ) + expected_transcription_with_comma = ( + " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." + ) + expected_transcription_with_quote_and_comma = ( + ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.' + ) + expected_transcription_options = { + expected_transcription_no_comma, + expected_transcription_with_comma, + expected_transcription_with_quote_and_comma, + } + pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True)[0] + ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)[0] - diff = pt_outputs - ort_outputs - max_diff = max(diff.min(), diff.max(), key=abs) + parity = ( + pt_transcription in expected_transcription_options and ort_transcription in expected_transcription_options + ) + max_diff = 0 - if max_diff > 0: - # For ONNX Runtime INT8 model - pt_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel." - ) - pt_transcription = processor.batch_decode(pt_outputs, skip_special_tokens=True) - ort_expected_transcription = ( - " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel." - ) - ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True) + if not parity: + if pt_outputs.shape != ort_outputs.shape: + diff = pt_outputs - ort_outputs[:, : len(pt_outputs[0])] + else: + diff = pt_outputs - ort_outputs + max_diff = max(diff.min(), diff.max(), key=abs) - parity = ( - pt_expected_transcription == pt_transcription[0] and ort_expected_transcription == ort_transcription[0] - ) - if parity: - max_diff = 0 + if max_diff != 0: + logger.warning(f"PyTorch outputs: {pt_transcription}") + logger.warning(f"ONNX Runtime outputs: {ort_transcription}") return max_diff diff --git a/onnxruntime/python/tools/transformers/onnx_model_phi.py b/onnxruntime/python/tools/transformers/onnx_model_phi.py index e68c3120e3f09..0fdce29ae0fa0 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_phi.py +++ b/onnxruntime/python/tools/transformers/onnx_model_phi.py @@ -80,14 +80,17 @@ def set_attention_op_type(self, attn_op_type: AttentionOpType): def get_uname(self, layer_id, name): return name + "_" + str(layer_id) - def get_io_by_name(self, node, name): - for input in node.input: - if input == name or input.endswith(name) or input.startswith(name): - return input - for output in node.output: - if output == name or output.endswith(name) or output.startswith(name): - return output - raise Exception(f"input {name} not found in node {node.name}") + def get_edge_by_name(self, edges, name): + for edge in edges: + if edge == name or edge.endswith(name) or edge.startswith(name): + return edge + raise ValueError(f"Edge {name} not found") + + def get_input_by_name(self, node, name): + return self.get_edge_by_name(node.input, name) + + def get_output_by_name(self, node, name): + return self.get_edge_by_name(node.output, name) def process_initializer(self, initializer_name, functor, custom_name=None): i = self.model.get_initializer(initializer_name) @@ -287,7 +290,6 @@ def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): self.num_attention_heads = num_heads self.hidden_size = hidden_size - self.phi2_edge_dict = self.get_phi2_edge_dict() self.func_name = "modeling_phi_PhiModel_model_1" def get_phi2_edge_dict(self) -> dict: @@ -296,11 +298,20 @@ def get_phi2_edge_dict(self) -> dict: edge_dict["l_input_ids_"] = "input_ids" edge_dict["key_states"] = "past_key_0" edge_dict["value_states"] = "past_value_0" - for i in range(self.num_hidden_layers): + for i in range(1, self.num_hidden_layers, 1): edge_dict[f"key_states_{i}"] = f"past_key_{i}" edge_dict[f"value_states_{i}"] = f"past_value_{i}" edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}" edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}" + + outputs = [o.name for o in self.model.graph.output] + if "model_layers_0_1_1" in outputs and "model_layers_0_1_2" in outputs: + edge_dict["model_layers_0_1_1"] = "present_key_0" + edge_dict["model_layers_0_1_2"] = "present_value_0" + else: + assert "model_layers_0_1" in outputs and "model_layers_0_1_1" in outputs + edge_dict["model_layers_0_1"] = "present_key_0" + edge_dict["model_layers_0_1_1"] = "present_value_0" return edge_dict def simplify_phi2_op_type(self): @@ -441,7 +452,7 @@ def preprocess_onnx(self, attn_op_type: AttentionOpType): break assert function_name is not None self.unroll_function(function_name) - self.update_edges(self.phi2_edge_dict) + self.update_edges(self.get_phi2_edge_dict()) self.simplify_phi2_op_type() self.remove_dropout_layer() if attn_op_type == AttentionOpType.PagedAttention: @@ -465,7 +476,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): input = node.input[0] output = node.output[0] - embedding = self.get_io_by_name(node, "embed_tokens.weight") + embedding = self.get_input_by_name(node, "embed_tokens.weight") layer_known_edges_names = [input, output, embedding] @@ -499,8 +510,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): input = node.input[0] output = node.output[0] - ln_weight = self.get_io_by_name(node, "final_layernorm.weight") - ln_bias = self.get_io_by_name(node, "final_layernorm.bias") + ln_weight = self.get_input_by_name(node, "final_layernorm.weight") + ln_bias = self.get_input_by_name(node, "final_layernorm.bias") layer_known_edges_names = [input, output, ln_weight, ln_bias] @@ -532,8 +543,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): input = node.input[2] output = node.output[0] - fc_weight = self.process_initializer(self.get_io_by_name(node, "lm_head.weight"), ProcessGemmWFunc()) - fc_bias = self.get_io_by_name(node, "lm_head.bias") + fc_weight = self.process_initializer(self.get_input_by_name(node, "lm_head.weight"), ProcessGemmWFunc()) + fc_bias = self.get_input_by_name(node, "lm_head.bias") layer_known_edges_names = [input, output, fc_weight, fc_bias] @@ -670,15 +681,15 @@ def fuse( layer_id = self.get_layer_id(node) i_hidden_states = node.input[0] - i_key_cache = self.get_io_by_name(node, "past_key") - i_value_cache = self.get_io_by_name(node, "past_value") + i_key_cache = self.get_input_by_name(node, "past_key") + i_value_cache = self.get_input_by_name(node, "past_value") - o_hidden_states = node.output[3] - o_key_cache = self.get_io_by_name(node, "present_key") - o_value_cache = self.get_io_by_name(node, "present_value") + o_hidden_states = node.output[-1] + o_key_cache = self.get_output_by_name(node, "present_key") + o_value_cache = self.get_output_by_name(node, "present_value") - ln_weight = self.get_io_by_name(node, "input_layernorm.weight") - ln_bias = self.get_io_by_name(node, "input_layernorm.bias") + ln_weight = self.get_input_by_name(node, "input_layernorm.weight") + ln_bias = self.get_input_by_name(node, "input_layernorm.bias") attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = ( None, @@ -693,45 +704,45 @@ def fuse( if self.attn_op_type != AttentionOpType.Attention: attn_q_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc() + self.get_input_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc() ) attn_k_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc() + self.get_input_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc() ) attn_v_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc() + self.get_input_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc() ) - attn_q_bias = self.get_io_by_name(node, "self_attn.q_proj.bias") - attn_k_bias = self.get_io_by_name(node, "self_attn.k_proj.bias") - attn_v_bias = self.get_io_by_name(node, "self_attn.v_proj.bias") + attn_q_bias = self.get_input_by_name(node, "self_attn.q_proj.bias") + attn_k_bias = self.get_input_by_name(node, "self_attn.k_proj.bias") + attn_v_bias = self.get_input_by_name(node, "self_attn.v_proj.bias") cos_cache = self.process_initializer( - self.get_io_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc() + self.get_input_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc() ) sin_cache = self.process_initializer( - self.get_io_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc() + self.get_input_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc() ) else: attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm( - self.get_io_by_name(node, "self_attn.q_proj.weight"), - self.get_io_by_name(node, "self_attn.k_proj.weight"), - self.get_io_by_name(node, "self_attn.v_proj.weight"), - self.get_io_by_name(node, "self_attn.q_proj.bias"), - self.get_io_by_name(node, "self_attn.k_proj.bias"), - self.get_io_by_name(node, "self_attn.v_proj.bias"), + self.get_input_by_name(node, "self_attn.q_proj.weight"), + self.get_input_by_name(node, "self_attn.k_proj.weight"), + self.get_input_by_name(node, "self_attn.v_proj.weight"), + self.get_input_by_name(node, "self_attn.q_proj.bias"), + self.get_input_by_name(node, "self_attn.k_proj.bias"), + self.get_input_by_name(node, "self_attn.v_proj.bias"), self.get_uname(layer_id, "attn_qkv_weight"), self.get_uname(layer_id, "attn_qkv_bias"), ) attn_out_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc() + self.get_input_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc() ) - attn_out_bias = self.get_io_by_name(node, "self_attn.dense.bias") + attn_out_bias = self.get_input_by_name(node, "self_attn.dense.bias") - mlp_fc1_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc()) - mlp_fc2_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc()) - mlp_fc1_bias = self.get_io_by_name(node, "mlp.fc1.bias") - mlp_fc2_bias = self.get_io_by_name(node, "mlp.fc2.bias") + mlp_fc1_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc()) + mlp_fc2_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc()) + mlp_fc1_bias = self.get_input_by_name(node, "mlp.fc1.bias") + mlp_fc2_bias = self.get_input_by_name(node, "mlp.fc2.bias") layer_known_edges_names = [] layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache]) @@ -771,6 +782,7 @@ def fuse( subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_")) subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_")) subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_")) + # vllm engine requires full position ids as the input pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step" subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_")) subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_")) diff --git a/onnxruntime/python/tools/transformers/quantize_helper.py b/onnxruntime/python/tools/transformers/quantize_helper.py index a449e881ad361..6a25196dbc24c 100644 --- a/onnxruntime/python/tools/transformers/quantize_helper.py +++ b/onnxruntime/python/tools/transformers/quantize_helper.py @@ -7,7 +7,7 @@ import logging import os -import onnx # noqa: F401 +import onnx import torch from transformers.modeling_utils import Conv1D @@ -69,6 +69,7 @@ def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data onnx_model_path, quantized_model_path, use_external_data_format=use_external_data_format, + extra_options={"DefaultTensorType": onnx.TensorProto.FLOAT}, ) logger.info(f"quantized model saved to:{quantized_model_path}") # TODO: inlcude external data in total model size. diff --git a/onnxruntime/python/tools/transformers/run_benchmark.sh b/onnxruntime/python/tools/transformers/run_benchmark.sh old mode 100644 new mode 100755 index f0422839c11eb..64d6ecde618f6 --- a/onnxruntime/python/tools/transformers/run_benchmark.sh +++ b/onnxruntime/python/tools/transformers/run_benchmark.sh @@ -34,6 +34,9 @@ run_gpu_fp16=true run_cpu_fp32=false run_cpu_int8=false +# Set this to true to enable bfloat16 fastmath gemm kernels on aarch64 platforms with bfloat16 support +arm64_bfloat16_fastmath_mode=false + average_over=1000 # CPU takes longer time to run, only run 100 inferences to get average latency. if [ "$run_cpu_fp32" = true ] || [ "$run_cpu_int8" = true ]; then @@ -63,7 +66,7 @@ models_to_test="bert-base-cased roberta-base distilbert-base-uncased" # export CUDA_VISIBLE_DEVICES=1 # This script will generate a logs file with a list of commands used in tests. -echo echo "ort=$run_ort torch=$run_torch torch2=$run_torch2 torchscript=$run_torchscript tensorflow=$run_tensorflow gpu_fp32=$run_gpu_fp32 gpu_fp16=$run_gpu_fp16 cpu=$run_cpu optimizer=$use_optimizer batch=$batch_sizes sequence=$sequence_length models=$models_to_test" >> benchmark.log +echo echo "ort=$run_ort torch=$run_torch torch2=$run_torch2 torchscript=$run_torchscript tensorflow=$run_tensorflow gpu_fp32=$run_gpu_fp32 gpu_fp16=$run_gpu_fp16 cpu=$run_cpu optimizer=$use_optimizer batch=$batch_sizes sequence=$sequence_length models=$models_to_test" arm64_bfloat16_fastmath_mode=$arm64_bfloat16_fastmath_mode >> benchmark.log # Set it to false to skip testing. You can use it to dry run this script with the log file. run_tests=true @@ -127,6 +130,10 @@ if [ "$force_layer_number" = true ] ; then benchmark_options="$benchmark_options --force_num_layers $layer_number" fi +if [ "$arm64_bfloat16_fastmath_mode" = true ] ; then + benchmark_options="$benchmark_options --enable_arm64_bfloat16_fastmath_mlas_gemm" +fi + # ------------------------------------------- run_one_test() { if [ "$run_ort" = true ] ; then diff --git a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py index f3e67930adbff..66f24c47f6cdb 100644 --- a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py +++ b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import torch +from torch._C._onnx import OperatorExportTypes TrainingMode = torch.onnx.TrainingMode from packaging.version import Version # noqa: E402 @@ -18,7 +19,7 @@ def torch_onnx_export( training=TrainingMode.EVAL, input_names=None, output_names=None, - operator_export_type=None, + operator_export_type=OperatorExportTypes.ONNX, opset_version=None, _retain_param_name=None, do_constant_folding=True, diff --git a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc index 6afb61bd1f0a1..8ea37ad054ed0 100644 --- a/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/decoder_masked_multihead_attention_op_test.cc @@ -640,122 +640,139 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp32) { return; } - // Vary batch size - for (int batch_size = 1; batch_size <= 5; batch_size += 2) { - // Vary kv_lengths - for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) { - int sequence_length = 1; - int number_of_heads = 12; - // Vary head_size / hidden_size - int hidden_sizes[3] = {384, 768, 1536}; - for (int hidden_size : hidden_sizes) { - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; - - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); - - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); - - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); - - // Mask - tester.AddOptionalInputEdge(); - - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; - - auto kv_cache = CreateRandom(past_present_size); - - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); - - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(float); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); - - tester.AddInput("past", past_dims, reordered_kv_cache); - - // Rel - tester.AddOptionalInputEdge(); - - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); - - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); - - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); - - auto k_merged = pair.first; - auto k_transpose = pair.second; - - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); - - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); - - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); - - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + // Buckets for test data: + // batch_size: 1, >=2 + // past_sequence_length 0~30, 31~2046, >=2047 (so that total_sequence_length: 1~31, 32~2047, >=2048) + // head_size: 32, 64, 128 + struct MyTestCase { + int batch_size; + int past_sequence_length; + int hidden_size; + } test_cases[] = { + {1, 0, 768}, + {1, 1, 384}, + {2, 30, 768}, + {3, 31, 1536}, + {4, 512, 384}, + {1, 1024, 768}, + {1, 2046, 1536}, + {2, 2047, 384}, + {3, 3000, 768}, + }; + + constexpr int sequence_length = 1; + constexpr int number_of_heads = 12; + + for (MyTestCase test_case : test_cases) { + int batch_size = test_case.batch_size; + int past_sequence_length = test_case.past_sequence_length; + int hidden_size = test_case.hidden_size; + + int head_size = (hidden_size / number_of_heads); + int total_sequence_length = sequence_length + past_sequence_length; + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + + OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(1)); + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector weights_dims = {hidden_size, 3 * hidden_size}; + std::vector bias_dims = {3 * hidden_size}; + std::vector output_dims = {batch_size, sequence_length, hidden_size}; + + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); + + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); + + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); + + // Mask + tester.AddOptionalInputEdge(); + + // Past + std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + + auto kv_cache = CreateRandom(past_present_size); + + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + number_of_heads, past_sequence_length, head_size, max_sequence_length); + + // Validate if reordering went well - by transposing and checking equality + int chunk_size = 16 / sizeof(float); + int num_chunks = head_size / chunk_size; + auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); + + tester.AddInput("past", past_dims, reordered_kv_cache); + + // Rel + tester.AddOptionalInputEdge(); + + // Past sequence length + std::vector arr_past_sequence_len(1, past_sequence_length); + tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + + // QKV MatMul + auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); + auto* qkv_matrix = qkv.data(); + + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, + max_sequence_length, head_size); + + auto k_merged = pair.first; + auto k_transpose = pair.second; + + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, + total_sequence_length, head_size); + + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, + sequence_length, total_sequence_length, head_size); + + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); + + // Validate our test logic + // We want to validate if our merged "unordered" K is the same as + // the merged "ordered" K so that the QKT we do in our test code + // is equivalent to the QKT we do in the kernel + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); + + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, number_of_heads, + sequence_length, total_sequence_length, + max_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); - - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); - - // Output(s) - tester.AddOutput("output", input_dims, output); + // Output(s) + tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); + tester.AddOutput("present", past_dims, present); - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } + // Run - Regular kernel execution path + { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; + // Test alternate kernel path of loading more KV data "in flight" + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } } @@ -766,122 +783,138 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { return; } - // Vary batch size - for (int batch_size = 1; batch_size <= 5; batch_size += 2) { - // Vary kv_lengths - for (int past_sequence_length = 1; past_sequence_length <= 3000; past_sequence_length += 150) { - int sequence_length = 1; - int number_of_heads = 12; - - // Vary head_size / hidden_size - int hidden_sizes[3] = {384, 768, 1536}; - for (int hidden_size : hidden_sizes) { - int head_size = (hidden_size / number_of_heads); - int total_sequence_length = sequence_length + past_sequence_length; - int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length - - OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); - tester.AddAttribute("num_heads", static_cast(number_of_heads)); - tester.AddAttribute("past_present_share_buffer", static_cast(1)); - - std::vector input_dims = {batch_size, sequence_length, hidden_size}; - std::vector weights_dims = {hidden_size, 3 * hidden_size}; - std::vector bias_dims = {3 * hidden_size}; - std::vector output_dims = {batch_size, sequence_length, hidden_size}; - - auto input = CreateRandom(batch_size * sequence_length * hidden_size); - tester.AddInput("input", input_dims, input); - - auto weight = CreateRandom(hidden_size * 3 * hidden_size); - tester.AddInput("weight", weights_dims, weight); - - auto bias = CreateRandom(3 * hidden_size); - tester.AddInput("bias", bias_dims, bias); - - // Mask - tester.AddOptionalInputEdge(); - - // Past - std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; - int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; - - auto kv_cache = CreateRandom(past_present_size); - - auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, - number_of_heads, past_sequence_length, head_size, max_sequence_length); + // Buckets for test data: + // batch_size: 1, >=2 + // past_sequence_length 0, 1~30, 31~2046, >=2047 (so that total_sequence_length: 1, 2-31, 32~2047, >=2048) + // head_size: 32, 64, 128 + struct MyTestCase { + int batch_size; + int past_sequence_length; + int hidden_size; + } test_cases[] = { + {1, 0, 768}, + {1, 1, 768}, + {3, 30, 384}, + {8, 31, 1536}, + {4, 256, 384}, + {3, 1024, 768}, + {2, 2046, 1536}, + {1, 2047, 384}, + {2, 3000, 768}, + }; + + constexpr int sequence_length = 1; + constexpr int number_of_heads = 12; + + for (MyTestCase test_case : test_cases) { + int batch_size = test_case.batch_size; + int past_sequence_length = test_case.past_sequence_length; + int hidden_size = test_case.hidden_size; + + int head_size = (hidden_size / number_of_heads); + int total_sequence_length = sequence_length + past_sequence_length; + int max_sequence_length = past_sequence_length + 1; // Always keep > past_sequence_length + + OpTester tester("DecoderMaskedSelfAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(number_of_heads)); + tester.AddAttribute("past_present_share_buffer", static_cast(1)); + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector weights_dims = {hidden_size, 3 * hidden_size}; + std::vector bias_dims = {3 * hidden_size}; + std::vector output_dims = {batch_size, sequence_length, hidden_size}; + + auto input = CreateRandom(batch_size * sequence_length * hidden_size); + tester.AddInput("input", input_dims, input); + + auto weight = CreateRandom(hidden_size * 3 * hidden_size); + tester.AddInput("weight", weights_dims, weight); + + auto bias = CreateRandom(3 * hidden_size); + tester.AddInput("bias", bias_dims, bias); + + // Mask + tester.AddOptionalInputEdge(); + + // Past + std::vector past_dims = {2, batch_size, number_of_heads, max_sequence_length, head_size}; + int past_present_size = 2 * batch_size * number_of_heads * max_sequence_length * head_size; + + auto kv_cache = CreateRandom(past_present_size); + + auto reordered_kv_cache = ReorderKVCache(kv_cache, batch_size, + number_of_heads, past_sequence_length, head_size, max_sequence_length); - // Validate if reordering went well - by transposing and checking equality - int chunk_size = 16 / sizeof(MLFloat16); - int num_chunks = head_size / chunk_size; - auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); - CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, - max_sequence_length, past_sequence_length, chunk_size); + // Validate if reordering went well - by transposing and checking equality + int chunk_size = 16 / sizeof(MLFloat16); + int num_chunks = head_size / chunk_size; + auto transposed = Transpose(kv_cache.data(), batch_size, number_of_heads, num_chunks, max_sequence_length, chunk_size); + CheckEquality(transposed.data(), reordered_kv_cache.data(), batch_size, number_of_heads, num_chunks, + max_sequence_length, past_sequence_length, chunk_size); - tester.AddInput("past", past_dims, reordered_kv_cache); + tester.AddInput("past", past_dims, reordered_kv_cache); - // Rel - tester.AddOptionalInputEdge(); + // Rel + tester.AddOptionalInputEdge(); - // Past sequence length - std::vector arr_past_sequence_len(1, past_sequence_length); - tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); + // Past sequence length + std::vector arr_past_sequence_len(1, past_sequence_length); + tester.AddInput("past_sequence_length", {1}, arr_past_sequence_len); - // QKV MatMul - auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); - auto* qkv_matrix = qkv.data(); + // QKV MatMul + auto qkv = QKV(input, weight, bias, batch_size, sequence_length, hidden_size); + auto* qkv_matrix = qkv.data(); - auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, - max_sequence_length, head_size); + auto pair = MergePastKWithPresentKAndTranspose(kv_cache.data(), qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, + max_sequence_length, head_size); - auto k_merged = pair.first; - auto k_transpose = pair.second; + auto k_merged = pair.first; + auto k_transpose = pair.second; - auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, - total_sequence_length, head_size); + auto qk_transpose = QK_Transpose(qkv_matrix, k_transpose.data(), batch_size, number_of_heads, + total_sequence_length, head_size); - auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, - sequence_length, total_sequence_length, head_size); + auto softmax_qk_transpose = Softmax_QK_Transpose(qk_transpose.data(), batch_size, number_of_heads, + sequence_length, total_sequence_length, head_size); - auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + auto present = MergeReorderedKVCacheWithK(reordered_kv_cache, qkv_matrix + hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); - // Validate our test logic - // We want to validate if our merged "unordered" K is the same as - // the merged "ordered" K so that the QKT we do in our test code - // is equivalent to the QKT we do in the kernel - ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); + // Validate our test logic + // We want to validate if our merged "unordered" K is the same as + // the merged "ordered" K so that the QKT we do in our test code + // is equivalent to the QKT we do in the kernel + ValidateReorderedMergedKWithK(k_merged.data(), present.data(), batch_size, number_of_heads, total_sequence_length, max_sequence_length, head_size); - MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, - number_of_heads, past_sequence_length, max_sequence_length, head_size); + MergeReorderedKVCacheWithV(present.data() + (past_present_size / 2), qkv_matrix + 2 * hidden_size, batch_size, + number_of_heads, past_sequence_length, max_sequence_length, head_size); - auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), - batch_size, number_of_heads, - sequence_length, total_sequence_length, - max_sequence_length, head_size); + auto output = Softmax_QK_Transpose_V(softmax_qk_transpose.data(), present.data() + (past_present_size / 2), + batch_size, number_of_heads, + sequence_length, total_sequence_length, + max_sequence_length, head_size); - // Output(s) - tester.AddOutput("output", input_dims, output); + // Output(s) + tester.AddOutput("output", input_dims, output); - tester.AddOutput("present", past_dims, present); + tester.AddOutput("present", past_dims, present); - // Run - Regular kernel execution path - { - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } + // Run - Regular kernel execution path + { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } - // Test alternate kernel path of loading more KV data "in flight" - { - ScopedEnvironmentVariables scoped_env_vars{ - EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; + // Test alternate kernel path of loading more KV data "in flight" + { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{{onnxruntime::contrib::attention::kDecoderMaskedAttentionLoadKVDataInFlight, "1"}}}; - std::vector> execution_providers; - execution_providers.push_back(DefaultCudaExecutionProvider()); - tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } - } + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } } } @@ -889,4 +922,4 @@ TEST(DecoderMaskedSelfAttentionTest, Test_fp16) { #endif } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc index 84bbee35eed5a..98fb62e435f31 100644 --- a/onnxruntime/test/contrib_ops/layer_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/layer_norm_op_test.cc @@ -7,6 +7,7 @@ #include "core/session/inference_session.h" #include "test/common/dnnl_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" #include "test/framework/test_utils.h" #include "test/util/include/default_providers.h" #include "test/providers/provider_test_utils.h" @@ -75,6 +76,28 @@ TEST(LayerNormTest, LayerNorm) { test.Run(); } +TEST(LayerNormTest, LayerNorm_BFloat16Input) { +// prevents test from running on non-BF16-supporting hardware +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware NOT support BFP16"; + return; + } +#endif + OpTester test("LayerNormalization"); + test.AddAttribute("epsilon", 1e-05f); + + std::vector dims{1, 2, 3}; + test.AddInput("x", dims, MakeBFloat16({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f})); + test.AddInput("gamma", {3}, MakeBFloat16({1.0f, 1.0f, 1.0f})); + test.AddOutput("output", dims, MakeBFloat16({-1.2247f, 0.0f, 1.2247f, -1.2247f, 0.0f, 1.2247f})); + // TRT, DNNL, OpenVINO and NNAPI, CoreML don't support this combination of datatypes + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {kTensorrtExecutionProvider, kDnnlExecutionProvider, kOpenVINOExecutionProvider, + kNnapiExecutionProvider, kQnnExecutionProvider, kCoreMLExecutionProvider}); +} + TEST(LayerNormTest, LayerNorm_Scale) { OpTester test("LayerNormalization"); test.AddAttribute("epsilon", 1e-05f); diff --git a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc index fefd5722054de..ea8537f243f5d 100644 --- a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc +++ b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc @@ -114,16 +114,21 @@ TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) { int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); std::array channels_last_values = {-1, 1}; for (const int channels_last : channels_last_values) { - if (enable_cuda) { + if (enable_cuda || enable_rocm) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); } + if (enable_rocm && channels_last != 0) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } + // Don't run the test if no providers are supported if (execution_providers.empty()) { continue; @@ -230,6 +235,7 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { int min_cuda_architecture = 530; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); std::array has_add_out_values = {true, false}; std::array skip_dims = {2, 4}; @@ -237,12 +243,16 @@ TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { constexpr int channels_last = 1; for (const int skip_dim : skip_dims) { for (const bool has_add_out : has_add_out_values) { - if (enable_cuda) { + if (enable_cuda || enable_rocm) { std::vector> execution_providers; if (enable_cuda && channels_last != 0) { execution_providers.push_back(DefaultCudaExecutionProvider()); } + if (enable_rocm && channels_last != 0) { + execution_providers.push_back(DefaultRocmExecutionProvider()); + } + // Don't run the test if no providers are supported if (execution_providers.empty()) { continue; diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index bf02c1741725f..e1fcf835c6043 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -42,6 +42,7 @@ #include "core/optimizer/expand_elimination.h" #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -7642,5 +7643,143 @@ TEST_F(GraphTransformationTests, GatherToSliceFusion) { } } +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion) { + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{54}}); + auto* reshape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); + builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); + + // Create Gather-1 Ops + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); + auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose 1-Ops + auto* transpose_out_1 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Gather-2 Ops + auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(-1)}); + auto* gather_out_2 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_2}, {gather_out_2}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose-2 Ops + auto* transpose_out_2 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_2}, {transpose_out_2}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Slice Ops + auto* slice_output = builder.MakeIntermediate(); + auto* starts = builder.MakeInitializer({1}, {0}); + auto* ends = builder.MakeInitializer({1}, {-2}); + auto* axes = builder.MakeInitializer({1}, {2}); + auto* steps = builder.MakeInitializer({1}, {1}); + builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); + + // Create Shape-1 Ops + auto* shape_output_1 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_1}); + + // Create Shape-2 Ops + auto* shape_output_2 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_2}); + + // Create Transpose-3 Ops + auto* transpose_out_3 = builder.MakeOutput(); + builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 2); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 0); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Split") { + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(static_cast(attrs.at("axis").i()) == 2); + } + } + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + +TEST_F(GraphTransformationTests, GatherSliceToSplitFusion_Invalid) { + { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* data_arg = builder.MakeInput({{54}}); + auto* reshape_arg = builder.MakeInput({{4}}); + auto* reshape_out = builder.MakeIntermediate({{2, 512, 73, 64}}); + builder.AddNode("Reshape", {data_arg, reshape_arg}, {reshape_out}); + + // Create Gather-1 Ops + auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(-2)}); + auto* gather_out_1 = builder.MakeIntermediate({{2, 512, 1, 64}}); + builder.AddNode("Gather", {reshape_out, gather_index_1}, {gather_out_1}) + .AddAttribute("axis", static_cast(2)); + + // Create Transpose 1-Ops + auto* transpose_out_1 = builder.MakeOutput(); + builder.AddNode("Transpose", {gather_out_1}, {transpose_out_1}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + + // Create Slice Ops + auto* slice_output = builder.MakeIntermediate(); + auto* starts = builder.MakeInitializer({1}, {0}); + auto* ends = builder.MakeInitializer({1}, {-2}); + auto* axes = builder.MakeInitializer({1}, {2}); + auto* steps = builder.MakeInitializer({1}, {1}); + builder.AddNode("Slice", {reshape_out, starts, ends, axes, steps}, {slice_output}); + + // Create Shape-1 Ops + auto* shape_output_1 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_1}); + + // Create Shape-2 Ops + auto* shape_output_2 = builder.MakeOutput(); + builder.AddNode("Shape", {slice_output}, {shape_output_2}); + + // Create Transpose-3 Ops + auto* transpose_out_3 = builder.MakeOutput(); + builder.AddNode("Transpose", {slice_output}, {transpose_out_3}) + .AddAttribute("perm", std::vector{0, 2, 1, 3}); + }; + + auto pre_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + return Status::OK(); + }; + + auto post_graph_checker = [&](Graph& graph) { + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Gather"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Slice"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Split"] == 0); + return Status::OK(); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger_, std::move(transformer), + TransformerLevel::Level1, 1, pre_graph_checker, post_graph_checker)); + } +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 7cfbe0a84e3e6..3874901f86387 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -128,6 +128,7 @@ namespace perftest { "\t\t The number of affinities must be equal to intra_op_num_threads - 1\n\n" "\t-D [Disable thread spinning]: disable spinning entirely for thread owned by onnxruntime intra-op thread pool.\n" "\t-Z [Force thread to stop spinning between runs]: disallow thread from spinning during runs to reduce cpu usage.\n" + "\t-n [Exit after session creation]: allow user to measure session creation time to measure impact of enabling any initialization optimizations.\n" "\t-h: help\n"); } #ifdef _WIN32 @@ -190,7 +191,7 @@ static bool ParseSessionConfigs(const std::string& configs_string, /*static*/ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int argc, ORTCHAR_T* argv[]) { int ch; - while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqz"))) != -1) { + while ((ch = getopt(argc, argv, ORT_TSTR("b:m:e:r:t:p:x:y:c:d:o:u:i:f:F:S:T:C:AMPIDZvhsqzn"))) != -1) { switch (ch) { case 'f': { std::basic_string dim_name; @@ -373,6 +374,9 @@ static bool ParseSessionConfigs(const std::string& configs_string, case 'Z': test_config.run_config.disable_spinning_between_run = true; break; + case 'n': + test_config.run_config.exit_after_session_creation = true; + break; case '?': case 'h': default: diff --git a/onnxruntime/test/perftest/main.cc b/onnxruntime/test/perftest/main.cc index 36f08167c2217..43bf54963cabb 100644 --- a/onnxruntime/test/perftest/main.cc +++ b/onnxruntime/test/perftest/main.cc @@ -43,6 +43,13 @@ int real_main(int argc, char* argv[]) { } std::random_device rd; perftest::PerformanceRunner perf_runner(env, test_config, rd); + + // Exit if user enabled -n option so that user can measure session creation time + if (test_config.run_config.exit_after_session_creation) { + perf_runner.LogSessionCreationTime(); + return 0; + } + auto status = perf_runner.Run(); if (!status.IsOK()) { printf("Run failed:%s\n", status.ErrorMessage().c_str()); diff --git a/onnxruntime/test/perftest/performance_runner.cc b/onnxruntime/test/perftest/performance_runner.cc index 9f2cbcf6a21f1..37bf80c80e90b 100644 --- a/onnxruntime/test/perftest/performance_runner.cc +++ b/onnxruntime/test/perftest/performance_runner.cc @@ -115,6 +115,11 @@ void PerformanceResult::DumpToFile(const std::basic_string& path, boo } } +void PerformanceRunner::LogSessionCreationTime() { + std::chrono::duration session_create_duration = session_create_end_ - session_create_start_; + std::cout << "\nSession creation time cost: " << session_create_duration.count() << " s\n"; +} + Status PerformanceRunner::Run() { if (!Initialize()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "failed to initialize."); diff --git a/onnxruntime/test/perftest/performance_runner.h b/onnxruntime/test/perftest/performance_runner.h index da2df9c39f44c..cb1cb661550a7 100644 --- a/onnxruntime/test/perftest/performance_runner.h +++ b/onnxruntime/test/perftest/performance_runner.h @@ -46,6 +46,8 @@ class PerformanceRunner { ~PerformanceRunner(); Status Run(); + void LogSessionCreationTime(); + inline const PerformanceResult& GetResult() const { return performance_result_; } inline void SerializeResult() const { diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 5a49414a49004..74c8eb472cb3e 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -63,6 +63,7 @@ struct RunConfig { std::string intra_op_thread_affinities; bool disable_spinning = false; bool disable_spinning_between_run = false; + bool exit_after_session_creation = false; }; struct PerformanceTestConfig { diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 5e746ed0c62d4..d35e5c78cfd69 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -5,6 +5,7 @@ #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" #include "test/common/dnnl_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" #include "core/util/math.h" #include #include @@ -786,13 +787,20 @@ TEST(MathOpTest, Sqrt_Float) { test.Run(); } -#if defined(USE_DNNL) +#if defined(USE_DNNL) || defined(USE_CUDA) TEST(MathOpTest, Sqrt_bfloat16) { #ifdef USE_DNNL if (!DnnlHasBF16Support()) { LOGS_DEFAULT(WARNING) << "Hardware does NOT support BF16"; return; } +#endif +#ifdef USE_CUDA + int min_cuda_architecture = 530; + if (!HasCudaEnvironment(min_cuda_architecture)) { + LOGS_DEFAULT(WARNING) << "Hardware does NOT support BFP16"; + return; + } #endif OpTester test_bf16("Sqrt", 13); // only version 13 support bf16 for sqrt test_bf16.AddInput("X", {2, 3}, @@ -804,6 +812,9 @@ TEST(MathOpTest, Sqrt_bfloat16) { std::vector> execution_providers; #if defined(USE_DNNL) execution_providers.push_back(DefaultDnnlExecutionProvider()); +#endif +#ifdef USE_CUDA + execution_providers.push_back(DefaultCudaExecutionProvider()); #endif test_bf16.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_test.cc index 13d4546d669e3..b6a760f7041ad 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_test.cc @@ -9,8 +9,8 @@ namespace test { template struct ConvOp { - const std::vector input_dims; - const std::vector kernel_shape; + std::vector input_dims; + std::vector kernel_shape; int64_t channels; int64_t group = 1; bool bias = false; @@ -52,20 +52,31 @@ struct ConvOp { }; TYPED_TEST(CudaNhwcTypedTest, ConvNhwcBias) { - auto op = ConvOp{.input_dims = {1, 16, 64, 64}, .kernel_shape = {3, 3}, .channels = 16, .bias = true}; + auto op = ConvOp{}; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.bias = true; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvNhwcGroupNoBias) { - auto op = ConvOp{.input_dims = {1, 16, 64, 64}, .kernel_shape = {3, 3}, .channels = 16, .group = 4}; + auto op = ConvOp{}; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.group = 4; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvNhwcPadding) { - auto op = - ConvOp{.input_dims = {2, 4, 64, 64}, .kernel_shape = {3, 3}, .channels = 4, .padding = {4, 4, 4, 4}}; + auto op = ConvOp{}; + op.input_dims = {2, 4, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 4; + op.padding = {4, 4, 4, 4}; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc index 6514feadf0ff7..786b2cb4cedc4 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc @@ -9,8 +9,8 @@ namespace test { template struct ConvTransposeOp { - const std::vector input_dims; - const std::vector kernel_shape; + std::vector input_dims; + std::vector kernel_shape; int64_t channels; int64_t group = 1; bool bias = false; @@ -60,15 +60,21 @@ struct ConvTransposeOp { }; TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcGroupNoBias) { - auto op = - ConvTransposeOp{.input_dims = {8, 8, 32, 32}, .kernel_shape = {3, 3}, .channels = 16, .group = 4}; + auto op = ConvTransposeOp{}; + op.input_dims = {8, 8, 32, 32}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.group = 4; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { - auto op = - ConvTransposeOp{.input_dims = {1, 8, 80, 80}, .kernel_shape = {5, 5}, .channels = 16, .bias = true}; + auto op = ConvTransposeOp{}; + op.input_dims = {1, 8, 80, 80}; + op.kernel_shape = {5, 5}; + op.channels = 16; + op.bias = true; if (HasCudaEnvironment(800)) { MAKE_PROVIDERS_EPS(1e-2) @@ -78,21 +84,23 @@ TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcPad) { - auto op = ConvTransposeOp{.input_dims = {1, 16, 8, 8}, - .kernel_shape = {3, 3}, - .channels = 32, - .padding = {2, 2, 2, 2}, - .output_padding = {}}; + auto op = ConvTransposeOp{}; + op.input_dims = {1, 16, 8, 8}; + op.kernel_shape = {3, 3}; + op.channels = 32; + op.padding = {2, 2, 2, 2}; + op.output_padding = {}; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcOutPad) { - auto op = ConvTransposeOp{.input_dims = {1, 32, 8, 8}, - .kernel_shape = {3, 3}, - .channels = 32, - .strides = {2, 2}, - .output_padding = {1, 1, 1, 1}}; + auto op = ConvTransposeOp{}; + op.input_dims = {1, 32, 8, 8}; + op.kernel_shape = {3, 3}; + op.channels = 32; + op.strides = {2, 2}; + op.output_padding = {1, 1, 1, 1}; MAKE_PROVIDERS_EPS_TYPE(TypeParam) } diff --git a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h index 2c942bb790096..82b6a286409cd 100644 --- a/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h +++ b/onnxruntime/test/providers/cuda/nhwc/nhwc_cuda_helper.h @@ -16,11 +16,13 @@ #define MAKE_PROVIDERS_EPS(eps) \ std::vector> execution_providers; \ - OrtCUDAProviderOptionsV2 nhwc = {.prefer_nhwc = true}; \ + OrtCUDAProviderOptionsV2 nhwc{}; \ + nhwc.prefer_nhwc = true; \ execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc)); \ \ double error_tolerance = eps; \ - OrtCUDAProviderOptionsV2 nchw = {.prefer_nhwc = false}; \ + OrtCUDAProviderOptionsV2 nchw{}; \ + nchw.prefer_nhwc = false; \ auto source_ep = CudaExecutionProviderWithOptions(&nchw); \ auto test = op.get_test(); \ test->CompareEPs(std::move(source_ep), execution_providers, error_tolerance); diff --git a/onnxruntime/test/providers/cuda/nhwc/norm_test.cc b/onnxruntime/test/providers/cuda/nhwc/norm_test.cc index 52da8ba557c2d..40f69e3bd5b4f 100644 --- a/onnxruntime/test/providers/cuda/nhwc/norm_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/norm_test.cc @@ -9,7 +9,7 @@ namespace test { template struct BatchNormOp { - const std::vector input_dims; + std::vector input_dims; std::unique_ptr get_test() { // create rand inputs @@ -40,9 +40,8 @@ struct BatchNormOp { }; TYPED_TEST(CudaNhwcTypedTest, BatchNormNhwc) { - auto op = BatchNormOp{ - .input_dims = {4, 16, 64, 64}, - }; + auto op = BatchNormOp{}; + op.input_dims = {4, 16, 64, 64}; MAKE_PROVIDERS() } diff --git a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc index e0d59901da80c..426170b9588f1 100644 --- a/onnxruntime/test/providers/cuda/nhwc/pool_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/pool_test.cc @@ -9,9 +9,9 @@ namespace test { template struct PoolOp { - const std::string pooling_type; - const std::vector input_dims; - const std::vector kernel_shape; + std::string pooling_type; + std::vector input_dims; + std::vector kernel_shape; int64_t channels; int64_t group = 1; std::vector strides = {1, 1}; @@ -41,22 +41,21 @@ struct PoolOp { }; TYPED_TEST(CudaNhwcTypedTest, AveragePoolNhwc) { - auto op = PoolOp{ - .pooling_type = "AveragePool", - .input_dims = {1, 16, 64, 64}, - .kernel_shape = {3, 3}, - .channels = 16, - }; + auto op = PoolOp{}; + op.pooling_type = "AveragePool"; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + MAKE_PROVIDERS() } TYPED_TEST(CudaNhwcTypedTest, MaxPoolNhwc) { - auto op = PoolOp{ - .pooling_type = "MaxPool", - .input_dims = {1, 16, 64, 64}, - .kernel_shape = {3, 3}, - .channels = 16, - }; + auto op = PoolOp{}; + op.pooling_type = "MaxPool"; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; MAKE_PROVIDERS() } @@ -72,21 +71,24 @@ TYPED_TEST(CudaNhwcTypedTest, GlobalMaxPoolNhwc) { test->AddOutput("Y", output_dims, output_data); std::vector> execution_providers; - OrtCUDAProviderOptionsV2 nhwc = {.prefer_nhwc = true}; + OrtCUDAProviderOptionsV2 nhwc{}; + nhwc.prefer_nhwc = true; execution_providers.push_back(CudaExecutionProviderWithOptions(&nhwc)); double error_tolerance = 1e-3; - OrtCUDAProviderOptionsV2 nchw = {.prefer_nhwc = false}; + OrtCUDAProviderOptionsV2 nchw{}; + nchw.prefer_nhwc = false; auto source_ep = CudaExecutionProviderWithOptions(&nchw); test->CompareEPs(std::move(source_ep), execution_providers, error_tolerance); } TYPED_TEST(CudaNhwcTypedTest, AveragePoolNhwcPad) { - auto op = PoolOp{.pooling_type = "AveragePool", - .input_dims = {1, 16, 64, 64}, - .kernel_shape = {3, 3}, - .channels = 16, - .padding = {2, 2, 2, 2}}; + auto op = PoolOp{}; + op.pooling_type = "AveragePool"; + op.input_dims = {1, 16, 64, 64}; + op.kernel_shape = {3, 3}; + op.channels = 16; + op.padding = {2, 2, 2, 2}; MAKE_PROVIDERS() } diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 2f3b0e84a123e..a6422407d79fd 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -1110,6 +1110,61 @@ TEST_F(QnnHTPBackendTests, LpNormalization_u16_rank4) { kOnnxDomain, true); } + +static GetTestQDQModelFn BuildQDQConvertAddTestCase(const TestInputDef& input0_def, + const TestInputDef& input1_def) { + return [input0_def, input1_def](ModelTestBuilder& builder, std::vector>& output_qparams) { + constexpr bool use_contrib_qdq = true; + + // Input0 -> Quantize(u8) -> Dequantize(u8 to float) -> input0_after_qdq + NodeArg* input0 = MakeTestInput(builder, input0_def); + QuantParams input0_u8_qparams = GetTestInputQuantParams(input0_def); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_u8_qparams.scale, + input0_u8_qparams.zero_point, use_contrib_qdq); + + // input0_after_qdq -> Quantize(u16) -> Dequantize(u16 to float) + QuantParams input0_u16_qparams = GetTestInputQuantParams(input0_def); + NodeArg* input0_after_convert = AddQDQNodePair(builder, input0_after_qdq, input0_u16_qparams.scale, + input0_u16_qparams.zero_point, use_contrib_qdq); + + // Input1 -> Quantize(u16) -> Dequantize(u16 to float) -> input1_after_qdq + NodeArg* input1 = MakeTestInput(builder, input1_def); + QuantParams input1_qparams = GetTestInputQuantParams(input1_def); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point, use_contrib_qdq); + + // Add op -> op_output + auto* op_output = builder.MakeIntermediate(); + builder.AddNode("Add", {input0_after_convert, input1_after_qdq}, {op_output}); + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Test quantization type conversion (mixed precision) with Add. +// First input is converted from uint8_t to uint16_t. +TEST_F(QnnHTPBackendTests, Add_U8_U16_Convert) { + std::vector input0_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + std::vector input1_data = GetFloatDataInRange(-20.0f, 20.0f, 8); + TestInputDef input0_def({1, 2, 2, 2}, false, input0_data); + TestInputDef input1_def({1, 2, 2, 2}, false, input1_data); + + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + TestQDQModelAccuracy(BuildOpTestCase("Add", {input0_def, input1_def}, {}, {}, kOnnxDomain), + BuildQDQConvertAddTestCase(input0_def, input1_def), + provider_options, + 18, + ExpectedEPNodeAssignment::All); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py new file mode 100644 index 0000000000000..2b5d1f36070e5 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizer_shape_inference.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import unittest + +import numpy as np +import onnx +import onnx.helper as oh +import onnx.numpy_helper as onh + +from onnxruntime.quantization.onnx_quantizer import ONNXQuantizer +from onnxruntime.quantization.quant_utils import QuantizationMode, QuantType + + +class TestQuantizerShapeInference(unittest.TestCase): + def test_com_microsoft(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("MatMul", ["X", "W1"], ["T1"]), + oh.make_node("FusedMatMul", ["T1", "W2"], ["T2"], domain="com.microsoft"), + oh.make_node("MatMul", ["T2", "W3"], ["T3"]), + oh.make_node("MatMul", ["T3", "W4"], ["Y"]), + ], + "name", + [oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [1, 4])], + [oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [1, 4])], + [ + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W1"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W2"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W3"), + onh.from_array(np.random.randn(4, 4).astype(np.float32), "W4"), + ], + ), + opset_imports=[oh.make_opsetid("", 18), oh.make_opsetid("com.microsoft", 1)], + ) + model_shaped = onnx.shape_inference.infer_shapes(model) + shaped_results = set(t.name for t in model_shaped.graph.value_info) + # every result after T1 depends on T2 coming from a node com.microsoft, + # shape_inference cannot go beyond this point + self.assertEqual(shaped_results, {"T1"}) + + # first try: checks it raises an exception + quantizer = ONNXQuantizer( + model, + False, # per_channel + False, # reduce_range + QuantizationMode.IntegerOps, # mode + False, # static + QuantType.QInt8, # weight_type, + QuantType.QUInt8, # dynamic activation only supports uint8 + None, + [], # nodes_to_quantize, + [], # nodes_to_exclude + ["MatMul"], # op_types_to_quantize, + {"MatMulConstBOnly": True}, # extra_options, + # {'DefaultTensorType': 1, } + ) + + with self.assertRaises(RuntimeError) as e: + quantizer.quantize_model() + self.assertIn("Unable to find data type for weight_name=", str(e)) + + # second try: checks it works + quantizer = ONNXQuantizer( + model, + False, # per_channel + False, # reduce_range + QuantizationMode.IntegerOps, # mode + False, # static + QuantType.QInt8, # weight_type, + QuantType.QUInt8, # dynamic activation only supports uint8 + None, + [], # nodes_to_quantize, + [], # nodes_to_exclude + ["MatMul"], # op_types_to_quantize, + { + "MatMulConstBOnly": True, + "DefaultTensorType": 1, + }, + ) + + model = quantizer.quantize_model() + ops = {n.op_type for n in model.graph.node} + self.assertEqual(ops, {"Cast", "FusedMatMul", "MatMulInteger", "DynamicQuantizeLinear", "Mul"}) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxruntime/test/python/quantization/test_subgraph.py b/onnxruntime/test/python/quantization/test_subgraph.py new file mode 100644 index 0000000000000..c425bf956f976 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_subgraph.py @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import tempfile +import unittest +import urllib.request + +import onnx + +from onnxruntime.quantization import quantize_dynamic + + +class TestDynamicQuantizationSubgraph(unittest.TestCase): + def test_dynamic_quantization_subgraph(self): + with tempfile.TemporaryDirectory() as tmpdir: + onnx_path = os.path.join(tmpdir, "decoder_model_merged.onnx") + quantized_onnx_path = os.path.join(tmpdir, "decoder_model_merged_quantized.onnx") + urllib.request.urlretrieve( + "https://huggingface.co/fxmarty/t5-tiny-onnx-testing/resolve/main/decoder_model_merged.onnx", onnx_path + ) + + quantize_dynamic( + model_input=onnx_path, + model_output=quantized_onnx_path, + per_channel=True, + op_types_to_quantize=[ + "Conv", + "MatMul", + "Attention", + "LSTM", + "Gather", + "Transpose", + "EmbedLayerNormalization", + ], + extra_options={"EnableSubgraph": True}, + ) + model = onnx.load(quantized_onnx_path) + + # The initializer `shared.weight_merged_0` is attached to the top-level graph, and used in a Gather node in each subgraphs. + # We expect the quantized Gather (after which a DequantizeLinear is attached) initializer to also be attached to the top-level graph. + found_gather_quantized = False + for initializer in model.graph.initializer: + if initializer.name == "shared.weight_merged_0_quantized": + found_gather_quantized = True + break + self.assertTrue(found_gather_quantized) + + found_gather_scale = False + for initializer in model.graph.initializer: + if initializer.name == "shared.weight_merged_0_scale": + found_gather_scale = True + break + self.assertTrue(found_gather_scale) + + # No initializers related to the Gather node should be attached to the subgraphs. + for node in model.graph.node: + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + for initializer in attr.g.initializer: + self.assertTrue("shared.weight" not in initializer.name) diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index 40ea8cf774918..33ec1bd7728fe 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -381,22 +381,23 @@ def test_logits_processor(self): @pytest.mark.slow def test_cross_qk_overall(self): - decoder_input_ids = [ - "--chain_model", - "--collect_cross_qk", - "--output_cross_qk", - "--use_forced_decoder_ids", - "--extra_decoding_ids", - "--output_no_speech_probs", + cross_qk_input_args = [ "--use_vocab_mask", "--use_prefix_vocab_mask", + "--use_forced_decoder_ids", "--use_logits_processor", + "--collect_cross_qk", + "--extra_decoding_ids", ] - self.run_configs(decoder_input_ids) + cross_qk_output_args = [ + "--output_cross_qk", + "--output_no_speech_probs", + ] + self.run_configs(cross_qk_input_args + cross_qk_output_args) @pytest.mark.slow def test_openai_impl_whisper(self): - optional_args = ["--model_impl", "openai", "--chain_model", "--use_whisper_beamsearch"] + optional_args = ["--model_impl", "openai"] self.run_configs(optional_args) diff --git a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py index 77ce09d7e793b..7892000ae45a0 100644 --- a/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py +++ b/onnxruntime/test/python/transformers/test_whisper_timestamp_processor.py @@ -50,7 +50,7 @@ def run_timestamp(self, provider: str): ort_out = sess.run(None, ort_inputs) ort_out_tensor = torch.from_numpy(ort_out[0]) ort_transcription = processor.batch_decode( - ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True + ort_out_tensor[0][0].view(1, -1), skip_special_tokens=True, output_offsets=True, decode_with_timestamps=True ) print(ort_transcription) expected_transcription = [ @@ -58,7 +58,7 @@ def run_timestamp(self, provider: str): "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", "offsets": [ { - "text": "<|0.00|> Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.<|5.44|>", + "text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.", "timestamp": (0.0, 5.44), } ], diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index a94f7b5b707c7..40b40136af1af 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -208,12 +208,18 @@ std::unique_ptr DefaultRocmExecutionProvider(bool test_tunab } std::unique_ptr DefaultCoreMLExecutionProvider() { -// For any non - macOS system, CoreML will only be used for ort model converter -// Make it unavailable here, you can still manually append CoreML EP to session for model conversion + // To manually test CoreML model generation on a non-macOS platform, comment out the `&& defined(__APPLE__)` below. + // The test will create a model but execution of it will obviously fail. + // To test creating an ML Program, set the environment variable COREML_EP_TEST_MLPROGRAM to any value. #if defined(USE_COREML) && defined(__APPLE__) // We want to run UT on CPU only to get output value without losing precision uint32_t coreml_flags = 0; coreml_flags |= COREML_FLAG_USE_CPU_ONLY; + + if (!Env::Default().GetEnvironmentVar("COREML_EP_TEST_MLPROGRAM").empty()) { + coreml_flags |= COREML_FLAG_CREATE_MLPROGRAM; + } + return CoreMLProviderFactoryCreator::Create(coreml_flags)->CreateProvider(); #else return nullptr; diff --git a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc index 894fe3b052fb2..0b68dc65e41cd 100644 --- a/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc +++ b/orttraining/orttraining/core/optimizer/graph_transformer_utils.cc @@ -24,6 +24,7 @@ #include "core/optimizer/fast_gelu_fusion.h" #include "core/optimizer/free_dim_override_transformer.h" #include "core/optimizer/gather_fusion.h" +#include "core/optimizer/gather_slice_fusion.h" #include "core/optimizer/gelu_approximation.h" #include "core/optimizer/gelu_fusion.h" #include "core/optimizer/gemm_activation_fusion.h" @@ -140,6 +141,7 @@ std::vector> GeneratePreTrainingTransformers( transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); transformers.emplace_back(std::make_unique(compatible_eps)); + transformers.emplace_back(std::make_unique(compatible_eps)); // If a model with Q, DQ nodes is being used for the purpose of training, it must be for // Quantization Aware Training. So, replace QDQ nodes with FakeQuant. transformers.emplace_back(std::make_unique(compatible_eps)); diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index fbf1b7c2bac42..4a03465cf2ead 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -39,7 +39,7 @@ def _defined_from_envvar(name, default_value, warn=True): # NOTE: To *change* values in runtime, import onnxruntime.training.ortmodule and # assign them new values. Importing them directly do not propagate changes. ################################################################################ -ONNX_OPSET_VERSION = 15 +ONNX_OPSET_VERSION = 17 MINIMUM_RUNTIME_PYTORCH_VERSION_STR = "1.8.1" ORTMODULE_TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__), "torch_cpp_extensions") _FALLBACK_INIT_EXCEPTION = None diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 9288027f0188c..f81aef5f6b9c4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -821,3 +821,27 @@ def upsample_bicubic2d(g, input, output_size, align_corners, scale_factors): operator_s="upsample_bicubic2d", overload_name_s="vec", ) + + +@register_symbolic("layer_norm") +@parse_args("v", "is", "v", "v", "f", "none") +def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): + # normalized_shape: input shape from an expected input of size + # axis: The first normalization dimension. + # layer_norm normalizes on the last D dimensions, + # where D is the size of normalized_shape + axis = -len(normalized_shape) + + res, new_running_mean, new_running_var = g.op( + "LayerNormalization", + input, + weight, + bias, + epsilon_f=eps, + axis_i=axis, + outputs=3, # force all 3 outputs to be exported in training mode + operator_s="layer_norm", + overload_name_s="vec", + ) + + return res diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index cc533e549db92..73c32a2f51e41 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -196,18 +196,20 @@ def backward(ctx, *grad_outputs): # Run and get results backward_outputs = C.OrtValueVector() - self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) - # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not - # affect peak memory usage in a subsequent graph run. - del ctx.run_info.state - - # Fast version: all backward_outputs are converted first. - # This version only works if backward_outputs is an OrtValueVector. - transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) - - self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) - - return tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) + try: + self._execution_agent.run_backward(backward_inputs, backward_outputs, ctx.run_info.state) + # Destroy the state immediately (as opposed to be at the mercy of garbage collector) so it does not + # affect peak memory usage in a subsequent graph run. + + # Fast version: all backward_outputs are converted first. + # This version only works if backward_outputs is an OrtValueVector. + transferred_backward_outputs = _utils._ortvalues_to_torch_tensor(backward_outputs, self._device) + + self._runtime_inspector.memory_ob.inspect_memory(Phase.POST_BACKWARD) + res = tuple(transferred_backward_outputs[idx] if idx != -1 else None for idx in self._gradient_map) + return res + finally: + del ctx.run_info.state return _ORTModuleFunction diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py index fa72f3b134917..898c242bb3c32 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py @@ -23,7 +23,7 @@ cur_file_dir, ] -extra_compile_args = {"cxx": ["-O3"]} +extra_compile_args = {"cxx": ["-O3", "-std=c++17"]} setup( name="torch_interop_utils", ext_modules=[ diff --git a/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc b/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc index cf510ea43c89f..509937bdd0c3a 100644 --- a/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/compute_optimizer_test.cc @@ -135,7 +135,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_Allowed) { } }; - std::vector opsets{12, 13, 14, 15}; + std::vector opsets{12, 13, 14, 15, 17}; for (auto opset : opsets) { std::unique_ptr transformer = std::make_unique(compatible_eps, std::vector{"label"}); @@ -206,7 +206,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_NotAllowed_LabelNameNotMat } }; - std::vector opsets{12, 13, 14, 15}; + std::vector opsets{12, 13, 14, 15, 17}; for (auto opset : opsets) { std::unique_ptr transformer = std::make_unique(compatible_eps, std::vector{"label"}); @@ -277,7 +277,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_NotAllowed_ReduceNone) { } }; - std::vector opsets{12, 13, 14, 15}; + std::vector opsets{12, 13, 14, 15, 17}; for (auto opset : opsets) { std::unique_ptr transformer = std::make_unique(compatible_eps, std::vector{"label"}); @@ -344,7 +344,7 @@ TEST(ComputeOptimizerTests, InsertGatherBeforeSceLoss_NotAllowed_NoIgnoreIndex) } }; - std::vector opsets{12, 13, 14, 15}; + std::vector opsets{12, 13, 14, 15, 17}; for (auto opset : opsets) { std::unique_ptr transformer = std::make_unique(compatible_eps, std::vector{"label"}); diff --git a/orttraining/orttraining/test/optimizer/graph_transform_test.cc b/orttraining/orttraining/test/optimizer/graph_transform_test.cc index b774fec11cc8d..bab7c09839273 100644 --- a/orttraining/orttraining/test/optimizer/graph_transform_test.cc +++ b/orttraining/orttraining/test/optimizer/graph_transform_test.cc @@ -1523,7 +1523,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionThreeInputs) { builder.AddNode("Identity", {add2_out}, {graph_out}); }; - const std::vector opsets{12, 13, 14, 15}; + const std::vector opsets{12, 13, 14, 15, 17}; for (auto& opset_version : opsets) { std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), @@ -1616,7 +1616,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionThreeInputs_LastAddNotHaveScaleI builder.AddNode("Identity", {add2_out}, {graph_out}); }; - const std::vector opsets{12, 13, 14, 15}; + const std::vector opsets{12, 13, 14, 15, 17}; for (auto& opset_version : opsets) { std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), @@ -1710,7 +1710,7 @@ TEST_F(GraphTransformationTests, ScaledSumFusionTwoInputs) { builder.AddNode("Identity", {add1_out}, {graph_output2}); }; - const std::vector opsets{12, 13, 14, 15}; + const std::vector opsets{12, 13, 14, 15, 17}; for (auto& opset_version : opsets) { std::unique_ptr transformer = std::make_unique(); ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, opset_version, *logger_, std::move(transformer), diff --git a/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc b/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc index ea05b29c8668b..a1629eb73eeb6 100644 --- a/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc +++ b/orttraining/orttraining/test/optimizer/shape_optimizer_test.cc @@ -67,7 +67,7 @@ TEST(ShapeOptimizerTests, Shape15CannotFold) { return Status::OK(); }; - std::vector opset_candidates{15}; + std::vector opset_candidates{15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> identity_input_shape; @@ -145,7 +145,7 @@ TEST(ShapeOptimizerTests, Shape15) { return Status::OK(); }; - std::vector opset_candidates{15}; + std::vector opset_candidates{15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> identity_input_shape; @@ -218,7 +218,7 @@ TEST(ShapeOptimizerTests, Shape15TakesGraphInput) { return Status::OK(); }; - std::vector opset_candidates{15}; + std::vector opset_candidates{15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> shape_input_shape; @@ -289,7 +289,7 @@ TEST(ShapeOptimizerTests, Shape15GeneratesGraphOutput) { return Status::OK(); }; - std::vector opset_candidates{15}; + std::vector opset_candidates{15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> identity_input_shape; @@ -366,7 +366,7 @@ TEST(ShapeOptimizerTests, Slice) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> shape_input_shape; @@ -446,7 +446,7 @@ TEST(ShapeOptimizerTests, SliceGeneratesGraphOutput) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> shape_input_shape; @@ -530,7 +530,7 @@ TEST(ShapeOptimizerTests, Gather) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> shape_input_shape; @@ -639,7 +639,7 @@ TEST(ShapeOptimizerTests, ConcreteDimUsedBySlice) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> dropout_input_shape; @@ -810,7 +810,7 @@ TEST(ShapeOptimizerTests, ConcreteDimUsedByGatherSlice) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> reshape_input_shape; @@ -976,7 +976,7 @@ TEST(ShapeOptimizerTests, SymbolicDimUsedByGather_ConcreteDimUsedByGather) { return Status::OK(); }; - std::vector opset_candidates{10, 11, 12, 13, 14, 15}; + std::vector opset_candidates{10, 11, 12, 13, 14, 15, 17}; for (auto opset : opset_candidates) { auto build_test_case = [&](ModelTestBuilder& builder) { std::vector> reshape_input_shape; diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 51aa1564cbfbe..365c2bb8ebe0e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -34,7 +34,7 @@ from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient from onnxruntime.training.ortmodule.options import _SkipCheck -DEFAULT_OPSET = 15 +DEFAULT_OPSET = 17 # PyTorch model definitions for tests @@ -5280,7 +5280,7 @@ def run_step(model, x): assert ort_model._torch_module._execution_manager(True)._runtime_options.onnx_opset_version == 13 -@pytest.mark.parametrize("opset_version", [12, 13, 14, 15]) +@pytest.mark.parametrize("opset_version", [12, 13, 14, 15, 17]) def test_opset_version_change(opset_version): original_env = None if "ORTMODULE_ONNX_OPSET_VERSION" in os.environ: diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py index 4f0925c5c855b..2f240406b25b9 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_onnx_ops.py @@ -79,7 +79,7 @@ def run_step(model, x): for onnx_model in [onnx_graph_inf, onnx_graph_train]: for oimp in onnx_model.opset_import: if oimp.domain == "": - self.assertEqual(oimp.version, 15) + self.assertEqual(oimp.version, 17) # Needs to match latest default ORTModule opset if op_grad_type is not None: if isinstance(op_grad_type, tuple): text = str(onnx_graph_train) diff --git a/orttraining/orttraining/test/python/qat_poc_example/README.md b/orttraining/orttraining/test/python/qat_poc_example/README.md index 6840e98bd9c86..05072b410b730 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/README.md +++ b/orttraining/orttraining/test/python/qat_poc_example/README.md @@ -48,7 +48,7 @@ We use `onnxruntime.training.onnxblock` to perform the above operations to get t > **_NOTE:_** As of this writing, ORT does not have its own `"Observers"`. Instead, we rely on the `onnxruntime.quantization` tool to quantize the model and give us an initial estimate of the quantization parameters using its calibration process. Here the calibration process is used as a substitute for the observers to present the POC. -> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag AddQDQPairToWeight=True` +> **_NOTE:_** Typically, the weights in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since weights are quantized. However, QAT requires weights and biases to be non quantized. We ensure that the weights have dedicated QDQ pair by passing in the flag `AddQDQPairToWeight=True` > **_NOTE:_** Typically, the bias term in the statically quantized onnx model is associated with a DQ node only (not the QDQ pair) since it is quantized as int32 as opposed to int8. So, we disable quantizing the bias term using the flag QuantizeBias=False` diff --git a/orttraining/orttraining/test/python/qat_poc_example/model.py b/orttraining/orttraining/test/python/qat_poc_example/model.py index 91d7ccd7294f5..601362a59e379 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/model.py +++ b/orttraining/orttraining/test/python/qat_poc_example/model.py @@ -5,7 +5,7 @@ import onnx import torch -import onnxruntime.training.onnxblock as onnxblock +from onnxruntime.training import artifacts class MNIST(torch.nn.Module): @@ -96,42 +96,26 @@ def create_training_artifacts(model_path, artifacts_dir, model_prefix): 4. The checkpoint file """ - class MNISTWithLoss(onnxblock.TrainingModel): - def __init__(self): - super().__init__() - self.loss = onnxblock.loss.CrossEntropyLoss() - - def build(self, output_name): - return self.loss(output_name) - - mnist_with_loss = MNISTWithLoss() - onnx_model, eval_model, optimizer_model = onnx.load(model_path), None, None - - # Build the training and eval graphs - logging.info("Using onnxblock to create the training artifacts.") - with onnxblock.onnx_model(onnx_model) as model_accessor: - _ = mnist_with_loss(onnx_model.graph.output[0].name) - eval_model = model_accessor.eval_model - - # Build the optimizer graph - optimizer = onnxblock.optim.AdamW() - with onnxblock.onnx_model() as accessor: - _ = optimizer(mnist_with_loss.parameters()) - optimizer_model = accessor.model + onnx_model = onnx.load(model_path) + + requires_grad = [ + param.name + for param in onnx_model.graph.initializer + if (not param.name.endswith("_scale") and not param.name.endswith("_zero_point")) + ] + artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + artifact_directory=artifacts_dir, + prefix=model_prefix, + ) # Create the training artifacts - train_model_path = os.path.join(artifacts_dir, f"{model_prefix}_train.onnx") - logging.info(f"Saving the training model to {train_model_path}.") - onnx.save(onnx_model, train_model_path) - eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}_eval.onnx") - logging.info(f"Saving the eval model to {eval_model_path}.") - onnx.save(eval_model, eval_model_path) - optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}_optimizer.onnx") - logging.info(f"Saving the optimizer model to {optimizer_model_path}.") - onnx.save(optimizer_model, optimizer_model_path) - trainable_params, non_trainable_params = mnist_with_loss.parameters() - checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}_checkpoint.ckpt") - logging.info(f"Saving the checkpoint to {checkpoint_path}.") - onnxblock.save_checkpoint((trainable_params, non_trainable_params), checkpoint_path) + train_model_path = os.path.join(artifacts_dir, f"{model_prefix}training_model.onnx") + eval_model_path = os.path.join(artifacts_dir, f"{model_prefix}eval_model.onnx") + optimizer_model_path = os.path.join(artifacts_dir, f"{model_prefix}optimizer_model.onnx") + checkpoint_path = os.path.join(artifacts_dir, f"{model_prefix}checkpoint") return train_model_path, eval_model_path, optimizer_model_path, checkpoint_path diff --git a/orttraining/orttraining/test/python/qat_poc_example/qat.py b/orttraining/orttraining/test/python/qat_poc_example/qat.py index 51a15475ee911..dcc9e116fda7d 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/qat.py +++ b/orttraining/orttraining/test/python/qat_poc_example/qat.py @@ -46,7 +46,7 @@ ) logging.info("Preparing the training artifacts for QAT.") - training_model_name = "mnist_qat" + training_model_name = "mnist_qat_" artifacts_dir = os.path.join(model_dir, "training_artifacts") utils.makedir(artifacts_dir) training_artifacts = create_training_artifacts( diff --git a/orttraining/orttraining/test/python/qat_poc_example/train.py b/orttraining/orttraining/test/python/qat_poc_example/train.py index 9a429d2adc6f1..a25c071c58a48 100644 --- a/orttraining/orttraining/test/python/qat_poc_example/train.py +++ b/orttraining/orttraining/test/python/qat_poc_example/train.py @@ -26,14 +26,10 @@ def _train_epoch(model, optimizer, train_loader): model.train() cumulative_loss = 0 for data, target in train_loader: - forward_inputs = [ - data.reshape(len(data), 784).numpy(), - target.numpy().astype(np.int32), - ] - train_loss = model(forward_inputs) + train_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64)) optimizer.step() model.lazy_reset_grad() - cumulative_loss += train_loss[0] + cumulative_loss += train_loss return cumulative_loss / len(train_loader) @@ -43,12 +39,8 @@ def _eval(model, test_loader): model.eval() cumulative_loss = 0 for data, target in test_loader: - forward_inputs = [ - data.reshape(len(data), 784).numpy(), - target.numpy().astype(np.int32), - ] - test_loss = model(forward_inputs) - cumulative_loss += test_loss[0] + test_loss = model(data.reshape(len(data), 784).numpy(), target.numpy().astype(np.int64)) + cumulative_loss += test_loss return cumulative_loss / len(test_loader) @@ -65,7 +57,7 @@ def train_model(qat_train_model, qat_eval_model, qat_optimizer_model, qat_checkp train_loader, test_loader = _get_dataloaders("data", batch_size) # Load the checkpoint state. - state = orttraining.CheckpointState(qat_checkpoint) + state = orttraining.CheckpointState.load_checkpoint(qat_checkpoint) # Create the training module. model = orttraining.Module(qat_train_model, state, qat_eval_model) diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index dcf733153bdad..8b2bc7e2ef2b3 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -196,6 +196,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, MixedPrecisionScale); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, LayerNormalizationGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float_BFloat16, SimplifiedLayerNormalizationGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16_float, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2); @@ -452,6 +453,7 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc index f6c58445c0a5d..fc5d9b65d0f89 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc @@ -114,7 +114,8 @@ Status ConvGrad::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& ORT_RETURN_IF_ERROR(args_.y_tensor.Set(dy_dims, args_.params.data_type)); ORT_RETURN_IF_ERROR(args_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args_.params.data_type)); + args_.params.data_type, + UseTF32())); if (dB) { const TensorShape& db_shape = dB->Shape(); diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc index 5dc16c68f6210..d23905496c9bb 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc @@ -233,11 +233,13 @@ bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const } template -Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { +Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results, bool use_tf32) { perf_results.resize(1); perf_results[0].algo = AlgoSearch::DEFAULT_ALGO; if (args.params.data_type == CUDNN_DATA_HALF) { perf_results[0].mathType = CUDNN_TENSOR_OP_MATH; + } else if (args.params.data_type == CUDNN_DATA_FLOAT && !use_tf32) { + perf_results[0].mathType = CUDNN_FMA_MATH; } else { perf_results[0].mathType = CUDNN_DEFAULT_MATH; } @@ -256,7 +258,7 @@ Status AlgoIterator::TryAll(const CUDAExecutionProvider* provider, const std::vector perf_results; ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault - ? OnlyDefaultAlgorithm(args_, perf_results) + ? OnlyDefaultAlgorithm(args_, perf_results, provider->UseTF32()) : AlgoSearch::FindAlgorithms(args_, provider, allocator, perf_results)); for (auto& algo_perf : perf_results) { if (f(algo_perf) == Status::OK()) { diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h index a2d4bf3bdc006..3fdb4306bfbbb 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h @@ -75,7 +75,7 @@ class AlgoIterator { Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, std::function f); - static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results); + static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results, bool use_tf32); private: const ConvArgs& args_; diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc index 5f7206fc121ec..d3f5a89434a48 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc @@ -182,7 +182,8 @@ Status ConvTransposeGrad::PrepareConvForwardArgs(const Tensor& X, const Tenso ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args.params.data_type)); + args.params.data_type, + UseTF32())); } return Status::OK(); @@ -287,7 +288,8 @@ Status ConvTransposeGrad::PrepareConvBackwardFilterArgs(const Tensor& X, cons ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, - args.params.data_type)); + args.params.data_type, + UseTF32())); if (dB) { const auto& b_shape = dB->Shape(); diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index e286236ba6447..f1d3702e3245e 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -181,6 +181,8 @@ def hipify(hipify_perl_path, src_file_path, dst_file_path): s = s.replace("rocm_device_prop_", "cuda_device_prop_") s = s.replace("rocm_device_arch_", "cuda_device_arch_") + s = s.replace("HipTuningContext", "RocmTuningContext") + # We want hipfft, which needs hipDataType etc, but only do this for files that have "fft" in their names # And we do this last, undoing or fixing hipify mistakes. if "fft" in src_file_path: diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 8567d595b7429..5b715bb29e5a1 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -1236,9 +1236,15 @@ def generate_build_tree( "-Donnxruntime_USE_OPENVINO_AUTO=" + ("ON" if args.use_openvino.startswith("AUTO") else "OFF"), ] - # TensorRT and OpenVINO providers currently only support - # full_protobuf option. - if args.use_full_protobuf or args.use_tensorrt or args.use_openvino or args.use_vitisai or args.gen_doc: + # VitisAI and OpenVINO providers currently only support + # full_protobuf option. TensorRT provider only requires it if built with oss_parser + if ( + args.use_full_protobuf + or (args.use_tensorrt and args.use_tensorrt_oss_parser) + or args.use_openvino + or args.use_vitisai + or args.gen_doc + ): cmake_args += ["-Donnxruntime_USE_FULL_PROTOBUF=ON", "-DProtobuf_USE_STATIC_LIBS=ON"] if args.use_tvm and args.llvm_path is not None: @@ -1520,7 +1526,8 @@ def generate_build_tree( ldflags = ["/profile", "/DYNAMICBASE"] # Address Sanitizer libs do not have a Qspectre version. So they two cannot be both enabled. if not args.enable_address_sanitizer: - cflags += ["/Qspectre"] + # Also enable a special perf patch that was made for Intel Meteor Lake mobile CPUs + cflags += ["/Qspectre", "/DONNXRUNTIME_ENABLE_INTEL_METEOR_LAKE_MOBILE_PLATFORM_PERF_PATCH"] if config == "Release": cflags += ["/O2", "/Ob2", "/DNDEBUG"] elif config == "RelWithDebInfo": @@ -1624,9 +1631,11 @@ def generate_build_tree( [ *temp_cmake_args, f"-DCMAKE_BUILD_TYPE={config}", - f"-DCMAKE_PREFIX_PATH={build_dir}/{config}/installed" - if preinstalled_dir.exists() and not (args.arm64 or args.arm64ec or args.arm) - else "", + ( + f"-DCMAKE_PREFIX_PATH={build_dir}/{config}/installed" + if preinstalled_dir.exists() and not (args.arm64 or args.arm64ec or args.arm) + else "" + ), ], cwd=config_build_dir, cuda_home=cuda_home, @@ -1660,8 +1669,11 @@ def build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, targe f"/p:CL_MPCount={num_parallel_jobs}", ] elif args.cmake_generator == "Xcode": - # CMake will generate correct build tool args for Xcode - cmd_args += ["--parallel", str(num_parallel_jobs)] + build_tool_args += [ + "-parallelizeTargets", + "-jobs", + str(num_parallel_jobs), + ] else: build_tool_args += [f"-j{num_parallel_jobs}"] diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index 2b181810b0788..d37266a8e96d8 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -31,7 +31,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101 + default: qnn-v2.19.2.240210 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index b19a8b11db265..24319184dd0b8 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -204,6 +204,7 @@ jobs: --volume /data/models:/build/models:ro \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ --volume /data/onnx:/data/onnx \ + -e NVIDIA_TF32_OVERRIDE=0 \ $(Repository) \ /bin/bash -c " set -ex; \ diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index e75bb68a8bfeb..eaadc6ad728c0 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -15,6 +15,11 @@ parameters: - 8.6.1.6 - BIN +- name: UseTensorrtOssParser + displayName: Use TensorRT-OSS Parser + type: boolean + default: false + - name: ModelGroups type: object default: @@ -73,7 +78,7 @@ jobs: value: ort-image-$(Build.BuildId) steps: - - ${{ if eq(parameters.TrtVersion, 'BIN') }}: + - ${{ if and(eq(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, false)) }}: - script: 'ls -al $(trtBinsDir)' displayName: 'Show available TensorRT .tar.gz packages' @@ -83,11 +88,19 @@ jobs: - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75 --install_bin --tar_cuda_version=$(tarCudaVersion) --tar_cudnn_version=$(tarCudnnVersion) --trt_bins_dir=.' displayName: 'Install TensorRT from binaries and build latest ORT Image' workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build' - - ${{ else }}: + + # Build ORT with TensorRT built-in parser + - ${{ if and(ne(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, false)) }}: - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75' - displayName: 'Build latest ORT Image' + displayName: 'Build latest ORT Image with TensorRT built-in parser' workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build' - + + # Build ORT with TensorRT OSS parser + - ${{ if and(ne(parameters.TrtVersion, 'BIN'), eq(parameters.UseTensorrtOssParser, true)) }}: + - script: 'python3 $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build/build_image.py -r $(Build.SourcesDirectory) -i $(image) -b $(branchName) -t $(trtVersion) -a 75 --use_tensorrt_oss_parser' + displayName: 'Build latest ORT Image with TensorRT OSS parser' + workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/build' + - ${{ if eq(parameters.MemTest, true) }}: - script: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/mem_test/run_mem_test_docker.sh -d $(image) -p $(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/mem_test/ -w /code/ -l false' displayName: 'Run Memory Test' diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 0312b70d2b1d5..8fa5bdbf90931 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101 + default: qnn-v2.19.2.240210 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml index d9ab85ee80ce3..47b1e0933417e 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml @@ -13,7 +13,7 @@ stages: parameters: build_py_parameters: --enable_training --update --build torch_version: '2.0.0' - opset_version: '15' + opset_version: '17' cuda_version: '11.8' cmake_cuda_architectures: 60;61;70;75;80;86 docker_file: Dockerfile.manylinux2_28_training_cuda11_8 diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml index 422fb33eec5de..86dce7ae465fc 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml @@ -13,7 +13,7 @@ stages: parameters: build_py_parameters: --enable_training --update --build torch_version: '2.1.0' - opset_version: '15' + opset_version: '17' cuda_version: '12.2' cmake_cuda_architectures: 70;75;80;86;90 docker_file: Dockerfile.manylinux2_28_training_cuda12_2 diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 5349b1ca67ab1..6b0ae085fa4db 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -34,6 +34,11 @@ parameters: type: boolean default: true +- name: enable_windows_x64_qnn + displayName: 'Whether Windows x86_64 package with QNN EP is built.' + type: boolean + default: true + - name: build_py_parameters displayName: 'Specify extra build parameters' type: string @@ -70,5 +75,6 @@ stages: enable_mac_cpu: ${{ parameters.enable_mac_cpu }} enable_linux_arm: ${{ parameters.enable_linux_arm }} enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} + enable_windows_x64_qnn: ${{ parameters.enable_windows_x64_qnn }} build_py_parameters: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index b0509467e1689..9a38513d04a79 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: qnn-v2.18.0.240101_win + default: qnn-v2.19.2.240210_win - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml index c2ef565a6e9ee..f1418e75bffa2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/component-governance-component-detection-steps.yml @@ -5,10 +5,12 @@ parameters: default: 'succeeded' # could be 'ci_only', 'always', 'succeeded' steps: -- ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: +- ${{ if eq(variables['System.TeamProject'], 'Lotus') }}: - task: DeleteFiles@1 inputs: - contents: $(Build.BinariesDirectory)/* + SourceFolder: '$(Build.BinariesDirectory)' + contents: | + **/* displayName: 'Clean up build directory' - task: ms.vss-governance-buildtask.governance-build-task-component-detection.ComponentGovernanceComponentDetection@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 146e3e58444c1..5ac5bda8b0964 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -40,6 +40,11 @@ parameters: type: boolean default: true +- name: enable_windows_x64_qnn + displayName: 'Whether Windows x86_64 package with QNN EP is built.' + type: boolean + default: true + # TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. - name: cmake_build_type type: string @@ -459,3 +464,9 @@ stages: QNN_SDK: 'qnn-v2.18.0.240101_win' PYTHON_VERSION: '3.11' NUMPY_VERSION: '1.25.2' + + - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: + - template: py-win-x64-qnn.yml + parameters: + MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' + QNN_SDK: 'qnn-v2.18.0.240101_win' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml index 18368e59cad52..4315eae503ebd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml @@ -120,17 +120,17 @@ jobs: $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} workingDirectory: '$(Build.BinariesDirectory)' - - task: VSBuild@1 + # building with build.py so the parallelization parameters are added to the msbuild command + - task: PythonScript@0 displayName: 'Build' inputs: - solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' - platform: x64 - configuration: RelWithDebInfo - msbuildArchitecture: $(buildArch) - maximumCpuCount: true - logProjectEvents: true - workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - createLogFile: true + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --parallel --build + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} + workingDirectory: '$(Build.BinariesDirectory)' # Esrp signing - template: win-esrp-dll.yml @@ -188,7 +188,7 @@ jobs: condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) inputs: GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - template: component-governance-component-detection-steps.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml new file mode 100644 index 0000000000000..30f21e933ee36 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -0,0 +1,177 @@ +parameters: + +- name: MACHINE_POOL + type: string + default: 'Onnxruntime-QNNEP-Windows-2022-CPU' + +- name: QNN_SDK + displayName: QNN Windows SDK path + type: string + default: qnn-v2.18.0.240101_win + +- name: ENV_SETUP_SCRIPT + type: string + default: '' + +- name: BUILD_PY_PARAMETERS + displayName: > + Extra parameters to pass to build.py. Don't put newlines in here. + type: string + default: '' + +jobs: +- job: Win_py_x64_qnn_Wheels + timeoutInMinutes: 210 + workspace: + clean: all + pool: + name: ${{ parameters.MACHINE_POOL }} + strategy: + matrix: + Python38_x64: + PythonVersion: '3.8' + Python39_x64: + PythonVersion: '3.9' + Python310_x64: + PythonVersion: '3.10' + Python311_x64: + PythonVersion: '3.11' + Python312_x64: + PythonVersion: '3.12' + variables: + GRADLE_OPTS: '-Dorg.gradle.daemon=false' + VSGenerator: 'Visual Studio 17 2022' + QNN_SDK_ROOTDIR: 'C:\data\qnnsdk\${{parameters.QNN_SDK}}' + steps: + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - script: | + DIR C:\data\qnnsdk + displayName: Check available QNN SDKs + + - task: UsePythonVersion@0 + inputs: + versionSpec: $(PythonVersion) + addToPath: true + architecture: 'x64' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: false + + - task: PythonScript@0 + inputs: + scriptSource: inline + script: | + import sys + np_version = 'numpy==1.21.6' if sys.version_info < (3, 11) else 'numpy==1.24.2' + import subprocess + subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', np_version]) + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Install python modules' + + - template: download-deps.yml + + - task: PythonScript@0 + displayName: 'Update deps.txt' + inputs: + scriptPath: $(Build.SourcesDirectory)/tools/ci_build/replace_urls_in_deps.py + arguments: --new_dir $(Build.BinariesDirectory)/deps + workingDirectory: $(Build.BinariesDirectory) + + - task: PowerShell@2 + displayName: 'Install ONNX' + inputs: + filePath: '$(Build.SourcesDirectory)/tools/ci_build/github/windows/install_third_party_deps.ps1' + workingDirectory: '$(Build.BinariesDirectory)' + arguments: -cpu_arch x64 -install_prefix $(Build.BinariesDirectory)\RelWithDebInfo\installed -build_config RelWithDebInfo + + - template: set-nightly-build-option-variable-step.yml + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --use_qnn + --qnn_home $(QNN_SDK_ROOTDIR) + --enable_pybind + --parallel --update + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'x64' + configuration: RelWithDebInfo + msbuildArchitecture: 'x64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd,*.dll' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + ArtifactName: onnxruntime_qnn + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 13d4589a67cdc..dc861f7f1ed79 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101_win + default: qnn-v2.19.2.240210_win jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 6246bb83566e5..534d5c6d6135b 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: qnn-v2.18.0.240101_win + default: qnn-v2.19.2.240210_win jobs: - job: 'build' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index dd7c669c37885..e1914d5fe2f06 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -178,7 +178,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ARG PYTHON_VERSION=3.8 -ARG OPSET_VERSION=15 +ARG OPSET_VERSION=17 ARG INSTALL_DEPS_EXTRA_ARGS diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 index a6a75afb0f4c3..fed29689fbe5e 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 @@ -161,7 +161,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ARG PYTHON_VERSION=3.9 ARG TORCH_VERSION=2.0.0 -ARG OPSET_VERSION=15 +ARG OPSET_VERSION=17 ARG INSTALL_DEPS_EXTRA_ARGS #Add our own dependencies diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 index d29157daef611..e1caa141ef317 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda12_2 @@ -161,7 +161,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ARG PYTHON_VERSION=3.9 ARG TORCH_VERSION=2.1.0 -ARG OPSET_VERSION=15 +ARG OPSET_VERSION=17 ARG INSTALL_DEPS_EXTRA_ARGS #Add our own dependencies diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 index 04a6af962b5e6..f1ffba3b3e1c9 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 @@ -82,8 +82,9 @@ RUN if [ -z "$ONNXRUNTIME_COMMIT_ID" ] ; then echo "Building branch ${ONNXRUNTIM git reset --hard ${ONNXRUNTIME_COMMIT_ID} && git submodule update --recursive ; fi # Build ORT -ENV CUDA_MODULE_LOADING "LAZY" -RUN /bin/sh build.sh --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' +ENV CUDA_MODULE_LOADING "LAZY" +ARG PARSER_CONFIG="" +RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home /usr/local/cuda --cudnn_home /usr/lib/x86_64-linux-gnu/ --use_tensorrt --tensorrt_home /usr/lib/x86_64-linux-gnu/ --config Release --build_wheel --skip_tests --skip_submodule_sync --cmake_extra_defines '"CMAKE_CUDA_ARCHITECTURES='${CMAKE_CUDA_ARCHITECTURES}'"' # Switch to root to continue following steps of CI USER root diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 4767c74afd28f..496b57b417fbd 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -112,7 +112,7 @@ RUN pip install \ cerberus \ sympy \ h5py \ - datasets==1.9.0 \ + datasets==2.17.0 \ requests \ sacrebleu==1.5.1 \ sacremoses \ @@ -131,7 +131,7 @@ RUN pip install \ # Install migraphx RUN apt update && apt install -y migraphx -ENV ORTMODULE_ONNX_OPSET_VERSION=15 +ENV ORTMODULE_ONNX_OPSET_VERSION=17 ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev diff --git a/tools/python/run_CIs_for_external_pr.py b/tools/python/run_CIs_for_external_pr.py index 7a77839c4a4e7..df4e70b1e51fe 100644 --- a/tools/python/run_CIs_for_external_pr.py +++ b/tools/python/run_CIs_for_external_pr.py @@ -93,6 +93,8 @@ def main(): # checks "onnxruntime-python-checks-ci-pipeline", "onnxruntime-binary-size-checks-ci-pipeline", + # big models + "Big Models", # not currently required, but running ensures we're hitting all mobile platforms "Android CI Pipeline", "iOS CI Pipeline", diff --git a/winml/lib/Api/HardwareCoreEnumerator.cpp b/winml/lib/Api/HardwareCoreEnumerator.cpp index a89ac561f8860..b6b44690f4f6c 100644 --- a/winml/lib/Api/HardwareCoreEnumerator.cpp +++ b/winml/lib/Api/HardwareCoreEnumerator.cpp @@ -14,7 +14,7 @@ struct LogicalProcessorInformation { struct CoreCounter { uint32_t PhysicalCores = 0; - uint32_t SocDieCores = 0; + uint32_t Num2CacheCores = 0; }; static LogicalProcessorInformation GetLogicalProcessorInfos(LOGICAL_PROCESSOR_RELATIONSHIP relationship) { @@ -75,7 +75,7 @@ static CoreCounter GetNumberOPhysicalAndEngineeringCores() { read += currentProcessorInfo->Size; } - cores.SocDieCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); + cores.Num2CacheCores = CountSetBits(dwLevel2GroupMask & ~dwLevel3GroupMask); return cores; } @@ -83,8 +83,27 @@ uint32_t HardwareCoreEnumerator::DefaultIntraOpNumThreads() { // # of physical cores = # of P cores + # of E Cores + # of Soc Cores. // # of logical cores = # of P cores x 2 (if hyper threading is enabled) + # of E cores + # of Soc Cores. auto cores = GetNumberOPhysicalAndEngineeringCores(); - // We want to use the number of physical cores, but exclude soc cores - return cores.PhysicalCores - cores.SocDieCores; + +#if !defined(_M_ARM64) && !defined(__aarch64__) + const int kVendorID_Intel[3] = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" + int regs_leaf0[4]; + int regs_leaf7[4]; + __cpuid(regs_leaf0, 0); + __cpuid(regs_leaf7, 0x7); + + auto isIntel = (kVendorID_Intel[0] == regs_leaf0[1]) && (kVendorID_Intel[1] == regs_leaf0[2]) && + (kVendorID_Intel[2] == regs_leaf0[3]); + + auto isHybrid = (regs_leaf7[3] & (1 << 15)); + + if (isIntel && isHybrid) { + // We want to use the number of physical cores, but exclude soc cores + // On Intel Hybrid processors, numSocCores == cores.Num2CacheCores + return cores.PhysicalCores - cores.Num2CacheCores; + } +#endif + + return cores.PhysicalCores; } } // namespace WINMLP