From ffceed9d44f2f3efb9dd69fa75fea51163c91d91 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Fri, 30 Aug 2024 15:02:31 -0700 Subject: [PATCH] ORT 1.19.2 Release: Cherry Pick Round 1 (#21861) Approved cherry picks for ORT 1.19.2 release. --------- Co-authored-by: Yi Zhang Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com> Co-authored-by: Ye Wang <52801275+wangyems@users.noreply.github.com> Co-authored-by: Your Name Co-authored-by: Tianlei Wu Co-authored-by: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Co-authored-by: mindest <30493312+mindest@users.noreply.github.com> Co-authored-by: Changming Sun --- VERSION_NUMBER | 2 +- cmake/patches/abseil/absl_windows.patch | 13 + cmake/patches/cutlass/cutlass_3.5.0.patch | 59 +- docs/ContribOperators.md | 16 +- docs/python/README.rst | 5 + include/onnxruntime/core/graph/graph_nodes.h | 3 +- js/common/lib/version.ts | 2 +- js/common/package-lock.json | 4 +- js/common/package.json | 2 +- js/node/lib/version.ts | 2 +- js/node/package-lock.json | 6 +- js/node/package.json | 2 +- js/react_native/lib/version.ts | 2 +- js/react_native/package.json | 2 +- js/react_native/yarn.lock | 2 +- js/web/lib/version.ts | 2 +- js/web/package-lock.json | 6 +- js/web/package.json | 2 +- onnxruntime/__init__.py | 2 +- .../contrib_ops/cpu/bert/attention_common.h | 2 + .../contrib_ops/cpu/bert/attention_helper.h | 43 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 19 +- .../contrib_ops/cuda/bert/attention.cc | 2 +- .../contrib_ops/cuda/bert/attention_impl.cu | 25 +- .../bert/cutlass_fmha/fmha_launch_template.h | 2 + .../cutlass_fmha/memory_efficient_attention.h | 1 + .../cuda/bert/flash_attention/flash.h | 2 + .../cuda/bert/flash_attention/flash_api.cc | 7 + .../cuda/bert/flash_attention/flash_api.h | 2 + .../bert/flash_attention/flash_fwd_kernel.h | 4 +- .../cuda/bert/flash_attention/softmax.h | 4 +- .../cuda/bert/group_query_attention.cc | 4 +- .../cuda/bert/group_query_attention.h | 1 + .../cuda/bert/group_query_attention_helper.h | 1 + .../cuda/bert/group_query_attention_impl.cu | 7 +- .../cuda/bert/multihead_attention.cc | 14 +- .../cuda/bert/packed_attention_impl.cu | 1 + .../bert/packed_multihead_attention_impl.cu | 1 + .../cuda/collective/sharded_moe.cc | 2 +- .../moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu | 31 + .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 157 ++- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.h | 5 +- onnxruntime/contrib_ops/cuda/moe/moe.cc | 2 +- onnxruntime/contrib_ops/cuda/moe/moe_base.h | 7 + .../cuda/quantization/moe_quantization.cc | 118 ++- .../cuda/quantization/moe_quantization.h | 19 + .../core/graph/contrib_ops/bert_defs.cc | 4 + .../core/graph/contrib_ops/collective_defs.cc | 4 + .../core/graph/contrib_ops/contrib_defs.cc | 23 +- onnxruntime/core/mlas/inc/mlas.h | 1 + onnxruntime/core/mlas/lib/compute.cpp | 93 +- .../coreml/builders/impl/argmax_op_builder.cc | 9 +- .../coreml/builders/impl/cast_op_builder.cc | 5 - .../core/providers/cpu/math/softmax_shared.cc | 2 +- onnxruntime/core/providers/cpu/ml/ml_common.h | 2 +- onnxruntime/core/session/onnxruntime_c_api.cc | 2 +- .../tools/transformers/io_binding_helper.py | 3 +- .../test/mlas/bench/bench_computesoftmax.cpp | 4 +- .../test/mlas/unittest/test_softmax.cpp | 22 +- .../providers/coreml/coreml_basic_test.cc | 51 +- .../test/python/transformers/benchmark_gqa.py | 51 +- .../transformers/benchmark_gqa_windows.py | 18 +- .../test/python/transformers/benchmark_mha.py | 5 +- .../transformers/test_flash_attn_cuda.py | 94 +- .../test/python/transformers/test_gqa_cpu.py | 195 +++- .../test/python/transformers/test_mha.py | 45 +- .../transformers/test_parity_mixtral_moe.py | 361 ------- .../python/transformers/test_parity_moe.py | 922 +++++++++++++++--- .../transformers/test_sparse_attention.py | 14 +- .../testdata/coreml_argmax_cast_test.onnx | 5 +- .../test/testdata/coreml_argmax_cast_test.py | 19 +- .../coreml_argmax_unsupported_cast_test.onnx | 19 + .../ortmodule/_custom_op_symbolic_registry.py | 3 +- .../test/python/orttraining_test_dort.py | 1 - ...orttraining-py-packaging-pipeline-cuda.yml | 2 +- ...ttraining-py-packaging-pipeline-cuda12.yml | 2 +- ...py-packaging-training-cuda-stage-steps.yml | 2 +- .../templates/win-esrp-dll.yml | 32 + .../pai/rocm-ci-pipeline-env.Dockerfile | 3 +- 79 files changed, 1788 insertions(+), 847 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu delete mode 100644 onnxruntime/test/python/transformers/test_parity_mixtral_moe.py create mode 100644 onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 66e2ae6c25cd6..836ae4eda2992 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.19.1 +1.19.2 diff --git a/cmake/patches/abseil/absl_windows.patch b/cmake/patches/abseil/absl_windows.patch index 82983646527dc..c50e147aa4a7d 100644 --- a/cmake/patches/abseil/absl_windows.patch +++ b/cmake/patches/abseil/absl_windows.patch @@ -74,6 +74,19 @@ index 2d85ac74..4875d668 100644 # The decorated name was longer than the compiler limit "/wd4503", # forcing value to bool 'true' or 'false' (performance warning) +diff --git a/absl/debugging/symbolize.cc b/absl/debugging/symbolize.cc +index 638d3954..6b817075 100644 +--- a/absl/debugging/symbolize.cc ++++ b/absl/debugging/symbolize.cc +@@ -14,7 +14,7 @@ + + #include "absl/debugging/symbolize.h" + +-#ifdef _WIN32 ++#if defined(_WIN32) && !defined(NDEBUG) + #include + #if !(WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP)) || \ + WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) diff --git a/absl/debugging/symbolize_win32.inc b/absl/debugging/symbolize_win32.inc index 53a099a1..34d210d6 100644 --- a/absl/debugging/symbolize_win32.inc diff --git a/cmake/patches/cutlass/cutlass_3.5.0.patch b/cmake/patches/cutlass/cutlass_3.5.0.patch index 3b829d2f8b2cf..93b8c474af9ed 100644 --- a/cmake/patches/cutlass/cutlass_3.5.0.patch +++ b/cmake/patches/cutlass/cutlass_3.5.0.patch @@ -1,13 +1,64 @@ +diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h +index 4c80f549..34327633 100644 +--- a/examples/41_fused_multi_head_attention/kernel_forward.h ++++ b/examples/41_fused_multi_head_attention/kernel_forward.h +@@ -221,6 +221,8 @@ struct AttentionKernel { + int32_t num_batches = 0; + int32_t num_heads = 0; + ++ bool use_smooth_softmax = false; ++ + // dropout + bool use_dropout = false; + unsigned long long dropout_batch_head_rng_offset = 0; +@@ -897,7 +899,8 @@ struct AttentionKernel { + p.num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, +- kSupportsBias ? 1.0f : p.scale); ++ kSupportsBias ? 1.0f : p.scale, ++ p.use_smooth_softmax); + + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % +@@ -1166,7 +1169,8 @@ struct AttentionKernel { + int max_col, + bool is_first, + typename WarpIteratorC::TensorCoord const& tile_offset, +- float scaling) { ++ float scaling, ++ bool use_smooth_softmax) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` +@@ -1257,7 +1261,7 @@ struct AttentionKernel { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, +- [&](int accum_m) { mi_row = mi[accum_m]; }, ++ [&](int accum_m) { mi_row = mi[accum_m];}, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); +@@ -1294,7 +1298,7 @@ struct AttentionKernel { + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } +- s_prime[id] = total_row; ++ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row; + } + } + diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 964d2ff3..b366bc14 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -39,6 +39,7 @@ #include "cutlass/numeric_types.h" - + #include +#include - + #if defined(CUTLASS_ARCH_WMMA_ENABLED) #include @@ -230,8 +231,12 @@ struct inverse_square_root { @@ -19,7 +70,7 @@ index 964d2ff3..b366bc14 100644 return reinterpret_cast(result); +#else + return half_t::convert((rsqrtf(half_t::convert(lhs)))); -+#endif ++#endif #else return half_t(1.f / std::sqrt(half_t::convert(lhs))); - #endif + #endif \ No newline at end of file diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index ed9e2a0567d2f..f0bf9c6b0ea14 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2482,6 +2482,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+
smooth_softmax : int
+
Use a smooth factor in softmax.
#### Inputs (7 - 9) @@ -3022,6 +3024,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of top experts to select from expert pool
normalize_routing_weights : int
Whether to normalize routing weights
+
use_sparse_mixer : int
+
Whether to use sparse mixer
#### Inputs (5 - 8) @@ -4337,7 +4341,7 @@ This version of the operator has been available since version 1 of the 'com.micr ### **com.microsoft.QMoE** - Int4 MoE + Quantized MoE #### Version @@ -4348,10 +4352,14 @@ This version of the operator has been available since version 1 of the 'com.micr
activation_type : string
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
expert_weight_bits : int
+
Number of bits used in quantized weights. Default is 4 bits
k : int
Number of top experts to select from expert pool
normalize_routing_weights : int
Whether to normalize routing weights
+
use_sparse_mixer : int
+
Whether to use sparse mixer
#### Inputs (7 - 11) @@ -4362,19 +4370,19 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T1
-
3D input tensor with shape (num_experts, hidden_size, inter_size / 2)
+
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
fc1_scales : T
2D input tensor with shape (num_experts, inter_size)
fc1_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
fc2_experts_weights : T1
-
3D input tensor with shape (num_experts, inter_size, hidden_size / 2)
+
3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
fc2_scales : T
2D input tensor with shape (num_experts, hidden_size)
fc2_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, hidden_size)
fc3_experts_weights (optional) : T1
-
3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)
+
3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
fc3_scales (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
fc3_experts_bias (optional) : T
diff --git a/docs/python/README.rst b/docs/python/README.rst index 724002b22d80e..a0c227fdfc4fb 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,11 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime & operator++() { if (current_ < end_) { while (++current_ != end_) { if (*current_ != nullptr && (!apply_filter_ || (*filter_func_)((*current_)->Index()) == false)) break; } } + return *this; } NodeIterator operator++(int) { diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts index e7e2e07e20d07..cf791ec41a8ca 100644 --- a/js/common/lib/version.ts +++ b/js/common/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.19.1'; +export const version = '1.19.2'; diff --git a/js/common/package-lock.json b/js/common/package-lock.json index 3c3be7ded0de5..07485ef36deee 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-common", - "version": "1.19.1", + "version": "1.19.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-common", - "version": "1.19.1", + "version": "1.19.2", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/common/package.json b/js/common/package.json index bf6a4a16c7a00..b6bb5ffbb5d42 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -2,7 +2,7 @@ "license": "MIT", "type": "module", "name": "onnxruntime-common", - "version": "1.19.1", + "version": "1.19.2", "repository": { "url": "https://github.com/Microsoft/onnxruntime.git", "type": "git" diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts index e7e2e07e20d07..cf791ec41a8ca 100644 --- a/js/node/lib/version.ts +++ b/js/node/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.19.1'; +export const version = '1.19.2'; diff --git a/js/node/package-lock.json b/js/node/package-lock.json index 6a3c4a8e0fee8..b1eccab55946b 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-node", - "version": "1.19.1", + "version": "1.19.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-node", - "version": "1.19.1", + "version": "1.19.2", "hasInstallScript": true, "license": "MIT", "os": [ @@ -29,7 +29,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.19.1", + "version": "1.19.2", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/node/package.json b/js/node/package.json index cd63246eb989a..86cd72c330448 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -13,7 +13,7 @@ 3 ] }, - "version": "1.19.1", + "version": "1.19.2", "dependencies": { "onnxruntime-common": "file:../common", "tar": "^7.0.1" diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts index e7e2e07e20d07..cf791ec41a8ca 100644 --- a/js/react_native/lib/version.ts +++ b/js/react_native/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.19.1'; +export const version = '1.19.2'; diff --git a/js/react_native/package.json b/js/react_native/package.json index c89d01d9c6b24..f91cb26c31ea4 100644 --- a/js/react_native/package.json +++ b/js/react_native/package.json @@ -36,7 +36,7 @@ "registry": "https://registry.npmjs.org/" }, "source": "lib/index", - "version": "1.19.1", + "version": "1.19.2", "main": "dist/commonjs/index", "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", "files": [ diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index e69469ea2bd46..664e9ac846f6b 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -5254,7 +5254,7 @@ onetime@^5.1.0, onetime@^5.1.2: mimic-fn "^2.1.0" "onnxruntime-common@file:../common": - version "1.19.1" + version "1.19.2" open@^6.2.0: version "6.4.0" diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts index e7e2e07e20d07..cf791ec41a8ca 100644 --- a/js/web/lib/version.ts +++ b/js/web/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.19.1'; +export const version = '1.19.2'; diff --git a/js/web/package-lock.json b/js/web/package-lock.json index e06a464c3c042..72a715c24ecc6 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.19.1", + "version": "1.19.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.19.1", + "version": "1.19.2", "license": "MIT", "dependencies": { "flatbuffers": "^1.12.0", @@ -50,7 +50,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.19.1", + "version": "1.19.2", "license": "MIT", "devDependencies": { "typedoc": "^0.25.7" diff --git a/js/web/package.json b/js/web/package.json index 28f4353d12ac6..cc8300a33a89f 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -7,7 +7,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.19.1", + "version": "1.19.2", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^1.12.0", diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 0f55c4c6ce139..259fcb2f6a32a 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -7,7 +7,7 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_ or the `Github project `_. """ -__version__ = "1.19.1" +__version__ = "1.19.2" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 88127387d08ea..8bc043cf5a9ca 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -99,6 +99,7 @@ struct GroupQueryAttentionParameters { int sequence_length; // sequence length of input query, key, value int seqlen_past_kv_cache; // sequence length of past kv tensor int seqlen_present_kv_cache; // sequence length of present kv tensor + int total_sequence_length; // maximum total sequence length (past_sequence_length + sequence_length) among keys int hidden_size; int num_heads; int head_size; @@ -113,6 +114,7 @@ struct GroupQueryAttentionParameters { bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor bool do_rotary; bool rotary_interleaved; + bool use_smooth_softmax; float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 29ae769ed89f1..04e120863d39e 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -16,6 +16,47 @@ using onnxruntime::concurrency::ThreadPool; namespace onnxruntime { namespace contrib { +template +void ComputeSmoothSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { + ThreadPool::TryParallelFor(tp, N, D * 2.0, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t j = begin; j != end; ++j) { + float* x = reinterpret_cast(score) + j * D; + float* y = x; + + float max = -std::numeric_limits::infinity(); + for (int i = 0; i < D; i++) { + if (max < x[i]) + max = x[i]; + } + + if (max < 0.0f) { + max = 0.0f; + } + + for (int i = 0; i < D; i++) { + y[i] = expf(x[i] - max); + } + + double sum = 0.0; + + for (int i = 0; i < D; i++) { + sum += x[i]; + } + + sum += exp(static_cast(-max)); + + for (int i = 0; i < D; i++) { + y[i] = x[i] / (float)sum; + } + } + }); +} + +template <> +inline void ComputeSmoothSoftmaxInplace(float* score, int N, int D, ThreadPool* tp) { + MlasComputeSoftmax(score, score, N, D, false, true, tp); +} + template void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { ThreadPool::TryParallelFor(tp, N, D * 2.0, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { @@ -58,7 +99,7 @@ void ComputeAttentionSoftmaxInplace(T* score, int N, int D, ThreadPool* tp) { template <> inline void ComputeAttentionSoftmaxInplace(float* score, int N, int D, ThreadPool* tp) { - MlasComputeSoftmax(score, score, N, D, false, tp); + MlasComputeSoftmax(score, score, N, D, false, false, tp); } template diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 137612a4bf902..70f8564a2cbf2 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -30,6 +30,8 @@ class GQAAttentionBase { do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; + local_window_size_ = has_local ? static_cast(info.GetAttrOrDefault("local_window_size", -1)) : -1; } @@ -40,6 +42,8 @@ class GQAAttentionBase { bool rotary_interleaved_; int local_window_size_; + bool use_smooth_softmax_; + template Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH const T* K, // K data with shape BxN_kvxSxH @@ -195,10 +199,19 @@ class GQAAttentionBase { for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } - ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, - local_window_size_ + 1, nullptr); + if (use_smooth_softmax_) { + ComputeSmoothSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, + local_window_size_ + 1, nullptr); + } else { + ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, + local_window_size_ + 1, nullptr); + } } else { - ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + if (use_smooth_softmax_) { + ComputeSmoothSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + } else { + ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + } } // set causal [seq_causal_length, total_seqlen) to 0.f diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 5c0989bced70c..0d87cd51a8d4b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -123,7 +123,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { if (use_flash_attention) { using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( - parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); parameters.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index f9eabe27d97e4..95a18621c05d3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -297,7 +297,7 @@ Status FlashAttention( ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( device_prop, stream, data.q, data.k, data.v, data.output, reinterpret_cast(data.scratch), parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, - parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, + parameters.sequence_length, parameters.total_sequence_length, scale, parameters.is_unidirectional, is_bf16, false, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH)); @@ -345,17 +345,18 @@ Status EfficientAttention( p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; p.scale = scale; - p.seqlen_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast(data.mask_index)); - p.seqstart_q_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast( - data.mask_index + parameters.batch_size)); - p.seqstart_k_ptr = nullptr == data.mask_index - ? nullptr - : const_cast(reinterpret_cast( - data.mask_index + 2 * parameters.batch_size + 1)); + p.use_smooth_softmax = false; + + if (nullptr == data.mask_index) { + p.seqlen_k_ptr = nullptr; + p.seqstart_q_ptr = nullptr; + p.seqstart_k_ptr = nullptr; + } else { + p.seqlen_k_ptr = const_cast(reinterpret_cast(data.mask_index)); + p.seqstart_q_ptr = p.seqlen_k_ptr + parameters.batch_size; + p.seqstart_k_ptr = p.seqlen_k_ptr + 2 * parameters.batch_size + 1; + } + p.query = data.q; p.key = data.k; p.value = data.v; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index a5de20e44be1a..222c641883a90 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -214,6 +214,8 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) { p.v_strideB = params.num_heads * params.v_head_size * params.max_sequence_length; p.bias_strideB = params.is_attn_bias_batched ? static_cast(p.bias_strideH) * params.num_heads : 0; } + + p.use_smooth_softmax = params.use_smooth_softmax; } auto kernel_fn = attention_kernel_batched_impl; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index 08a562a12b844..81e70dab4e683 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -27,6 +27,7 @@ struct MemoryEfficientAttentionParams { bool causal; // The default shape of attn_bias is [1, N, S, S*]. Sometimes we need to use [B, N, S, S*] in custom models. bool is_attn_bias_batched; + bool use_smooth_softmax; float scale; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index 0463d3795b446..bcd87c1ab6251 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -121,6 +121,8 @@ struct Flash_fwd_params : public Qkv_params { bool is_rotary_interleaved = false; + bool smooth_softmax = false; + int num_splits = 0; // For split-KV version void* __restrict__ alibi_slopes_ptr = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 967c04c52b182..f875d31f5ca7a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -37,6 +37,7 @@ void set_params_fprop(Flash_fwd_params& params, float softmax_scale, bool is_causal, bool is_bf16, + bool use_smooth_softmax, bool kv_bsnh = true, int window_size_left = -1, int window_size_right = -1) { @@ -47,6 +48,7 @@ void set_params_fprop(Flash_fwd_params& params, params.o_ptr = out; params.is_bf16 = is_bf16; + params.smooth_softmax = use_smooth_softmax; // All stride are in elements, not bytes. if (kv_bsnh) { @@ -267,6 +269,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, float softmax_scale, bool is_causal, bool is_bf16, + bool use_smooth_softmax, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded @@ -293,6 +296,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, softmax_scale, is_causal, is_bf16, + use_smooth_softmax, kv_bsnh, local_window_size, is_causal ? 0 : -1); @@ -365,6 +369,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, softmax_scale, is_causal, is_bf16, + false, true, -1, is_causal ? 0 : -1); @@ -424,6 +429,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, const float softmax_scale, bool is_causal, bool is_bf16, + bool use_smooth_softmax, bool past_bsnh, // otherwise bnsh int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads @@ -456,6 +462,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, softmax_scale, is_causal, is_bf16, + use_smooth_softmax, past_bsnh, local_window_size, is_causal ? 0 : -1); diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 4c59561449851..baad0a938d377 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -52,6 +52,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, float softmax_scale, bool is_causal, bool is_bf16, + bool use_smooth_softmax, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded @@ -105,6 +106,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, const float softmax_scale, bool is_causal, bool is_bf16, + bool use_smooth_softmax, bool past_bsnh, // otherwise bnsh int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index 1c8a93674a80b..b2aa3668a5be1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -346,7 +346,7 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi // Epilogue - Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, params.smooth_softmax); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = flash::convert_type(acc_o); @@ -902,7 +902,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons // Epilogue - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.smooth_softmax); Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index ba678b740d376..7e0095cb39bd9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -159,7 +159,7 @@ struct Softmax { }; template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale) { + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale, bool smooth_softmax) { SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); @@ -167,7 +167,7 @@ struct Softmax { static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = row_sum(mi); + float sum = smooth_softmax ? row_sum(mi) + expf(-row_max(mi) * softmax_scale) : row_sum(mi); float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum); float scale = inv_sum; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 797f9b0a1ea47..58d1d7f0e4af4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -51,6 +51,7 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); + use_smooth_softmax_ = info.GetAttrOrDefault("smooth_softmax", 0) == 1; kernel_options_ = this->GetAttentionKernelOptions(); @@ -98,6 +99,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { device_prop.maxThreadsPerBlock)); parameters.local_window_size = local_window_size_; parameters.is_unidirectional = is_unidirectional_; + parameters.use_smooth_softmax = use_smooth_softmax_; parameters.zeros_count = kZerosCount; parameters.zero_ptr = zeros_.get(); // parameters.left_padding = left_padding_; @@ -132,7 +134,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { // split kv buffer using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( - parameters.batch_size, parameters.sequence_length, parameters.sequence_length, parameters.num_heads, + parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); parameters.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 4ff5b0a59f021..872fe9fe05ad2 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -28,6 +28,7 @@ class GroupQueryAttention final : public CudaKernel { bool is_past_bsnh_; bool do_rotary_; bool rotary_interleaved_; + bool use_smooth_softmax_; float scale_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 91418b17e6dbc..39efdfd66bcc6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -251,6 +251,7 @@ Status CheckInputs(const Tensor* query, output_parameters->sequence_length = sequence_length; // sequence length of Q output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors + output_parameters->total_sequence_length = total_sequence_length; // total sequence length output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; output_parameters->head_size = head_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 3099b52cce13e..356f723902da7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -678,9 +678,9 @@ Status FlashAttention( reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr, batch_size, num_heads, kv_num_heads, head_size, sequence_length, parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, - scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, - parameters.is_packed_qkv)); + scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, + reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), + parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv)); // if (parameters.left_padding && parameters.is_prompt) { // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); @@ -844,6 +844,7 @@ Status EfficientAttention( : nullptr; p.stream = stream; p.has_custom_right_padding = true; + p.use_smooth_softmax = parameters.use_smooth_softmax; run_memory_efficient_attention(p); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 2835192abd298..8ca3bea9ffd84 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -45,8 +45,6 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - ORT_ENFORCE(!is_unidirectional_, - "MHA support CUDA kernel does not Unidirectional. Consider using Attention or GQA instead."); kernel_options_ = this->GetAttentionKernelOptions(); @@ -168,7 +166,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { if (use_flash_attention) { using namespace std; auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( - parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads, + parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, parameters.head_size, device_prop.multiProcessorCount); parameters.num_splits = static_cast(num_splits); softmax_lse_accum_bytes = slse_accum_bytes; @@ -187,13 +185,13 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { bool use_fused_cross_attention = !use_flash_attention && !disable_fused_cross_attention_ && + !is_unidirectional_ && nullptr == key_padding_mask && nullptr == relative_position_bias && nullptr == past_key && nullptr == present_key && (parameters.qkv_format == Q_K_V_BSNH || (parameters.qkv_format == Q_KV_BSNH_BSN2H && bias == nullptr)) && parameters.hidden_size == parameters.v_hidden_size && - has_fused_cross_attention_kernel(sm, parameters.head_size, - parameters.kv_sequence_length); + has_fused_cross_attention_kernel(sm, parameters.head_size, parameters.kv_sequence_length); if (use_fused_cross_attention) { if (fused_fp16_cross_attention_kernel_ == nullptr) { std::call_once(fused_cross_init_once_flag_, [&]() { @@ -212,6 +210,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_fused_self_attention_ && fused_cross_attention_kernel == nullptr && + !is_unidirectional_ && nullptr == relative_position_bias && (parameters.qkv_format == Q_K_V_BSNH || parameters.qkv_format == QKV_BSN3H) && nullptr == past_key && nullptr == present_key && @@ -219,13 +218,12 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.hidden_size == parameters.v_hidden_size && parameters.sequence_length == parameters.kv_sequence_length && // self attention only for fused runner FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, false); + enable_trt_flash_attention_, is_unidirectional_); if (use_fused_runner) { // Here we assume that num_heads and head_size does not change for a MultiHeadAttention node. if (nullptr == fused_fp16_runner_.get()) { - constexpr bool is_unidirectional = false; std::call_once(fused_fp16_runner_created_, [&]() { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional, + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, enable_trt_flash_attention_, parameters.scale); }); } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 2521cd49b5482..e4a5afd528a9a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -515,6 +515,7 @@ Status FusedScaledDotProductAttentionCutlass( p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; + p.use_smooth_softmax = false; p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) : parameters.scale; p.seqlen_k_ptr = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index e5a4c54f48903..3c25f9146edfd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -693,6 +693,7 @@ Status FusedAttentionCutlass( p.qk_head_size = parameters.head_size; p.v_head_size = parameters.v_head_size; p.causal = false; + p.use_smooth_softmax = false; p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) : parameters.scale; p.seqlen_k_ptr = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 013b7e1779773..1a4a63de38790 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -79,7 +79,7 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_); + normalize_routing_weights_, use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu new file mode 100644 index 0000000000000..b0a72a1d2506a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_fp16_uint8.cu @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif +namespace ort_fastertransformer { +template class MoeGemmRunner; +} // namespace ort_fastertransformer + diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index 5f26de4810c42..a6ea9f4b61271 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -127,7 +127,7 @@ __launch_bounds__(TPB) __global__ const int block_row = blockIdx.x; const bool should_process_row = finished ? !finished[block_row] : true; - const int thread_read_offset = blockIdx.x * num_experts; + const int thread_row_offset = blockIdx.x * num_experts; float output_row_sum = 0.f; for (int k_idx = 0; k_idx < k; ++k_idx) { thread_kvp.key = 0; @@ -135,7 +135,7 @@ __launch_bounds__(TPB) __global__ cub_kvp inp_kvp; for (int expert = threadIdx.x; expert < num_experts; expert += TPB) { - const int idx = thread_read_offset + expert; + const int idx = thread_row_offset + expert; inp_kvp.key = expert; inp_kvp.value = inputs_after_softmax[idx]; @@ -169,6 +169,107 @@ __launch_bounds__(TPB) __global__ } #endif +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 +template +__launch_bounds__(TPB) __global__ void sparse_mixer_top2(const T *, T *, int *, int *, const float) { + // Does not support pre-Kepler architectures + ; +} +#else + +template +__launch_bounds__(TPB) __global__ + void sparse_mixer_top2(const T *inputs, T *output, int *indices, int *source_rows, const float jitter_eps) { + static constexpr int K = 2; + + using cub_kvp = cub::KeyValuePair; + using KVBlockReduce = cub::BlockReduce; + + __shared__ float result_kvp_value[K]; + __shared__ typename KVBlockReduce::TempStorage kvTmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + const int thread_row_offset = blockIdx.x * NUM_EXPERTS; + + float factor[K]; + bool logits_mask[K]; + +#pragma unroll + for (int k_idx = 0; k_idx < K; ++k_idx) { + thread_kvp.key = 0; + thread_kvp.value = T(-1.f); + + cub_kvp inp_kvp; +#pragma unroll + for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) { + const int idx = thread_row_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) { + const int prior_winning_expert = indices[K * block_row + prior_k]; + + if (prior_winning_expert == expert) { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = KVBlockReduce(kvTmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) { + const int idx = K * block_row + k_idx; + result_kvp_value[k_idx] = (float)result_kvp.value; + indices[idx] = result_kvp.key; + source_rows[idx] = k_idx * num_rows + block_row; + } + __syncthreads(); + +#pragma unroll + for (int expert = threadIdx.x; expert < NUM_EXPERTS; expert += TPB) { + const int idx = thread_row_offset + expert; + factor[k_idx] = max(abs((float)inputs[idx]), result_kvp_value[k_idx]); + logits_mask[k_idx] = (result_kvp_value[k_idx] - (float)inputs[idx]) > (2 * jitter_eps * factor[k_idx]); + if (k_idx == 1 && expert == indices[K * block_row]) { + logits_mask[1] = true; + } + } + } + +#pragma unroll + for (int k_idx = 0; k_idx < K; ++k_idx) { + float row_sum(0); + +#pragma unroll + for (int ii = threadIdx.x; ii < NUM_EXPERTS; ii += TPB) { + const int idx = thread_row_offset + ii; + row_sum += logits_mask[k_idx] ? 0 : exp((static_cast(inputs[idx]) - result_kvp_value[k_idx])); + } + +#pragma unroll + for (int mask = NUM_EXPERTS / 2; mask > 0; mask /= 2) { + row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, NUM_EXPERTS); + } + + const float normalizing_factor = 1.f / row_sum; + + const int idx = K * block_row + k_idx; + if (threadIdx.x == indices[idx]) { + const int input_idx = thread_row_offset + threadIdx.x; + output[idx] = logits_mask[k_idx] ? 0 + : exp((static_cast(inputs[input_idx]) - result_kvp_value[k_idx])) * + normalizing_factor; + } + } +} +#endif + // ====================== TopK softmax things =============================== /* @@ -406,9 +507,30 @@ void topk_gating_softmax_launcher_helper(const T *input, const bool *finished, T template void topk_gating_softmax_kernelLauncher(const T *input, const bool *finished, T *output, T *softmax_temp_output, int *indices, int *source_row, int num_rows, int num_experts, int k, - bool normalize_routing_weights, cudaStream_t stream) { + bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream) { static constexpr int WARPS_PER_TB = 4; + if (use_sparse_mixer) { + static constexpr int TPB = WARP_SIZE * WARPS_PER_TB; + static constexpr float jitter_eps = 0.01f; + + switch (num_experts) { + case 8: { + sparse_mixer_top2<<>>(input, output, indices, source_row, jitter_eps); + break; + } + case 16: { + sparse_mixer_top2<<>>(input, output, indices, source_row, jitter_eps); + break; + } + + default: { + ORT_THROW("Sparse mixer only supports 8 and 16 experts"); + } + } + return; + } + switch (num_experts) { case 2: { topk_gating_softmax_launcher_helper(input, finished, output, indices, source_row, num_rows, @@ -542,9 +664,9 @@ __global__ void dispatch_activations_kernel(int64_t *total_rows_before_expert, i template CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, bool has_fc3, - bool normalize_routing_weights) + bool normalize_routing_weights, bool use_sparse_mixer) : has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0), - normalize_routing_weights_(normalize_routing_weights) { + normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) { moe_gemm_runner_.initialize(sm_version); } @@ -729,7 +851,8 @@ void CutlassMoeFCRunner::run_moe_fc( configure_ws_ptrs(workspace_ptr, static_cast(num_rows), static_cast(hidden_size), static_cast(inter_size), static_cast(num_experts), static_cast(k)); topk_gating_softmax_kernelLauncher(gating_output, finished, expert_scales, softmax_out_, expert_for_source_row, - source_rows_, num_rows, num_experts, k, normalize_routing_weights_, stream); + source_rows_, num_rows, num_experts, k, normalize_routing_weights_, + use_sparse_mixer_, stream); const int sorter_ws_size_bytes = static_cast(pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows))); sorter_.run(reinterpret_cast(fc1_result_), sorter_ws_size_bytes, expert_for_source_row, permuted_experts_, @@ -748,7 +871,8 @@ void CutlassMoeFCRunner::run_moe_fc( stream); } - // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, expanded_active_expert_rows); + // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, + // expanded_active_expert_rows); moe_gemm_runner_.moe_gemm_bias_act( permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index, @@ -868,9 +992,9 @@ void CutlassMoeFCRunner::get_total_rows_info(int64_t expe // experts in the end. // Note that the expanded_dest_row_to_expanded_source_row map referred to here has indices in the range (0, -// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... (k-1)*rows_in_input -// all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we simply take the modulus -// of the expanded index. +// k*rows_in_input - 1). However, it is set up so that index 0, rows_in_input, 2*rows_in_input ... +// (k-1)*rows_in_input all map to row 0 in the original matrix. Thus, to know where to read in the source matrix, we +// simply take the modulus of the expanded index. template __global__ void initialize_moe_routing_kernel(const T *unpermuted_input, T *permuted_output, @@ -878,9 +1002,9 @@ __global__ void initialize_moe_routing_kernel(const T *unpermuted_input, T *perm int *expanded_source_row_to_expanded_dest_row, int num_rows, int active_rows, int cols) { // Reverse permutation map. - // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need the - // reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 - // thread block will be responsible for all k summations. + // I do this so that later, we can use the source -> dest map to do the k-way reduction and unpermuting. I need + // the reverse map for that reduction to allow each threadblock to do 1 k-way reduce without atomics later in + // MoE. 1 thread block will be responsible for all k summations. const int expanded_dest_row = blockIdx.x; const int expanded_source_row = expanded_dest_row_to_expanded_source_row[expanded_dest_row]; if (threadIdx.x == 0) { @@ -1014,14 +1138,15 @@ void finalize_moe_routing_kernelLauncher(const T *expanded_permuted_rows, T *red // ========================= TopK Softmax specializations =========================== template void topk_gating_softmax_kernelLauncher(const float *, const bool *, float *, float *, int *, int *, int, int, - int, bool, cudaStream_t); + int, bool, bool, cudaStream_t); template void topk_gating_softmax_kernelLauncher(const half *, const bool *, half *, half *, int *, int *, int, int, - int, bool, cudaStream_t); + int, bool, bool, cudaStream_t); // ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner; // ===================== Specializations for init routing ========================= template void initialize_moe_routing_kernelLauncher(const float *, float *, const int *, int *, int, int, int, int, @@ -1043,4 +1168,4 @@ template void finalize_moe_routing_kernelLauncher(const float *, float *, const template void finalize_moe_routing_kernelLauncher(const half *, half *, const half *, const half *, const half *, const half *, const int *, const int *, int, int, int, cudaStream_t); -} // namespace ort_fastertransformer +} // namespace ort_fastertransformer \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index 18a26e6a43382..c457b608decbf 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -109,7 +109,7 @@ template class CutlassMoeFCRunner { public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights); + CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k); @@ -161,6 +161,7 @@ class CutlassMoeFCRunner { bool has_fc3_; bool normalize_routing_weights_; + bool use_sparse_mixer_; // Cuda events contrib::cuda::AutoDestoryCudaEvent cuda_event_; @@ -175,7 +176,7 @@ class CutlassMoeFCRunner { template class CutlassMoeFCRunner::value>> { public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights); + CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) { return 0; diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index 6aa75840e6dc0..c5352d931ce2c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -49,7 +49,7 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { const int sm = device_prop.major * 10 + device_prop.minor; ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_); + normalize_routing_weights_, use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index 4a407fa1b2159..6b65557444a66 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -22,6 +22,7 @@ enum class MoEParallelType { enum class MoEQuantType { None = 0, UINT4 = 1, + UINT8 = 2, }; struct MoEParameters { @@ -225,9 +226,15 @@ class MoEBase { } normalize_routing_weights_ = op_kernel_info.GetAttrOrDefault("normalize_routing_weights", 0) == 1; + + use_sparse_mixer_ = op_kernel_info.GetAttrOrDefault("use_sparse_mixer", 0) == 1; + if (use_sparse_mixer_) { + ORT_ENFORCE(k_ == 2, "Sparse mixer only supports k=2"); + } } bool normalize_routing_weights_; + bool use_sparse_mixer_; int64_t k_; ort_fastertransformer::ActivationType activation_type_; }; diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 571cc59dec75c..4dd5a079d1a29 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -37,61 +37,54 @@ template <> struct ToCudaTypeWrapper { using MappedType = cutlass::uint4b_t; }; + } // anonymous namespace -QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) {} +QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { + ORT_ENFORCE(op_kernel_info.GetAttr("expert_weight_bits", &expert_weight_bits_).IsOK()); + ORT_ENFORCE(expert_weight_bits_ == 8 || expert_weight_bits_ == 4, + "expert_weight_bits must be 4 or 8, but got ", expert_weight_bits_); +} -Status QMoE::ComputeInternal(OpKernelContext* context) const { - const Tensor* input = context->Input(0); - const Tensor* router_probs = context->Input(1); - const Tensor* fc1_experts_weights = context->Input(2); - const Tensor* fc1_scales = context->Input(3); - const Tensor* fc1_experts_bias_optional = context->Input(4); - const Tensor* fc2_experts_weights = context->Input(5); - const Tensor* fc2_scales = context->Input(6); - const Tensor* fc2_experts_bias_optional = context->Input(7); - const Tensor* fc3_experts_weights_optional = context->Input(8); - const Tensor* fc3_scales_optional = context->Input(9); - const Tensor* fc3_experts_bias_optional = context->Input(10); +template +Status QMoE::QuantizedMoEImpl(OpKernelContext* context, + MoEParameters& moe_params, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias_optional, + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias_optional, + const Tensor* fc3_experts_weights_optional, + const Tensor* fc3_experts_bias_optional, + const Tensor* fc1_scales, + const Tensor* fc2_scales, + const Tensor* fc3_scales_optional, + const cudaDeviceProp& device_prop) const { + auto stream = context->GetComputeStream(); -#if defined(__GNUC__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // Mute "maybe used uninitialized" warning for MoEParameters. -#endif + const int sm = device_prop.major * 10 + device_prop.minor; - MoEParameters moe_params; - MoEQuantType quant_type = MoEQuantType::UINT4; - ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, - fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, - fc3_experts_weights_optional, fc3_experts_bias_optional)); - ORT_RETURN_IF_ERROR(CheckInputScales(fc1_scales, fc2_scales, fc3_scales_optional, moe_params.num_experts, - moe_params.hidden_size, moe_params.inter_size)); + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - // Support int4 only at the moment. We can add uint8 if needed. - static constexpr bool use_quint4x2 = true; using T = MLFloat16; using CudaT = typename ToCudaType::MappedType; - using CudaWeightT = typename ToCudaTypeWrapper::MappedType; - - auto stream = context->GetComputeStream(); - - auto& device_prop = GetDeviceProp(); - const int sm = device_prop.major * 10 + device_prop.minor; - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_, + use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), - static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), static_cast(k_)); + static_cast(moe_params.inter_size), static_cast(moe_params.num_experts), + static_cast(k_)); size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT); size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT); size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int); size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int); - AllocatorPtr allocator; - ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr(allocator, ws_size, false, stream); IAllocatorUniquePtr fc2_output = IAllocator::MakeUniquePtr(allocator, fc2_output_size, false, stream); IAllocatorUniquePtr expert_scales = @@ -140,13 +133,56 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { reinterpret_cast(expert_for_source_row.get()), static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), static_cast(k_), Stream(context)); + return Status::OK(); +} + +Status QMoE::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* router_probs = context->Input(1); + const Tensor* fc1_experts_weights = context->Input(2); + const Tensor* fc1_scales = context->Input(3); + const Tensor* fc1_experts_bias_optional = context->Input(4); + const Tensor* fc2_experts_weights = context->Input(5); + const Tensor* fc2_scales = context->Input(6); + const Tensor* fc2_experts_bias_optional = context->Input(7); + const Tensor* fc3_experts_weights_optional = context->Input(8); + const Tensor* fc3_scales_optional = context->Input(9); + const Tensor* fc3_experts_bias_optional = context->Input(10); + + MoEQuantType quant_type = expert_weight_bits_ == 4 ? MoEQuantType::UINT4 : MoEQuantType::UINT8; + MoEParameters moe_params; + ORT_RETURN_IF_ERROR(CheckInputs(moe_params, quant_type, input, router_probs, fc1_experts_weights, + fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, + fc3_experts_weights_optional, fc3_experts_bias_optional)); + ORT_RETURN_IF_ERROR(CheckInputScales(fc1_scales, fc2_scales, fc3_scales_optional, moe_params.num_experts, + moe_params.hidden_size, moe_params.inter_size)); + #if defined(__GNUC__) -#pragma GCC diagnostic pop +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // Mute "maybe used uninitialized" warning for MoEParameters. #endif - return Status::OK(); + if (quant_type == MoEQuantType::UINT4) { + using CudaWeightT = typename ToCudaTypeWrapper::MappedType; + return QuantizedMoEImpl(context, moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, + fc2_experts_bias_optional, fc3_experts_weights_optional, + fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional, + GetDeviceProp()); + } else { + using CudaWeightT = typename ToCudaTypeWrapper::MappedType; + return QuantizedMoEImpl(context, moe_params, input, router_probs, + fc1_experts_weights, fc1_experts_bias_optional, fc2_experts_weights, + fc2_experts_bias_optional, fc3_experts_weights_optional, + fc3_experts_bias_optional, fc1_scales, fc2_scales, fc3_scales_optional, + GetDeviceProp()); + } + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif } } // namespace cuda } // namespace contrib -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h index 7b68d2d082de8..c0164576d7c7f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h @@ -18,6 +18,25 @@ class QMoE final : public CudaKernel, public MoEBase { public: explicit QMoE(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* ctx) const override; + + private: + template + Status QuantizedMoEImpl(OpKernelContext* context, + MoEParameters& moe_params, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias_optional, + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias_optional, + const Tensor* fc3_experts_weights_optional, + const Tensor* fc3_experts_bias_optional, + const Tensor* fc1_scales, + const Tensor* fc2_scales, + const Tensor* fc3_scales_optional, + const cudaDeviceProp& device_prop) const; + + int64_t expert_weight_bits_; }; } // namespace cuda diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 7272a949f7218..27bbd108a02c2 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1075,6 +1075,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Rotate using interleaved pattern. Default value is 0 (False).", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("smooth_softmax", + "Use a smooth factor in softmax.", + AttributeProto::INT, + static_cast(-1)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" diff --git a/onnxruntime/core/graph/contrib_ops/collective_defs.cc b/onnxruntime/core/graph/contrib_ops/collective_defs.cc index a0ca2e45f153a..7b4f3611f7cdf 100644 --- a/onnxruntime/core/graph/contrib_ops/collective_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/collective_defs.cc @@ -95,6 +95,10 @@ void RegisterCollectiveOps() { "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) + .Attr("use_sparse_mixer", + "Whether to use sparse mixer", + AttributeProto::INT, + static_cast(0)) .Attr("local_experts_start_index", "The start index of local experts", AttributeProto::INT, diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 2d51658953282..6c6b13cf14264 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1395,6 +1395,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) .Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) + .Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0)) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") @@ -1410,7 +1411,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, ONNX_MS_OPERATOR_SET_SCHEMA( QMoE, 1, OpSchema() - .SetDoc("Int4 MoE") + .SetDoc("Quantized MoE") .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, @@ -1423,18 +1424,31 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) + .Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0)) + .Attr("expert_weight_bits", + "Number of bits used in quantized weights. Default is 4 bits", + AttributeProto::INT, + static_cast(4)) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " "(batch_size, sequence_length, hidden_size)", "T") .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") - .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size / 2)", "T1") + .Input(2, + "fc1_experts_weights", + "3D input tensor with shape (num_experts, hidden_size, inter_size) " + "or (num_experts, hidden_size, inter_size / 2)", + "T1") .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size)", "T") .Input(4, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) - .Input(5, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size / 2)", "T1") + .Input(5, + "fc2_experts_weights", + "3D input tensor with shape (num_experts, inter_size, hidden_size) " + "or (num_experts, inter_size, hidden_size / 2)", + "T1") .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T") .Input(7, "fc2_experts_bias", @@ -1443,7 +1457,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(8, "fc3_experts_weights", - "3D optional input tensor with shape (num_experts, hidden_size, inter_size / 2)", + "3D optional input tensor with shape (num_experts, hidden_size, inter_size) " + "or (num_experts, hidden_size, inter_size / 2)", "T1", OpSchema::Optional) .Input(9, diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index e46105324a7fb..bea4b91ebaa79 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1013,6 +1013,7 @@ MlasComputeSoftmax( size_t N, size_t D, bool LogSoftmax, + bool SmoothSoftmax, MLAS_THREADPOOL* ThreadPool ); diff --git a/onnxruntime/core/mlas/lib/compute.cpp b/onnxruntime/core/mlas/lib/compute.cpp index f4c1e3da69289..73df23e64ca1f 100644 --- a/onnxruntime/core/mlas/lib/compute.cpp +++ b/onnxruntime/core/mlas/lib/compute.cpp @@ -71,6 +71,7 @@ MLAS_INTERNAL_DATA const float MlasMinimumF32Value = std::numeric_limits: struct MLAS_SOFTMAX_WORK_BLOCK { ptrdiff_t ThreadCountN; bool LogSoftmax; + bool SmoothSoftmax; const float* Input; float* Output; size_t N; @@ -81,7 +82,7 @@ MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasComputeExpVector( MLAS_FLOAT32X4 Vector - ) +) /*++ Routine Description: @@ -186,7 +187,7 @@ MlasComputeExpF32Kernel( const float* Input, float* Output, size_t N - ) +) /*++ Routine Description: @@ -208,7 +209,6 @@ Return Value: --*/ { while (N > 0) { - MLAS_FLOAT32X4 Vector; if (N >= 4) { @@ -228,7 +228,6 @@ Return Value: Vector = MlasComputeExpVector(Vector); if (N >= 4) { - MlasStoreFloat32x4(Output, Vector); Input += 4; @@ -236,7 +235,6 @@ Return Value: N -= 4; } else { - MlasStoreLaneFloat32x4<0>(Output, Vector); Input += 1; @@ -252,7 +250,7 @@ MlasComputeExp( const float* Input, float* Output, size_t N - ) +) /*++ Routine Description: @@ -287,7 +285,7 @@ MLAS_FLOAT32X4 MlasComputeSumExpVector( MLAS_FLOAT32X4 Vector, MLAS_FLOAT32X4 NegativeMaximumVector - ) +) /*++ Routine Description: @@ -379,7 +377,7 @@ MlasComputeSumExpF32Kernel( float* Output, size_t N, const float* NegativeMaximum - ) +) /*++ Routine Description: @@ -411,7 +409,6 @@ Return Value: float Accumulator = 0.0f; if (N >= 4) { - MLAS_FLOAT32X4 AccumulatorVector = MlasZeroFloat32x4(); #if !defined(MLAS_SSE2_INTRINSICS) @@ -426,7 +423,6 @@ Return Value: // while (N >= 8) { - MLAS_FLOAT32X4 Vector0 = MlasLoadFloat32x4(Input); MLAS_FLOAT32X4 Vector1 = MlasLoadFloat32x4(Input + 4); @@ -448,7 +444,6 @@ Return Value: #endif while (N >= 4) { - MLAS_FLOAT32X4 Vector = MlasLoadFloat32x4(Input); Vector = MlasComputeSumExpVector(Vector, NegativeMaximumVector); @@ -467,7 +462,6 @@ Return Value: } while (N > 0) { - #if defined(MLAS_SSE2_INTRINSICS) // N.B. SSE2 lacks a broadcast load instruction, so avoid a shuffle and // use zeroes for the upper elements. @@ -498,7 +492,7 @@ MLASCALL MlasReduceMaximumF32Kernel( const float* Input, size_t N - ) +) /*++ Routine Description: @@ -521,17 +515,14 @@ Return Value: float Maximum = MlasMinimumF32Value; if (N >= 4) { - MLAS_FLOAT32X4 MaximumVector0 = MlasBroadcastFloat32x4(Maximum); if (N >= 16) { - MLAS_FLOAT32X4 MaximumVector1 = MaximumVector0; MLAS_FLOAT32X4 MaximumVector2 = MaximumVector0; MLAS_FLOAT32X4 MaximumVector3 = MaximumVector0; while (N >= 16) { - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input)); MaximumVector1 = MlasMaximumFloat32x4(MaximumVector1, MlasLoadFloat32x4(Input + 4)); MaximumVector2 = MlasMaximumFloat32x4(MaximumVector2, MlasLoadFloat32x4(Input + 8)); @@ -547,7 +538,6 @@ Return Value: } while (N >= 4) { - MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, MlasLoadFloat32x4(Input)); Input += 4; @@ -558,7 +548,6 @@ Return Value: } while (N > 0) { - Maximum = std::max(Maximum, *Input); Input += 1; @@ -575,18 +564,16 @@ MlasReduceMinimumMaximumF32Kernel( float* Min, float* Max, size_t N - ) +) { float tmp_min = std::numeric_limits::max(); float tmp_max = std::numeric_limits::lowest(); if (N >= 4) { - MLAS_FLOAT32X4 MaximumVector0 = MlasBroadcastFloat32x4(tmp_max); MLAS_FLOAT32X4 MinimumVector0 = MlasBroadcastFloat32x4(tmp_min); if (N >= 16) { - MLAS_FLOAT32X4 MaximumVector1 = MaximumVector0; MLAS_FLOAT32X4 MaximumVector2 = MaximumVector0; MLAS_FLOAT32X4 MaximumVector3 = MaximumVector0; @@ -596,7 +583,6 @@ MlasReduceMinimumMaximumF32Kernel( MLAS_FLOAT32X4 MinimumVector3 = MinimumVector0; while (N >= 16) { - MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(Input); MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(Input + 4); MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(Input + 8); @@ -626,7 +612,6 @@ MlasReduceMinimumMaximumF32Kernel( } while (N >= 4) { - MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(Input); MaximumVector0 = MlasMaximumFloat32x4(MaximumVector0, InputVector0); @@ -641,7 +626,6 @@ MlasReduceMinimumMaximumF32Kernel( } while (N > 0) { - tmp_max = std::max(tmp_max, *Input); tmp_min = std::min(tmp_min, *Input); @@ -659,7 +643,7 @@ MlasComputeSoftmaxOutputF32Kernel( float* Output, size_t N, const float* Parameters - ) +) /*++ Routine Description: @@ -686,7 +670,6 @@ Return Value: const MLAS_FLOAT32X4 ScaleVector = MlasBroadcastFloat32x4(Scale); while (N >= 16) { - MLAS_FLOAT32X4 Vector0 = MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output)); MLAS_FLOAT32X4 Vector1 = MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output + 4)); MLAS_FLOAT32X4 Vector2 = MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output + 8)); @@ -702,7 +685,6 @@ Return Value: } while (N >= 4) { - MlasStoreFloat32x4(Output, MlasMultiplyFloat32x4(ScaleVector, MlasLoadFloat32x4(Output))); Output += 4; @@ -710,7 +692,6 @@ Return Value: } while (N > 0) { - *Output *= Scale; Output += 1; @@ -725,7 +706,7 @@ MlasComputeLogSoftmaxOutputF32Kernel( float* Output, size_t N, const float* Parameters - ) +) /*++ Routine Description: @@ -757,7 +738,6 @@ Return Value: const MLAS_FLOAT32X4 LogarithmVector = MlasBroadcastFloat32x4(Logarithm); while (N >= 16) { - MLAS_FLOAT32X4 Vector0 = MlasLoadFloat32x4(Input); MLAS_FLOAT32X4 Vector1 = MlasLoadFloat32x4(Input + 4); MLAS_FLOAT32X4 Vector2 = MlasLoadFloat32x4(Input + 8); @@ -784,7 +764,6 @@ Return Value: } while (N >= 4) { - MLAS_FLOAT32X4 Vector = MlasLoadFloat32x4(Input); Vector = MlasAddFloat32x4(Vector, NegativeMaximumVector); Vector = MlasSubtractFloat32x4(Vector, LogarithmVector); @@ -796,7 +775,6 @@ Return Value: } while (N > 0) { - *Output = *Input + NegativeMaximum - Logarithm; Input += 1; @@ -809,7 +787,7 @@ void MlasComputeSoftmaxThreaded( void* Context, ptrdiff_t Index - ) +) /*++ Routine Description: @@ -846,6 +824,7 @@ Return Value: const size_t D = WorkBlock->D; const bool LogSoftmax = WorkBlock->LogSoftmax; + const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; const float* Input = WorkBlock->Input + n * D; float* Output = WorkBlock->Output + n * D; @@ -857,7 +836,6 @@ Return Value: #endif while (CountN > 0) { - #if defined(MLAS_SSE2_INTRINSICS) // // Prefetch the next row of the input buffer. @@ -878,24 +856,30 @@ Return Value: float Maximum = MlasReduceMaximumF32Kernel(Input, D); #endif float NegativeMaximum = -Maximum; + if (SmoothSoftmax && NegativeMaximum > 0.0f) { + NegativeMaximum = 0.0f; + } - if (LogSoftmax) { - - // - // Compute the sum of the exponential functions for the row. - // - + // + // Compute the exponential function for each element of the row (save to Temp if provided) and + // compute the sum of these exponential functions. + // + float* Temp = LogSoftmax ? nullptr : Output; #if defined(MLAS_TARGET_AMD64) - float Accumulation = GetMlasPlatform().ComputeSumExpF32Kernel(Input, nullptr, D, &NegativeMaximum); + float Accumulation = GetMlasPlatform().ComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); #else - float Accumulation = MlasComputeSumExpF32Kernel(Input, nullptr, D, &NegativeMaximum); + float Accumulation = MlasComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); #endif + if (SmoothSoftmax) { + Accumulation += expf(NegativeMaximum); + } + + if (LogSoftmax) { // // Compute the log softmax output. // - - float Parameters[] = { NegativeMaximum, std::log(Accumulation)}; + float Parameters[] = {NegativeMaximum, std::log(Accumulation)}; #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) GetMlasPlatform().ComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); @@ -904,23 +888,10 @@ Return Value: #endif } else { - - // - // Compute the exponential function for each element of the row and - // compute the sum of these exponential functions. - // - -#if defined(MLAS_TARGET_AMD64) - float Accumulation = GetMlasPlatform().ComputeSumExpF32Kernel(Input, Output, D, &NegativeMaximum); -#else - float Accumulation = MlasComputeSumExpF32Kernel(Input, Output, D, &NegativeMaximum); -#endif - // // Normalize the softmax output. // - - float Parameters[] = { 1.0f / Accumulation }; + float Parameters[] = {1.0f / Accumulation}; #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) GetMlasPlatform().ComputeSoftmaxOutputF32Kernel(Output, D, Parameters); @@ -943,8 +914,9 @@ MlasComputeSoftmax( size_t N, size_t D, bool LogSoftmax, + bool SmoothSoftmax, MLAS_THREADPOOL* ThreadPool - ) +) /*++ Routine Description: @@ -966,6 +938,8 @@ Routine Description: LogSoftmax - Supplies true if this is a log softmax operation, else false if this is a softmax operation. + SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation. + ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. @@ -982,6 +956,7 @@ Return Value: // WorkBlock.LogSoftmax = LogSoftmax; + WorkBlock.SmoothSoftmax = SmoothSoftmax; WorkBlock.Input = Input; WorkBlock.Output = Output; WorkBlock.N = N; diff --git a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc index e9a8176c8349b..bc8b2d1a3505d 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc @@ -38,13 +38,14 @@ Status ArgMaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // 2. Otherwise, we add Argmax layer normally if (node.GetOutputEdgesCount() == 1) { auto it = node.OutputEdgesBegin(); - const auto* succ_node(graph_viewer.GetNode(it->GetNode().Index())); + const auto* next_node_in_partition = graph_viewer.GetNode(it->GetNode().Index()); // If Argmax's successive node is a Cast from int64 to int32 output - // The 'cast to' type is checked in operator supported related, omit the check here - if (succ_node->OpType() == "Cast") { + // The 'cast to' type is checked when determining operator support (see CastOpBuilder::IsOpSupportedImpl()) + // so we omit the check here + if (next_node_in_partition != nullptr && next_node_in_partition->OpType() == "Cast") { // Skip the cast's input/argmax's output *layer->mutable_input()->Add() = node.InputDefs()[0]->Name(); - *layer->mutable_output()->Add() = succ_node->OutputDefs()[0]->Name(); + *layer->mutable_output()->Add() = next_node_in_partition->OutputDefs()[0]->Name(); model_builder.AddLayer(std::move(layer)); return Status::OK(); } diff --git a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc index 70053c2c606a0..fc8879abbefb0 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc @@ -36,11 +36,6 @@ bool CastOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPara return false; } - if (node.GetInputEdgesCount() > 1) { - LOGS(logger, VERBOSE) << "Multiple nodes producing Cast's input."; - return false; - } - const auto& prec_node = node.InputEdgesBegin()->GetNode(); /*Cast node is only aimed for supporting argmax and we are only handling the case where an argmax diff --git a/onnxruntime/core/providers/cpu/math/softmax_shared.cc b/onnxruntime/core/providers/cpu/math/softmax_shared.cc index cae20b42725b8..2817dda9d0085 100644 --- a/onnxruntime/core/providers/cpu/math/softmax_shared.cc +++ b/onnxruntime/core/providers/cpu/math/softmax_shared.cc @@ -99,7 +99,7 @@ common::Status SoftmaxCPU(size_t N, float* Ydata, bool logarithmic, onnxruntime::concurrency::ThreadPool* thread_pool) { - MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, thread_pool); + MlasComputeSoftmax(Xdata, Ydata, N, D, logarithmic, false, thread_pool); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/ml/ml_common.h b/onnxruntime/core/providers/cpu/ml/ml_common.h index ed108eade05ab..2f4ebeabe043e 100644 --- a/onnxruntime/core/providers/cpu/ml/ml_common.h +++ b/onnxruntime/core/providers/cpu/ml/ml_common.h @@ -441,7 +441,7 @@ void batched_update_scores_inplace(gsl::span scores, int64_t num_batches_in, } if (use_mlas) { - MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, threadpool); + MlasComputeSoftmax(s, s, num_batches, onnxruntime::narrow(batch_size), false, false, threadpool); } else { while (s < s_end) { gsl::span scores_for_batch(s, s + batch_size); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index a43e2e766c687..ae1b1cb1e261e 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2763,7 +2763,7 @@ static_assert(offsetof(OrtApi, SessionOptionsAppendExecutionProvider_OpenVINO_V2 static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeof(void*) == 279, "Size of version 18 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.19.1", +static_assert(std::string_view(ORT_VERSION) == "1.19.2", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it // 2. If there were any APIs added to ort_api_1_to_19 above: diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 4f46242a4f402..2375104ac96f5 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -304,7 +304,7 @@ def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]): tensor.data_ptr(), ) - def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = False): + def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True): """Bind input tensors and run inference""" for name, tensor in feed_dict.items(): assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous() @@ -317,7 +317,6 @@ def infer(self, feed_dict: Dict[str, torch.Tensor], run_options: RunOptions = No else: self.bind_input_and_buffer_sharing(name, tensor) - # Synchronization are not needed in most cases unless different streams are used or inputs/outputs are in CPU. if synchronize: self.io_binding.synchronize_inputs() self.ort_session.run_with_iobinding(self.io_binding, run_options) diff --git a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp index 6181be873f73e..65822eb294d7d 100644 --- a/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp +++ b/onnxruntime/test/mlas/bench/bench_computesoftmax.cpp @@ -58,10 +58,10 @@ void COMPUTESOFTMAXINPLACE(benchmark::State& state) { std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory // warming up run - MlasComputeSoftmax(input, output, N, D, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); for (auto _ : state) { - MlasComputeSoftmax(input, output, N, D, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); } free(ptr.underlying_buffer); diff --git a/onnxruntime/test/mlas/unittest/test_softmax.cpp b/onnxruntime/test/mlas/unittest/test_softmax.cpp index 4c5e11bbe9566..fb4ebbee77faf 100644 --- a/onnxruntime/test/mlas/unittest/test_softmax.cpp +++ b/onnxruntime/test/mlas/unittest/test_softmax.cpp @@ -23,13 +23,15 @@ class MlasSoftmaxTest : public MlasTestBase { Input[nd] = distribution(generator); } - Test(Input, Output, OutputReference, N, D, false); - Test(Input, Output, OutputReference, N, D, true); + Test(Input, Output, OutputReference, N, D, false, true); + Test(Input, Output, OutputReference, N, D, true, true); + Test(Input, Output, OutputReference, N, D, false, false); + Test(Input, Output, OutputReference, N, D, true, false); } - void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax) { - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, threadpool_); - ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax); + void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 1e-6f; constexpr float RelativeTolerance = 1e-6f; @@ -42,7 +44,7 @@ class MlasSoftmaxTest : public MlasTestBase { } } - void ReferenceSoftmax(const float* Input, float* Output, size_t N, size_t D, bool LogSoftmax) { + void ReferenceSoftmax(const float* Input, float* Output, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { for (size_t n = 0; n < N; n++) { float MaximumValue = std::numeric_limits::lowest(); @@ -50,6 +52,10 @@ class MlasSoftmaxTest : public MlasTestBase { MaximumValue = (std::max)(MaximumValue, Input[d]); } + if (SmoothSoftmax && MaximumValue < 0.0f) { + MaximumValue = 0.0f; + } + double Sum = 0.0; for (size_t d = 0; d < D; d++) { @@ -58,6 +64,10 @@ class MlasSoftmaxTest : public MlasTestBase { Output[d] = float(e); } + if (SmoothSoftmax) { + Sum += expf(-MaximumValue); + } + if (LogSoftmax) { float Scale = float(std::log(Sum)); diff --git a/onnxruntime/test/providers/coreml/coreml_basic_test.cc b/onnxruntime/test/providers/coreml/coreml_basic_test.cc index 0f068ba48d3d8..daa24db134114 100644 --- a/onnxruntime/test/providers/coreml/coreml_basic_test.cc +++ b/onnxruntime/test/providers/coreml/coreml_basic_test.cc @@ -2,6 +2,8 @@ // Licensed under the MIT License. #include "core/common/logging/logging.h" +#include "core/graph/graph.h" +#include "core/graph/graph_viewer.h" #include "core/providers/coreml/coreml_execution_provider.h" #include "core/providers/coreml/coreml_provider_factory.h" #include "core/session/inference_session.h" @@ -92,7 +94,7 @@ TEST(CoreMLExecutionProviderTest, FunctionTest) { feeds.insert(std::make_pair("Y", ml_value_y)); feeds.insert(std::make_pair("Z", ml_value_z)); - RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.FunctionTest", + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), MakeCoreMLExecutionProvider(), feeds); #else @@ -118,9 +120,50 @@ TEST(CoreMLExecutionProviderTest, ArgMaxCastTest) { NameMLValMap feeds; feeds.insert(std::make_pair("X", ml_value_x)); - RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.ArgMaxCastTest", + EPVerificationParams verification_params{}; + verification_params.ep_node_assignment = ExpectedEPNodeAssignment::All; + + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), MakeCoreMLExecutionProvider(), - feeds); + feeds, + verification_params); +#else + TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::All); +#endif +} + +TEST(CoreMLExecutionProviderTest, ArgMaxUnsupportedCastTest) { + const ORTCHAR_T* model_file_name = ORT_TSTR("testdata/coreml_argmax_unsupported_cast_test.onnx"); + +#if defined(__APPLE__) + std::vector dims_mul_x = {3, 2, 2}; + std::vector values_mul_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + OrtValue ml_value_x; + AllocatorPtr allocator = std::make_shared(); + CreateMLValue(allocator, dims_mul_x, values_mul_x, &ml_value_x); + + NameMLValMap feeds; + feeds.insert(std::make_pair("X", ml_value_x)); + + const std::function graph_verifier = [](const Graph& graph) { + GraphViewer graph_viewer{graph}; + const auto& node_indices_in_order = graph_viewer.GetNodesInTopologicalOrder(); + ASSERT_EQ(node_indices_in_order.size(), size_t{2}); + // second node should be an unsupported Cast + const auto* cast_node = graph.GetNode(node_indices_in_order[1]); + ASSERT_NE(cast_node, nullptr); + ASSERT_EQ(cast_node->OpType(), "Cast"); + ASSERT_EQ(cast_node->GetExecutionProviderType(), kCpuExecutionProvider); + }; + + EPVerificationParams verification_params{}; + verification_params.ep_node_assignment = ExpectedEPNodeAssignment::Some; + verification_params.graph_verifier = &graph_verifier; + + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), + MakeCoreMLExecutionProvider(), + feeds, + verification_params); #else TestModelLoad(model_file_name, MakeCoreMLExecutionProvider(), ExpectedEPNodeAssignment::Some); #endif @@ -184,7 +227,7 @@ TEST(CoreMLExecutionProviderTest, TestOrtFormatModel) { NameMLValMap feeds; feeds.insert(std::make_pair("Input3", ml_value)); - RunAndVerifyOutputsWithEP(model_file_name, "CoreMLExecutionProviderTest.TestOrtFormatModel", + RunAndVerifyOutputsWithEP(model_file_name, CurrentTestName(), MakeCoreMLExecutionProvider(), feeds); #else diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py index 5e028519b9f34..53d015a029083 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -37,6 +37,7 @@ def plot_prompt_performance( head_size: int, max_seq_len: int, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, ): import triton @@ -55,6 +56,7 @@ def plot_prompt_performance( "kv_num_heads": kv_num_heads, "head_size": head_size, "local_window_size": local_window_size, + "use_smooth_softmax": use_smooth_softmax, }, ) ] @@ -68,6 +70,7 @@ def benchmark( kv_num_heads: int, head_size: int, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, device="cuda", ): warmup = 15 @@ -82,6 +85,7 @@ def benchmark( kv_num_heads=kv_num_heads, head_size=head_size, local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1, + use_smooth_softmax=use_smooth_softmax, device=device, is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"], ) @@ -103,6 +107,7 @@ def plot_token_performance( head_size: int, max_seq_len: int, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, ): import triton @@ -121,6 +126,7 @@ def plot_token_performance( "kv_num_heads": kv_num_heads, "head_size": head_size, "local_window_size": local_window_size, + "use_smooth_softmax": use_smooth_softmax, }, ) ] @@ -134,6 +140,7 @@ def benchmark( kv_num_heads: int, head_size: int, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, device="cuda", ): warmup = 15 @@ -150,6 +157,7 @@ def benchmark( local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1, do_rotary=True, # Most models use rotary positional embeddings is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"], + use_smooth_softmax=use_smooth_softmax, device=device, ) @@ -186,26 +194,29 @@ def run_performance_test(sm: int): for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures: for batch_size in [1, 4]: - plot_prompt_performance( - sm=sm, - batch_size=batch_size, - num_heads=num_heads, - kv_num_heads=kv_num_heads, - head_size=head_size, - max_seq_len=min(threshold, max_seq_len), - local_window_size=local_window_size, - model_name=model_name, - ) - plot_token_performance( - sm=sm, - batch_size=batch_size, - num_heads=num_heads, - kv_num_heads=kv_num_heads, - head_size=head_size, - max_seq_len=min(threshold, max_seq_len), - local_window_size=local_window_size, - model_name=model_name, - ) + for use_smooth_softmax in [False, True]: + plot_prompt_performance( + sm=sm, + batch_size=batch_size, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + max_seq_len=min(threshold, max_seq_len), + local_window_size=local_window_size, + use_smooth_softmax=use_smooth_softmax, + model_name=model_name, + ) + plot_token_performance( + sm=sm, + batch_size=batch_size, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + max_seq_len=min(threshold, max_seq_len), + local_window_size=local_window_size, + use_smooth_softmax=use_smooth_softmax, + model_name=model_name, + ) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py index b781ccf03f138..79cc8e41bf343 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py @@ -19,6 +19,7 @@ def save_results(results, filename): "Max Sequence Length", "Sequence Length", "Past Sequence Length", + "Smooth Softmax", "Model Name", ], ) @@ -36,6 +37,7 @@ def benchmark( sequence_length: int = 1, past_sequence_length: int = 0, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, model_name: str = "Llama3-8B", ): warmup = 15 @@ -50,6 +52,7 @@ def benchmark( kv_num_heads=kv_num_heads, head_size=head_size, local_window_size=local_window_size if local_window_size else -1, + use_smooth_softmax=use_smooth_softmax, do_rotary=True, # Most models use rotary positional embeddings is_packed_qkv=model_name in ["Phi-3-mini-128k", "Phi-3-small-128k"], device="cuda", @@ -93,6 +96,8 @@ def run_performance_tests(args): # Reduce max sequence length when GPU memory is not enough. threshold = 131072 if memory_in_gb > 24 else 65536 if memory_in_gb > 12 else 32768 + smooth_softmax = args.use_smooth_softmax + all_metrics = [] for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures: prompt_metrics_model = [] @@ -131,6 +136,7 @@ def run_performance_tests(args): sequence_length=sequence_length, max_seq_len=min(threshold, max_seq_len), local_window_size=local_window_size, + use_smooth_softmax=smooth_softmax, model_name=model_name, ) metrics = [*metrics, batch_size, max_seq_len, sequence_length, 0, model_name] @@ -169,9 +175,10 @@ def run_performance_tests(args): past_sequence_length=past_sequence_length, max_seq_len=min(threshold, max_seq_len), local_window_size=local_window_size, + use_smooth_softmax=smooth_softmax, model_name=model_name, ) - metrics = [*metrics, batch_size, max_seq_len, 1, past_sequence_length, model_name] + metrics = [*metrics, batch_size, max_seq_len, 1, past_sequence_length, smooth_softmax, model_name] token_metrics_model.append(metrics) all_metrics.append(metrics) # Calculate average inference interval and throughput for each model @@ -209,6 +216,15 @@ def run_performance_tests(args): default="flash_attention", help="GQA Kernel to use for benchmarking. Options: flash_attention, memory_efficient", ) + + parser.add_argument( + "--use_smooth_softmax", + required=False, + action="store_true", + help="test smooth softmax", + ) + parser.set_defaults(use_smooth_softmax=False) + args = parser.parse_args() if args.kernel == "memory_efficient": diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py index 0c52ee690af82..108f58e29c9f1 100644 --- a/onnxruntime/test/python/transformers/benchmark_mha.py +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -535,8 +535,8 @@ def __init__(self, config: MultiHeadAttentionConfig, session_options=None, use_t self.ort_session = create_session(config, session_options, use_tf32=use_tf32) self.feed_dict = config.random_inputs() - def infer(self): - return self.ort_session.infer(self.feed_dict) + def infer(self, run_options=None, synchronize=True): + return self.ort_session.infer(self.feed_dict, run_options=run_options, synchronize=synchronize) def measure_latency(cuda_session: CudaSession, input_dict): @@ -1119,7 +1119,6 @@ def _parse_arguments(): assert args.causal, "--has_past need --causal specified" if args.use_gpu: - assert args.torch or not args.causal, "no causal cuda kernel in MHA op" assert torch.cuda.is_available() if not args.torch: assert "CUDAExecutionProvider" in get_available_providers() diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index 84bf30b65a742..13bf51f74389a 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -22,6 +22,7 @@ from onnx import TensorProto, helper from packaging import version from parameterized import parameterized +from test_gqa_cpu import smooth_softmax_ref from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -222,6 +223,7 @@ def create_group_query_attention_graph_prompt( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -246,6 +248,7 @@ def create_group_query_attention_graph_prompt( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -408,6 +411,7 @@ def create_group_query_attention_graph_past( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -434,6 +438,7 @@ def create_group_query_attention_graph_past( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -783,6 +788,7 @@ def gqa_prompt_func( past_kv_format=Formats.BSNH, share_buffer=True, rotary_interleaved=False, + use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_prompt( config, @@ -792,6 +798,7 @@ def gqa_prompt_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None @@ -888,6 +895,7 @@ def gqa_past_func( share_buffer=True, window_size=-1, rotary_interleaved=False, + use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_past( config, @@ -897,6 +905,7 @@ def gqa_past_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() @@ -1033,6 +1042,7 @@ def attention_ref( window_size=(-1, -1), # -1 means infinite window size upcast=True, reorder_ops=False, + use_smooth_softmax=False, ): """ Arguments: @@ -1079,7 +1089,12 @@ def attention_ref( q.device, ) scores.masked_fill_(local_mask, float("-inf")) - attention = torch.softmax(scores, dim=-1) + + if use_smooth_softmax: + attention = smooth_softmax_ref(scores) + else: + attention = torch.softmax(scores, dim=-1) + # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) @@ -1099,7 +1114,14 @@ def attention_ref( def attention_qkvpacked_ref( - qkv, key_padding_mask=None, dropout_p=0.0, dropout_mask=None, causal=False, upcast=True, reorder_ops=False + qkv, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + upcast=True, + reorder_ops=False, + use_smooth_softmax=False, ): return attention_ref( qkv[:, :, 0], @@ -1112,6 +1134,7 @@ def attention_qkvpacked_ref( upcast=upcast, causal=causal, reorder_ops=reorder_ops, + use_smooth_softmax=use_smooth_softmax, ) @@ -1192,6 +1215,7 @@ def parity_check_gqa_prompt( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1306,7 +1330,16 @@ def parity_check_gqa_prompt( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1330,6 +1363,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1346,6 +1380,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1374,6 +1409,7 @@ def parity_check_gqa_prompt_no_buff( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1465,7 +1501,16 @@ def parity_check_gqa_prompt_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + new_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1489,6 +1534,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1505,6 +1551,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1512,7 +1559,8 @@ def parity_check_gqa_prompt_no_buff( err_msg = ( f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}" + f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}," + f" use_smooth_softmax={use_smooth_softmax}" ) # Make sure past-present buffer updating correctly numpy.testing.assert_allclose( @@ -1533,6 +1581,7 @@ def parity_check_gqa_past( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1643,7 +1692,16 @@ def parity_check_gqa_past( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1667,6 +1725,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_past_func( @@ -1683,6 +1742,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -1711,6 +1771,7 @@ def parity_check_gqa_past_no_buff( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1827,7 +1888,16 @@ def parity_check_gqa_past_no_buff( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1851,6 +1921,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_past_func( @@ -1867,6 +1938,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -2137,6 +2209,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=False, ) parity_check_gqa_prompt_no_buff( config, @@ -2146,6 +2219,7 @@ def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleave rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=True, ) @parameterized.expand(gqa_no_past_flash_attention_test_cases()) @@ -2162,6 +2236,7 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=True, ) parity_check_gqa_prompt_no_buff( config, @@ -2170,6 +2245,7 @@ def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_inte rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=False, ) @parameterized.expand(gqa_past_memory_efficient_test_cases()) @@ -2187,6 +2263,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=True, ) parity_check_gqa_past_no_buff( config, @@ -2196,6 +2273,7 @@ def test_gqa_past_memory_efficient(self, _, config, rotary, rotary_interleaved, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=False, ) @parameterized.expand(gqa_past_flash_attention_test_cases()) @@ -2214,6 +2292,7 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=False, ) parity_check_gqa_past_no_buff( config, @@ -2224,6 +2303,7 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, + use_smooth_softmax=True, ) diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index b6b8aee15852f..eeba0baccf15b 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -145,6 +145,7 @@ def create_group_query_attention_graph_prompt( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -169,6 +170,7 @@ def create_group_query_attention_graph_prompt( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -331,6 +333,7 @@ def create_group_query_attention_graph_past( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -357,6 +360,7 @@ def create_group_query_attention_graph_past( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -667,6 +671,7 @@ def gqa_prompt_func( past_kv_format=Formats.BSNH, share_buffer=True, rotary_interleaved=False, + use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_prompt( config, @@ -676,6 +681,7 @@ def gqa_prompt_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None @@ -773,6 +779,7 @@ def gqa_past_func( share_buffer=True, window_size=-1, rotary_interleaved=False, + use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_past( config, @@ -782,6 +789,7 @@ def gqa_past_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() @@ -906,6 +914,13 @@ def construct_local_mask( ) +def smooth_softmax_ref(x): + x_max = x.amax(axis=-1, keepdim=True) + x_max = torch.maximum(x_max, torch.zeros_like(x_max)) + w = torch.exp(x - x_max) + return w * torch.reciprocal(w.sum(axis=-1, keepdim=True) + torch.exp(-x_max)) + + def attention_ref( q, k, @@ -918,6 +933,7 @@ def attention_ref( window_size=(-1, -1), # -1 means infinite window size upcast=True, reorder_ops=False, + use_smooth_softmax=False, ): """ Arguments: @@ -935,6 +951,7 @@ def attention_ref( reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) without changing the math. This is to estimate the numerical error from operation reordering. + use_smooth_softmax: whether use smooth softmax or not Output: output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout @@ -964,10 +981,16 @@ def attention_ref( q.device, ) scores.masked_fill_(local_mask, float("-inf")) - attention = torch.softmax(scores, dim=-1) + + if use_smooth_softmax: + attention = smooth_softmax_ref(scores) + else: + attention = torch.softmax(scores, dim=-1) + # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: @@ -984,7 +1007,14 @@ def attention_ref( def attention_qkvpacked_ref( - qkv, key_padding_mask=None, dropout_p=0.0, dropout_mask=None, causal=False, upcast=True, reorder_ops=False + qkv, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + upcast=True, + reorder_ops=False, + use_smooth_softmax=False, ): return attention_ref( qkv[:, :, 0], @@ -997,6 +1027,7 @@ def attention_qkvpacked_ref( upcast=upcast, causal=causal, reorder_ops=reorder_ops, + use_smooth_softmax=use_smooth_softmax, ) @@ -1008,6 +1039,7 @@ def parity_check_gqa_prompt( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1108,7 +1140,16 @@ def parity_check_gqa_prompt( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1132,6 +1173,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1148,6 +1190,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1172,6 +1215,8 @@ def parity_check_gqa_prompt( rotary, " rotary_interleaved:", rotary_interleaved, + " smooth_softmax:", + use_smooth_softmax, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1201,6 +1246,7 @@ def parity_check_gqa_prompt_no_buff( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1275,7 +1321,16 @@ def parity_check_gqa_prompt_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + new_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1299,6 +1354,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1315,6 +1371,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1339,6 +1396,8 @@ def parity_check_gqa_prompt_no_buff( rotary, " rotary_interleaved:", rotary_interleaved, + " smooth_softmax:", + use_smooth_softmax, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1368,6 +1427,7 @@ def parity_check_gqa_past( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1473,7 +1533,16 @@ def parity_check_gqa_past( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1497,6 +1566,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_past_func( @@ -1513,6 +1583,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -1539,6 +1610,8 @@ def parity_check_gqa_past( rotary, " rotary_interleaved:", rotary_interleaved, + " smooth_softmax:", + use_smooth_softmax, " B:", config.batch_size, " S:", @@ -1566,6 +1639,7 @@ def parity_check_gqa_past_no_buff( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1677,7 +1751,16 @@ def parity_check_gqa_past_no_buff( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1701,6 +1784,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_past_func( @@ -1717,6 +1801,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -1737,6 +1822,8 @@ def parity_check_gqa_past_no_buff( rotary, " rotary_interleaved:", rotary_interleaved, + " smooth_softmax:", + use_smooth_softmax, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1787,26 +1874,29 @@ def test_gqa_no_past(self): for local in [False, True]: for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: for packed in [False, True]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_prompt( - config, - local=local, - past_format=past_kv_format, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_prompt_no_buff( - config, - local=local, - past_format=past_kv_format, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) + for use_smooth_softmax in [False, True]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_prompt( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) def test_gqa_past(self): print("-------- TEST GQA PAST (TOKEN GEN) ---------") @@ -1838,31 +1928,34 @@ def test_gqa_past(self): for local in [False, True]: for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: for packed in [False, True]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) + for use_smooth_softmax in [False, True]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/test_mha.py b/onnxruntime/test/python/transformers/test_mha.py index 5948f8b1ccfc1..158fc0417afbd 100644 --- a/onnxruntime/test/python/transformers/test_mha.py +++ b/onnxruntime/test/python/transformers/test_mha.py @@ -68,6 +68,22 @@ def get_bias_support(format: InputFormats): raise RuntimeError(f"Unknown format: {format}") +def get_causal_support(format: InputFormats): + if format == InputFormats.Q_K_V_BSNH_BSNH_BSNH: + return [True, False] + + if format == InputFormats.Q_K_V_BSNH_BNSH_BNSH: + return [True, False] + + if format == InputFormats.Q_KV_BSNH_BSN2H: + return [True, False] + + if format == InputFormats.QKV_BSN3H: + return [True, False] + + raise RuntimeError(f"Unknown format: {format}") + + def attention_reference( head_size: int, query: torch.Tensor, @@ -194,7 +210,7 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): for num_heads in heads: for head_size in head_sizes: for format in formats: - for causal in [True, False]: + for causal in get_causal_support(format): for mask_format in mask_formats: for has_bias in get_bias_support(format): config = MultiHeadAttentionConfig( @@ -224,8 +240,8 @@ def no_kv_cache_test_cases(provider: str, comprehensive: bool): num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] mask_format = mask_formats[i % len(mask_formats)] - for causal in [True, False]: - for format in formats: + for format in formats: + for causal in get_causal_support(format): for has_bias in get_bias_support(format): config = MultiHeadAttentionConfig( batch_size=batch_size, @@ -271,7 +287,7 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): for num_heads in heads: for head_size in head_sizes: for format in formats: - for causal in [True, False]: + for causal in get_causal_support(format): for has_past_input in [True, False]: for mask_format in mask_formats: for has_bias in get_bias_support(format): @@ -305,8 +321,8 @@ def kv_cache_test_cases(provider: str, comprehensive: bool): num_heads = heads[i % len(heads)] head_size = head_sizes[i % len(head_sizes)] mask_format = mask_formats[i % len(mask_formats)] - for causal in [True, False]: - for format in formats: + for format in formats: + for causal in get_causal_support(format): for has_past_input in [True, False]: for has_bias in get_bias_support(format): sequence_length = 1 if has_past_input else past_sequence_length @@ -346,7 +362,7 @@ def no_kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): device, dtype, formats = get_provider_support_info(provider, False) for format in formats: - for causal in [True, False]: + for causal in get_causal_support(format): for num_heads in heads: for head_size in head_sizes: configs = [] # list of configurations to run in parallel @@ -386,7 +402,7 @@ def kv_cache_multi_thread_test_cases(provider: str, comprehensive: bool): device, dtype, formats = get_provider_support_info(provider, True) for format in formats: - for causal in [True, False]: + for causal in get_causal_support(format): for num_heads in heads: for head_size in head_sizes: configs = [] @@ -443,12 +459,8 @@ def parity_check_mha( rtol=1e-3, atol=1e-3, ): - # CUDA kernel does not support causal so skip such test cases. - if config.causal and config.provider == "CUDAExecutionProvider": - return - ort_mha = OrtMultiHeadAttention(config, use_tf32=False) - ort_outputs = ort_mha.infer() + ort_outputs = ort_mha.infer(synchronize=True) out = ort_outputs["output"] out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -545,9 +557,6 @@ def parity_check_mha_multi_threading( ): # Use the first config to create a session, which is shared by all configs to run in parallel. config = test_inputs[0]["config"] - # For now, MHA CUDA kernel does not support causal so skip such test cases. - if config.causal and config.provider == "CUDAExecutionProvider": - return None # Some kernel does not support certain input format. if attention_kernel not in [ @@ -693,6 +702,10 @@ def run_mha_cpu(self): def run_mha_cuda_multi_threading(self, attention_kernel): for configs in multi_thread_test_cases("CUDAExecutionProvider", comprehensive_mode): + if configs and configs[0].causal and (SdpaKernel.TRT_CAUSAL_ATTENTION & attention_kernel != 0): + # TRT fused causal is disabled by default so skip the test of causal for multi-threading. + continue + test_inputs = [] for config in configs: ort_inputs = config.random_inputs() diff --git a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py b/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py deleted file mode 100644 index 00704626028a0..0000000000000 --- a/onnxruntime/test/python/transformers/test_parity_mixtral_moe.py +++ /dev/null @@ -1,361 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright 2020 The HuggingFace Inc. team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# -------------------------------------------------------------------------- -import unittest -from collections import OrderedDict - -import numpy -import torch -import torch.nn.functional as F -from onnx import TensorProto, helper -from torch import nn - -import onnxruntime - -torch.manual_seed(42) -numpy.random.seed(42) - -ORT_DTYPE = TensorProto.FLOAT -NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 -THRESHOLD = 3e-2 - - -def value_string_of(numpy_array): - arr = numpy_array.flatten() - lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)] - return "{\n " + "f,\n ".join(lines) + "f}" - - -def print_tensor(name, numpy_array): - print(f"const std::vector {name} = {value_string_of(numpy_array)};") - - -def create_moe_onnx_graph( - num_rows, - num_experts, - hidden_size, - inter_size, - fc1_experts_weights, - fc2_experts_weights, - fc3_experts_weights, - topk, -): - nodes = [ - helper.make_node( - "MoE", - [ - "input", - "router_probs", - "fc1_experts_weights", - "", - "fc2_experts_weights", - "", - "fc3_experts_weights", - ], - ["output"], - "MoE_0", - k=topk, - normalize_routing_weights=1, - activation_type="silu", - domain="com.microsoft", - ), - ] - - fc1_shape = [num_experts, hidden_size, inter_size] - fc2_shape = [num_experts, inter_size, hidden_size] - fc3_shape = [num_experts, hidden_size, inter_size] - - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 - - initializers = [ - helper.make_tensor( - "fc1_experts_weights", - ORT_DTYPE, - fc1_shape, - fc1_experts_weights.to(torch_type).flatten().tolist(), - raw=False, - ), - helper.make_tensor( - "fc2_experts_weights", - ORT_DTYPE, - fc2_shape, - fc2_experts_weights.to(torch_type).flatten().tolist(), - raw=False, - ), - helper.make_tensor( - "fc3_experts_weights", - ORT_DTYPE, - fc3_shape, - fc3_experts_weights.to(torch_type).flatten().tolist(), - raw=False, - ), - ] - - graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), - ] - - graph_inputs.append( - helper.make_tensor_value_info( - "router_probs", - ORT_DTYPE, - [num_rows, num_experts], - ) - ) - - graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), - ] - - graph = helper.make_graph( - nodes, - "MoE_Graph", - graph_inputs, - graph_outputs, - initializers, - ) - - model = helper.make_model(graph) - return model.SerializeToString() - - -class ClassInstantier(OrderedDict): - def __getitem__(self, key): - content = super().__getitem__(key) - cls, kwargs = content if isinstance(content, tuple) else (content, {}) - return cls(**kwargs) - - -ACT2CLS = { - "silu": nn.SiLU, -} -ACT2FN = ClassInstantier(ACT2CLS) - - -class MixtralConfig: - def __init__( - self, - hidden_size=4096, - intermediate_size=14336, - num_hidden_layers=32, - num_attention_heads=32, - num_key_value_heads=8, - hidden_act="silu", - initializer_range=0.02, - rms_norm_eps=1e-5, - use_cache=True, - rope_theta=1e6, - attention_dropout=0.0, - num_experts_per_tok=2, - num_local_experts=8, - output_router_logits=False, - router_aux_loss_coef=0.001, - ): - self.hidden_size = hidden_size - self.intermediate_size = intermediate_size - self.num_hidden_layers = num_hidden_layers - self.num_attention_heads = num_attention_heads - if num_key_value_heads is None: - num_key_value_heads = num_attention_heads - self.num_key_value_heads = num_key_value_heads - self.hidden_act = hidden_act - self.initializer_range = initializer_range - self.rms_norm_eps = rms_norm_eps - self.use_cache = use_cache - self.rope_theta = rope_theta - self.attention_dropout = attention_dropout - self.num_experts_per_tok = num_experts_per_tok - self.num_local_experts = num_local_experts - self.output_router_logits = output_router_logits - self.router_aux_loss_coef = router_aux_loss_coef - - -class MixtralBlockSparseTop2MLP(nn.Module): - def __init__(self, config: MixtralConfig): - super().__init__() - self.ffn_dim = config.intermediate_size - self.hidden_dim = config.hidden_size - - self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - - self.act_fn = ACT2FN[config.hidden_act] - - def forward(self, hidden_states): - current_hidden_states_1 = self.act_fn(self.w1(hidden_states)) - current_hidden_states_3 = self.w3(hidden_states) - current_hidden_states = current_hidden_states_1 * current_hidden_states_3 - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states - - -class MixtralSparseMoeBlock(nn.Module): - """ - This implementation is - strictly equivalent to standard MoE with full capacity (no - dropped tokens). It's faster since it formulates MoE operations - in terms of block-sparse operations to accommodate imbalanced - assignments of tokens to experts, whereas standard MoE either - (1) drop tokens at the cost of reduced performance or (2) set - capacity factor to number of experts and thus waste computation - and memory on padding. - """ - - def __init__(self, config, batch_size, sequence_length): - super().__init__() - self.hidden_dim = config.hidden_size - self.ffn_dim = config.intermediate_size - self.num_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok - - # gating - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - - self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) - - w1_list = [] - w2_list = [] - w3_list = [] - for i in range(self.num_experts): - w1_list.append(self.experts[i].w1.weight) - w2_list.append(self.experts[i].w2.weight) - w3_list.append(self.experts[i].w3.weight) - - self.moe_experts_weight1 = torch.stack(w1_list, dim=0) - self.moe_experts_weight2 = torch.stack(w2_list, dim=0) - self.moe_experts_weight3 = torch.stack(w3_list, dim=0) - - self.batch_size = batch_size - self.sequence_length = sequence_length - self.moe_onnx_graph = create_moe_onnx_graph( - self.batch_size * self.sequence_length, - self.num_experts, - self.hidden_dim, - self.ffn_dim, - self.moe_experts_weight1, - self.moe_experts_weight2, - self.moe_experts_weight3, - self.top_k, - ) - - self.ort_sess = self.create_ort_session() - - def create_ort_session(self): - from onnxruntime import InferenceSession, SessionOptions - - sess_options = SessionOptions() - - cuda_providers = ["CUDAExecutionProvider"] - if cuda_providers[0] not in onnxruntime.get_available_providers(): - return None - - sess_options.log_severity_level = 2 - ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) - - return ort_session - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - """ """ - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - - routing_weights /= routing_weights.sum(dim=-1, keepdim=True) - # we cast back to the input dtype - routing_weights = routing_weights.to(hidden_states.dtype) - - final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device - ) - - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx]) - - if top_x.shape[0] == 0: - continue - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states # , router_logits - - def ort_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) - # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - ort_inputs = { - "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), - "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), - } - - ort_output = None - if self.ort_sess is not None: - ort_output = self.ort_sess.run(None, ort_inputs) - return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits - - # print_tensor("input", ort_inputs["input"]) - # print_tensor("router_probs", ort_inputs["router_probs"]) - # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) - # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) - # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) - # print_tensor("output", ort_output[0]) - - return None - - def parity_check(self): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) - torch_output = self.forward(hidden_state) - ort_output = self.ort_forward(hidden_state) - if ort_output is not None: - assert torch.allclose(torch_output, ort_output, rtol=1e-04, atol=1e-04) - print( - "batch_size:", - self.batch_size, - " sequence_length:", - self.sequence_length, - " max_diff:", - (torch_output - ort_output).abs().max(), - " parity: OK", - ) - - -class TestMixtralMoE(unittest.TestCase): - def test_mixtral_moe_parity(self): - for batch_size in [1, 16]: - for sequence_length in [128, 1024]: - # use a small sizes to speed up the test - config = MixtralConfig(hidden_size=256, intermediate_size=1024) - mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) - mixtral_moe.parity_check() - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_parity_moe.py index be288d8b6e360..1e7940e38335f 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_parity_moe.py @@ -8,28 +8,26 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. -# ------------------------------------------------------------------------- - -import platform -import time +# -------------------------------------------------------------------------- import unittest +from collections import OrderedDict import numpy -import pytest import torch -import torch.nn as nn import torch.nn.functional as F from onnx import TensorProto, helper +from parameterized import parameterized +from torch import nn import onnxruntime torch.manual_seed(42) numpy.random.seed(42) - -ORT_DTYPE = TensorProto.FLOAT16 +USE_QUANT = False +ORT_DTYPE = TensorProto.FLOAT16 if USE_QUANT else TensorProto.FLOAT NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 -THRESHOLD = 3e-2 +THRESHOLD = 5e-1 if USE_QUANT else 1e-2 def value_string_of(numpy_array): @@ -42,8 +40,30 @@ def print_tensor(name, numpy_array): print(f"const std::vector {name} = {value_string_of(numpy_array)};") +def quant_dequant(weights, quant_mode: bool = True): + # use the test version `_symmetric_...` to get the non-interleaved weights + type = torch.quint4x2 if quant_mode else torch.int8 + # This import is needed to use torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix() + # Comment out this line for passing the lintrunner check in the CI. + # import tensorrt_llm + + quant_weights, processed_q_weight, torch_weight_scales = ( + torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix(weights.T.cpu().contiguous(), type) + ) + + # Unpack the int4s int int8s + if quant_mode: + upper = quant_weights >> 4 + lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends + quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape) + + quant_weights = quant_weights.to(dtype=weights.dtype) + result = torch.multiply(quant_weights, torch_weight_scales.unsqueeze(0)).T.contiguous() + return torch_weight_scales.to(torch.float16), processed_q_weight, result.to(device=weights.device) + + def create_moe_onnx_graph( - num_rows, + sequence_length, num_experts, hidden_size, inter_size, @@ -115,19 +135,265 @@ def create_moe_onnx_graph( ) graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [num_rows, hidden_size]), + helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + ORT_DTYPE, + [sequence_length, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_mixtral_moe_onnx_graph( + sequence_length, + num_experts, + hidden_size, + inter_size, + fc1_experts_weights, + fc2_experts_weights, + fc3_experts_weights, + topk, +): + nodes = [ + helper.make_node( + "MoE", + [ + "input", + "router_probs", + "fc1_experts_weights", + "", + "fc2_experts_weights", + "", + "fc3_experts_weights", + ], + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, + activation_type="silu", + domain="com.microsoft", + ), + ] + + fc1_shape = [num_experts, hidden_size, inter_size] + fc2_shape = [num_experts, inter_size, hidden_size] + fc3_shape = [num_experts, hidden_size, inter_size] + + torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + ORT_DTYPE, + fc1_shape, + fc1_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_weights", + ORT_DTYPE, + fc2_shape, + fc2_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc3_experts_weights", + ORT_DTYPE, + fc3_shape, + fc3_experts_weights.to(torch_type).flatten().tolist(), + raw=False, + ), + ] + + graph_inputs = [ + helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + ORT_DTYPE, + [sequence_length, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_phi_moe_onnx_graph( + sequence_length, + num_experts, + hidden_size, + inter_size, + fc1_experts_weights, + fc2_experts_weights, + fc3_experts_weights, + fc1_scales, + fc2_scales, + fc3_scales, + topk, +): + use_quant = USE_QUANT + if use_quant: + assert fc1_experts_weights.dtype == torch.int8 + assert fc2_experts_weights.dtype == torch.int8 + assert fc3_experts_weights.dtype == torch.int8 + assert fc1_scales is not None + assert fc2_scales is not None + assert fc3_scales is not None + assert fc1_scales.dtype == torch.float16 + assert fc2_scales.dtype == torch.float16 + assert fc3_scales.dtype == torch.float16 + + nodes = [ + helper.make_node( + "MoE" if not use_quant else "QMoE", + ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "", + "fc2_experts_weights", + "", + "fc3_experts_weights", + ] + if not use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_scales", + "", + "fc2_experts_weights", + "fc2_scales", + "", + "fc3_experts_weights", + "fc3_scales", + "", + ] + ), + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=0, + use_sparse_mixer=1, + activation_type="silu", + domain="com.microsoft", + ), + ] + + if use_quant: + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", 8)]) + + fc1_shape = [num_experts, hidden_size, inter_size] + fc2_shape = [num_experts, inter_size, hidden_size] + fc3_shape = [num_experts, hidden_size, inter_size] + + torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + numpy_type = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 + if use_quant: + numpy_type = numpy.uint8 + + initializers = [ + helper.make_tensor( + "fc1_experts_weights", + ORT_DTYPE if not use_quant else TensorProto.UINT8, + fc1_shape, + fc1_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_experts_weights", + ORT_DTYPE if not use_quant else TensorProto.UINT8, + fc2_shape, + fc2_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + raw=False, + ), + helper.make_tensor( + "fc3_experts_weights", + ORT_DTYPE if not use_quant else TensorProto.UINT8, + fc3_shape, + fc3_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + raw=False, + ), + ] + + if use_quant: + fc1_scale_shape = [num_experts, inter_size] + fc2_scale_shape = [num_experts, hidden_size] + fc3_scale_shape = [num_experts, inter_size] + initializers.extend( + [ + helper.make_tensor( + "fc1_scales", + ORT_DTYPE, + fc1_scale_shape, + fc1_scales.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc2_scales", + ORT_DTYPE, + fc2_scale_shape, + fc2_scales.to(torch_type).flatten().tolist(), + raw=False, + ), + helper.make_tensor( + "fc3_scales", + ORT_DTYPE, + fc3_scale_shape, + fc3_scales.to(torch_type).flatten().tolist(), + raw=False, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", ORT_DTYPE, - [num_rows, num_experts], + [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [num_rows, hidden_size]), + helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -142,13 +408,52 @@ def create_moe_onnx_graph( return model.SerializeToString() -def get_activation_fn(activation): - if activation == "relu": - return nn.ReLU - elif activation == "gelu": - return nn.GELU - else: - raise NotImplementedError +class ClassInstantier(OrderedDict): + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + "silu": nn.SiLU, + "gelu": nn.GELU, +} +ACT2FN = ClassInstantier(ACT2CLS) + + +class MixtralConfig: + def __init__( + self, + hidden_size=4096, + intermediate_size=14336, + hidden_act="silu", + num_experts_per_tok=2, + num_local_experts=8, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + + +class PhiMoEConfig: + def __init__( + self, + hidden_size=4096, + intermediate_size=14336, + hidden_act="silu", + num_experts_per_tok=2, + num_local_experts=8, + router_jitter_noise=0.01, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.router_jitter_noise = router_jitter_noise class MoEGate(nn.Module): @@ -184,14 +489,9 @@ def __init__( hidden_features=None, out_features=None, act_layer=nn.GELU, - drop=0.0, bias=True, - chunk_size=-1, ): super().__init__() - # assert bias is False, "Current bias is not supported" - assert drop == 0.0, "Current drop is not supported" - assert chunk_size == -1, "Current chunk is not supported" self.weight1 = nn.Parameter(torch.rand(num_experts, in_features, hidden_features)) self.weight2 = nn.Parameter(torch.rand(num_experts, hidden_features, out_features)) @@ -217,50 +517,39 @@ def bmm(self, x, weight, indices_s): return x -class MoE(nn.Module): - def __init__( - self, - batch_size, - num_rows, - num_experts, - in_features, - hidden_features=None, - out_features=None, - eval_capacity=-1, - activation="gelu", - ): +class MoEBlockSparseTop2MLP(nn.Module): + def __init__(self, config): super().__init__() - self.num_experts = num_experts - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.eval_capacity = eval_capacity # -1 means we route all tokens + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size - self.gate = MoEGate(num_experts=num_experts, in_features=in_features) - self.moe_experts = MoERuntimeExperts( - num_experts=num_experts, - in_features=in_features, - hidden_features=hidden_features, - out_features=out_features, - act_layer=get_activation_fn(activation), - bias=True, - ) + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.moe_onnx_graph = create_moe_onnx_graph( - batch_size * num_rows, - num_experts, - in_features, - hidden_features, - self.moe_experts.weight1.transpose(1, 2), - self.moe_experts.bias1, - self.moe_experts.weight2.transpose(1, 2), - self.moe_experts.bias2, - ) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralBlockSparseTop2MLP(MoEBlockSparseTop2MLP): + def __init__(self, config: MixtralConfig): + super().__init__(config) - self.ort_sess = self.create_ort_session() - self.torch_input = torch.randn(batch_size, num_rows, in_features) +class PhiMoEBlockSparseTop2MLP(MoEBlockSparseTop2MLP): + def __init__(self, config: PhiMoEConfig): + super().__init__(config) - def create_ort_session(self): + +class SparseMoeBlockORTHelper(nn.Module): + def __init__(self): + super().__init__() + + def create_ort_session(self, moe_onnx_graph): from onnxruntime import InferenceSession, SessionOptions sess_options = SessionOptions() @@ -270,10 +559,42 @@ def create_ort_session(self): return None sess_options.log_severity_level = 2 - ort_session = InferenceSession(self.moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) + ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) return ort_session + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + pass + + def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + ort_inputs = { + "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), + "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), + } + + ort_output = None + if self.ort_sess is not None: + if not iobinding: + ort_output = self.ort_sess.run(None, ort_inputs) + return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits + else: + self.ort_run_with_iobinding(ort_inputs) + return None + + # print_tensor("input", ort_inputs["input"]) + # print_tensor("router_probs", ort_inputs["router_probs"]) + # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) + # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) + # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) + # print_tensor("output", ort_output[0]) + + return None + def ort_run_with_iobinding(self, ort_inputs, repeat=1000): iobinding = self.ort_sess.io_binding() device_id = torch.cuda.current_device() @@ -286,6 +607,7 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): shape=ort_inputs["input"].shape, buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), ) + iobinding.bind_input( name="router_probs", device_type="cuda", @@ -308,6 +630,14 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): ).data_ptr(), ) + # warm up + for _ in range(5): + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + + import time + s = time.time() for _ in range(repeat): iobinding.synchronize_inputs() @@ -316,117 +646,389 @@ def ort_run_with_iobinding(self, ort_inputs, repeat=1000): e = time.time() print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") - def torch_forward(self): - x = self.torch_input + def parity_check(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) + torch_output = self.forward(hidden_state) + ort_output = self.ort_forward(hidden_state) + if ort_output is not None: + assert torch.allclose(torch_output, ort_output.to(torch.float32), rtol=THRESHOLD, atol=THRESHOLD) + print( + "name:", + self.__class__.__name__, + " batch_size:", + self.batch_size, + " sequence_length:", + self.sequence_length, + " max_diff:", + (torch_output - ort_output).abs().max(), + " parity: OK", + ) + + def benchmark_ort(self): + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) + self.ort_forward(hidden_state, iobinding=True) + + +class SwitchMoE(SparseMoeBlockORTHelper): + def __init__( + self, + batch_size, + sequence_length, + num_experts, + in_features, + hidden_features=None, + out_features=None, + eval_capacity=-1, + activation="gelu", + ): + super().__init__() + self.batch_size = batch_size + self.sequence_length = sequence_length + self.num_experts = num_experts + self.hidden_dim = in_features + self.ffn_dim = hidden_features + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.eval_capacity = eval_capacity # -1 means we route all tokens + + self.gate = MoEGate(num_experts=num_experts, in_features=in_features) + self.moe_experts = MoERuntimeExperts( + num_experts=num_experts, + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + act_layer=ACT2CLS[activation], + bias=True, + ) + + self.moe_onnx_graph = create_moe_onnx_graph( + batch_size * sequence_length, + num_experts, + in_features, + hidden_features, + self.moe_experts.weight1.transpose(1, 2), + self.moe_experts.bias1, + self.moe_experts.weight2.transpose(1, 2), + self.moe_experts.bias2, + ) - b, t, c = x.shape - x = x.reshape(-1, c) - logits = self.gate(x) + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) + + self.torch_input = torch.randn(batch_size, sequence_length, in_features) + + def forward(self, hidden_states): + b, t, c = hidden_states.shape + hidden_states = hidden_states.reshape(-1, c) + logits = self.gate(hidden_states) gates = torch.nn.functional.softmax(logits, dim=1) ret = torch.max(gates, dim=1) indices_s = ret.indices # dim: [bs], the index of the expert with highest softmax value scores = ret.values.unsqueeze(-1).unsqueeze(-1) # S - x = self.moe_experts(x, indices_s) + hidden_states = self.moe_experts(hidden_states, indices_s) - x = x * scores - x = x.reshape(b * t, c) + hidden_states = hidden_states * scores + hidden_states = hidden_states.reshape(b, t, c) - return x, torch.sum(x) + return hidden_states - def onnx_forward(self, iobinding=False): - x = self.torch_input - _, _, c = x.shape - y = x.reshape(-1, c) - logits = self.gate(y) +class MixtralSparseMoeBlock(SparseMoeBlockORTHelper): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ - ort_inputs = { - "input": numpy.ascontiguousarray(y.detach().numpy().astype(NP_TYPE)), - "router_probs": numpy.ascontiguousarray(logits.detach().numpy().astype(NP_TYPE)), - } + def __init__(self, config, batch_size, sequence_length): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + w1_list = [] + w2_list = [] + w3_list = [] + for i in range(self.num_experts): + w1_list.append(self.experts[i].w1.weight) + w2_list.append(self.experts[i].w2.weight) + w3_list.append(self.experts[i].w3.weight) + + self.moe_experts_weight1 = torch.stack(w1_list, dim=0) + self.moe_experts_weight2 = torch.stack(w2_list, dim=0) + self.moe_experts_weight3 = torch.stack(w3_list, dim=0) + + self.batch_size = batch_size + self.sequence_length = sequence_length + self.moe_onnx_graph = create_mixtral_moe_onnx_graph( + self.batch_size * self.sequence_length, + self.num_experts, + self.hidden_dim, + self.ffn_dim, + self.moe_experts_weight1, + self.moe_experts_weight2, + self.moe_experts_weight3, + self.top_k, + ) - ort_output = None - if self.ort_sess is not None: - if not iobinding: - ort_output = self.ort_sess.run(None, ort_inputs) - else: - self.ort_run_with_iobinding(ort_inputs) - return None + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) - # print_tensor("input", ort_inputs["input"]) - # print_tensor("router_probs", ort_inputs["router_probs"]) - # print_tensor("fc1_experts_weights", self.moe_experts.weight1.detach().numpy()) - # print_tensor("fc1_experts_bias", self.moe_experts.bias1.detach().numpy()) - # print_tensor("fc2_experts_weights", self.moe_experts.weight2.detach().numpy()) - # print_tensor("fc2_experts_bias", self.moe_experts.bias2.detach().numpy()) - # print_tensor("output", ort_output[0]) + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) - return ort_output + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) - def parity_check(self): - torch_out = self.torch_forward() - ort_out = self.onnx_forward() - if ort_out is not None: - # print("max diff", numpy.max(numpy.abs(torch_out[0].detach().numpy() - ort_out[0]))) - assert numpy.allclose(torch_out[0].detach().numpy(), ort_out[0], rtol=THRESHOLD, atol=THRESHOLD) - - def benchmark(self): - self.onnx_forward(iobinding=True) - - -class TestMoE(unittest.TestCase): - def test_moe_small(self): - if platform.system() == "Windows": - pytest.skip("Skip on Windows") - rt = MoE( - batch_size=2, - num_rows=8, - num_experts=4, - in_features=16, - hidden_features=32, - out_features=16, + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states # , router_logits + + +def masked_sampling_omp_inference(scores, top_k, jitter_eps, training): + assert top_k == 2 + assert not training + + mask_logits_threshold, selected_experts = torch.topk(scores, 2) + + mask_logits_threshold_1 = mask_logits_threshold[:, 0].unsqueeze(-1) + + factor = scores.abs().clamp(min=mask_logits_threshold_1) + logits_mask = ((mask_logits_threshold_1 - scores) / factor) > (2 * jitter_eps) + + multiplier_1 = torch.softmax(scores.masked_fill(logits_mask, float("-inf")), dim=-1).gather( + dim=-1, index=selected_experts[:, 0].unsqueeze(-1) + ) + + ################ second expert gating ################ + + mask_logits_threshold_2 = mask_logits_threshold[:, 1].unsqueeze(-1) + + factor = scores.abs().clamp(min=mask_logits_threshold_2) + logits_mask = ((mask_logits_threshold_2 - scores) / factor) > (2 * jitter_eps) + + multiplier_2 = torch.softmax( + torch.scatter(scores, -1, selected_experts[:, 0].unsqueeze(-1), float("-inf")).masked_fill( + logits_mask, float("-inf") + ), + dim=-1, + ).gather(dim=-1, index=selected_experts[:, 1].unsqueeze(-1)) + + multiplier = torch.concat((multiplier_1, multiplier_2), dim=-1) + + return ( + multiplier, + selected_experts, + ) + + +class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accommodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config, batch_size, sequence_length): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + self.router_jitter_noise = config.router_jitter_noise + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + + self.experts = nn.ModuleList([PhiMoEBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + w1_list = [] + w2_list = [] + w3_list = [] + w1_scale_list = [] + w2_scale_list = [] + w3_scale_list = [] + if not USE_QUANT: + for i in range(self.num_experts): + w1_list.append(self.experts[i].w1.weight) + w2_list.append(self.experts[i].w2.weight) + w3_list.append(self.experts[i].w3.weight) + else: + for i in range(self.num_experts): + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, False) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, False) + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, False) + + self.experts[i].w1.weight.data = w1_qdq + self.experts[i].w2.weight.data = w2_qdq + self.experts[i].w3.weight.data = w3_qdq + + w1_list.append(pre_qweight1) + w2_list.append(pre_qweight2) + w3_list.append(pre_qweight3) + w1_scale_list.append(w1_scale) + w2_scale_list.append(w2_scale) + w3_scale_list.append(w3_scale) + + self.moe_experts_weight1 = torch.stack(w1_list, dim=0) + self.moe_experts_weight2 = torch.stack(w2_list, dim=0) + self.moe_experts_weight3 = torch.stack(w3_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if USE_QUANT else None + moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if USE_QUANT else None + moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if USE_QUANT else None + + self.batch_size = batch_size + self.sequence_length = sequence_length + self.moe_onnx_graph = create_phi_moe_onnx_graph( + self.batch_size * self.sequence_length, + self.num_experts, + self.hidden_dim, + self.ffn_dim, + self.moe_experts_weight1, + self.moe_experts_weight2, + self.moe_experts_weight3, + moe_experts_weight_scale1, + moe_experts_weight_scale2, + moe_experts_weight_scale3, + self.top_k, + ) + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + + routing_weights, selected_experts = masked_sampling_omp_inference( + router_logits, + top_k=self.top_k, + jitter_eps=self.router_jitter_noise, + training=False, + ) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + + return final_hidden_states # , router_logits + + +def small_test_cases(): + for batch_size in [1, 4, 16]: + for sequence_length in [128, 512, 1024]: + yield batch_size, sequence_length + + +class TestSwitchMoE(unittest.TestCase): + @parameterized.expand(small_test_cases()) + def test_switch_moe_parity(self, batch_size, sequence_length): + # if platform.system() == "Windows": + # pytest.skip("Skip on Windows") + switch_moe = SwitchMoE( + batch_size=batch_size, + sequence_length=sequence_length, + num_experts=8, + in_features=256, + hidden_features=1024, + out_features=256, ) - rt.parity_check() - - @pytest.mark.slow - def test_moe_large(self): - for batch_size in [1, 8]: - for num_rows in [16, 64]: - for num_experts in [16, 64]: - for in_features in [256]: - for hidden_features in [512]: - print( - f"batch_size={batch_size}, num_rows={num_rows}, num_experts={num_experts}, in_features={in_features}, hidden_features={hidden_features}" - ) - rt = MoE( - batch_size=batch_size, - num_rows=num_rows, - num_experts=num_experts, - in_features=in_features, - hidden_features=hidden_features, - out_features=in_features, - ) - rt.parity_check() - - @pytest.mark.slow - def test_moe_benchmark(self): - for batch_size in [32, 64]: - for num_rows in [128, 512]: - for num_experts in [64, 128]: - for in_features in [256, 512]: - for hidden_features in [1024, 2048]: - print( - f"batch_size={batch_size}, num_rows={num_rows}, num_experts={num_experts}, in_features={in_features}, hidden_features={hidden_features}" - ) - rt = MoE( - batch_size=batch_size, - num_rows=num_rows, - num_experts=num_experts, - in_features=in_features, - hidden_features=hidden_features, - out_features=in_features, - ) - rt.benchmark() + switch_moe.parity_check() + # switch_moe.benchmark_ort() + + +class TestMixtralMoE(unittest.TestCase): + @parameterized.expand(small_test_cases()) + def test_mixtral_moe_parity(self, batch_size, sequence_length): + config = MixtralConfig(hidden_size=256, intermediate_size=1024) + mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) + mixtral_moe.parity_check() + # mixtral_moe.benchmark_ort() + + +class TestPhiMoE(unittest.TestCase): + @parameterized.expand(small_test_cases()) + def test_phi3_moe_parity(self, batch_size, sequence_length): + config = PhiMoEConfig(hidden_size=256, intermediate_size=1024) + phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length) + phi3_moe.parity_check() + # phi3_moe.benchmark_ort() if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index f18bcdba65579..0f3947db18e4b 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -14,6 +14,7 @@ from benchmark_mha import InputFormats from onnx import TensorProto, helper from parameterized import parameterized +from test_gqa_cpu import smooth_softmax_ref from torch import Tensor from onnxruntime import InferenceSession, SessionOptions, get_available_providers @@ -43,6 +44,7 @@ def __init__( is_packed_qkv: bool = False, max_cache_sequence_length=None, max_rotary_sequence_length=None, + use_smooth_softmax: bool = False, ): self.operator = operator self.batch_size = batch_size @@ -73,6 +75,8 @@ def __init__( self.share_buffer = share_buffer self.is_packed_qkv = is_packed_qkv + self.use_smooth_softmax = use_smooth_softmax + def shape_dict(self): shapes = { "query": ( @@ -166,6 +170,7 @@ def __init__( is_packed_qkv=False, max_cache_sequence_length=None, max_rotary_sequence_length=None, + use_smooth_softmax: bool = False, ): super().__init__( "GroupQueryAttention", @@ -185,6 +190,7 @@ def __init__( is_packed_qkv=is_packed_qkv, max_cache_sequence_length=max_cache_sequence_length, max_rotary_sequence_length=max_rotary_sequence_length, + use_smooth_softmax=use_smooth_softmax, ) # local_window_size is for ORT only, not for Torch implementation. self.local_window_size = local_window_size @@ -529,6 +535,7 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): local_window_size=config.local_window_size, do_rotary=1 if config.do_rotary else 0, rotary_interleaved=config.rotary_interleaved, + smooth_softmax=1 if config.use_smooth_softmax else 0, domain="com.microsoft", ), ] @@ -612,7 +619,12 @@ def group_query_attention_reference( attn = torch.einsum("bhmd,bhnd->bhmn", query, key).float() * scale if mask is not None: attn = attn.masked_fill((1 - mask).bool(), float("-inf")) - attn = attn.softmax(-1) + + if config.use_smooth_softmax: + attn = smooth_softmax_ref(attn) + else: + attn = attn.softmax(-1) + attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value) result = attn_output.transpose(1, 2).contiguous() diff --git a/onnxruntime/test/testdata/coreml_argmax_cast_test.onnx b/onnxruntime/test/testdata/coreml_argmax_cast_test.onnx index db806f296aff3..931bd30dbe62f 100644 --- a/onnxruntime/test/testdata/coreml_argmax_cast_test.onnx +++ b/onnxruntime/test/testdata/coreml_argmax_cast_test.onnx @@ -1,4 +1,5 @@ -:Ä + +:Ä F Xargmax_output_int64argmax"ArgMax* axis * @@ -15,4 +16,4 @@ F    -B \ No newline at end of file +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/coreml_argmax_cast_test.py b/onnxruntime/test/testdata/coreml_argmax_cast_test.py index acf24ac379065..6cc25311131a0 100644 --- a/onnxruntime/test/testdata/coreml_argmax_cast_test.py +++ b/onnxruntime/test/testdata/coreml_argmax_cast_test.py @@ -1,16 +1,18 @@ import onnx from onnx import TensorProto, helper -# CoreML EP currently handles a special case for supporting ArgMax op -# Please see in /onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and -# /onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc -# We have this separated test script to generate graph for the case: An ArgMax followed by a Cast to int32 type +# CoreML EP currently handles a special case for supporting ArgMax followed by a Cast to int32. +# Please see /onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and +# /onnxruntime/core/providers/coreml/builders/impl/cast_op_builder.cc. +# This script generates graphs for these cases: +# - An ArgMax followed by a supported Cast to int32 type +# - An ArgMax followed by an unsupported Cast to a type other than int32 -def GenerateModel(model_name): # noqa: N802 +def GenerateModel(model_name, cast_to_dtype): # noqa: N802 nodes = [ helper.make_node("ArgMax", ["X"], ["argmax_output_int64"], "argmax", axis=1, keepdims=1), - helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=6), # cast to int32 type + helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=cast_to_dtype), ] graph = helper.make_graph( @@ -20,7 +22,7 @@ def GenerateModel(model_name): # noqa: N802 helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2, 2]), ], [ # output - helper.make_tensor_value_info("Y", TensorProto.INT32, [3, 1, 2]), + helper.make_tensor_value_info("Y", cast_to_dtype, [3, 1, 2]), ], ) @@ -29,4 +31,5 @@ def GenerateModel(model_name): # noqa: N802 if __name__ == "__main__": - GenerateModel("coreml_argmax_cast_test.onnx") + GenerateModel("coreml_argmax_cast_test.onnx", TensorProto.INT32) + GenerateModel("coreml_argmax_unsupported_cast_test.onnx", TensorProto.UINT32) diff --git a/onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx b/onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx new file mode 100644 index 0000000000000..d5aea9110cbfa --- /dev/null +++ b/onnxruntime/test/testdata/coreml_argmax_unsupported_cast_test.onnx @@ -0,0 +1,19 @@ + +:Ä +F +Xargmax_output_int64argmax"ArgMax* +axis * +keepdims  +/ +argmax_output_int64Ycast"Cast* +to  CoreML_ArgMax_Cast_TestZ +X + + + +b +Y +  + + +B \ No newline at end of file diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index c48968efbb262..004e3540c62d6 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -10,7 +10,6 @@ import torch.onnx.symbolic_helper as sym_help from packaging import version from packaging.version import Version -from torch.onnx import register_custom_op_symbolic from torch.onnx.symbolic_helper import parse_args from onnxruntime.training.utils import pytorch_type_to_onnx_dtype @@ -70,6 +69,8 @@ def register(cls, name, domain, fn): @classmethod def register_all(cls, onnx_opset_version): + from torch.onnx import register_custom_op_symbolic + for name, fn in cls._SYMBOLICS.items(): # Symbolic name is in format: domain::name register_custom_op_symbolic( diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index 573ec85d76013..e57b615de07bb 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -5,7 +5,6 @@ import torch import torch._dynamo -import torch.onnx._internal.exporter from torch import nn from torch.nn import functional as F from torch.onnx import ExportOptions diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml index 6a772ebc1e1db..be3f67ba450b4 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda.yml @@ -18,7 +18,7 @@ stages: torch_version: '2.0.0' opset_version: '17' cuda_version: '11.8' - cmake_cuda_architectures: 70;75;80;86 + cmake_cuda_architectures: 60;61;70;75;80;86 docker_file: Dockerfile.manylinux2_28_training_cuda11_8 agent_pool: Onnxruntime-Linux-GPU upload_wheel: 'yes' diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml index b356d8027d0c5..74d299c728911 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cuda12.yml @@ -8,7 +8,7 @@ stages: torch_version: '2.1.0' opset_version: '17' cuda_version: '12.2' - cmake_cuda_architectures: 80;86;90 + cmake_cuda_architectures: 70;75;80;86;90 docker_file: Dockerfile.manylinux2_28_training_cuda12_2 agent_pool: Onnxruntime-Linux-GPU upload_wheel: 'yes' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml index 9b65ddbfdf3df..fc163d17e44a9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-training-cuda-stage-steps.yml @@ -165,7 +165,7 @@ stages: set -e -x whlfilename=$(ls $(Build.ArtifactStagingDirectory)/Release/dist/*.whl | head -n 1) ; \ echo $whlfilename ; du -sh $whlfilename ; \ - (( $(wc -c < "$whlfilename") - 300*1024*1024 < 0 )) || ( echo 'Wheel size bigger than 300M'; exit 1) + (( $(wc -c < "$whlfilename") - 400*1024*1024 < 0 )) || ( echo 'Wheel size bigger than 400M'; exit 1) displayName: 'Check wheel size' continueOnError: true diff --git a/tools/ci_build/github/azure-pipelines/templates/win-esrp-dll.yml b/tools/ci_build/github/azure-pipelines/templates/win-esrp-dll.yml index 933abad11595e..c495e11014b30 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-esrp-dll.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-esrp-dll.yml @@ -26,6 +26,38 @@ steps: AuthAKVName: 'buildkeyvault' AuthCertName: '53d54d02-SSL-AutoRotate' AuthSignCertName: '53d54d02-978d-4305-8572-583cf6711c4f' + signConfigType: inlineSignParams + inlineOperation: | + [ + { + "keyCode": "CP-230012", + "operationSetCode": "SigntoolSign", + "parameters": [ + { + "parameterName": "OpusName", + "parameterValue": "Microsoft" + }, + { + "parameterName": "OpusInfo", + "parameterValue": "http://www.microsoft.com" + }, + { + "parameterName": "PageHash", + "parameterValue": "/NPH" + }, + { + "parameterName": "FileDigest", + "parameterValue": "/fd sha256" + }, + { + "parameterName": "TimeStamp", + "parameterValue": "/tr \"http://rfc3161.gtm.corp.microsoft.com/TSS/HttpTspServer\" /td sha256" + } + ], + "toolName": "signtool.exe", + "toolVersion": "6.2.9304.0" + } + ] FolderPath: ${{ parameters.FolderPath }} Pattern: ${{ parameters.Pattern }} diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 9272f6e627a13..4e55ce29f46ff 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -125,7 +125,8 @@ RUN pip install \ sentencepiece \ wget \ dill==0.3.4 \ - pytorch_lightning==1.6.0 \ + pytorch_lightning==2.3.3 \ + tensorboard \ pytest-xdist \ pytest-rerunfailures \ ml_dtypes==0.3.0 \