From 59b7b6bb7cbb7bcc86dab590f1b4d5ed50d53dec Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 13 Sep 2024 16:52:49 +0000 Subject: [PATCH 1/9] Remove training from web ci pipeline (#22082) ### Description Remove training from web ci pipeline ### Motivation and Context --- .../templates/linux-wasm-ci.yml | 21 ------------------- .../azure-pipelines/templates/win-web-ci.yml | 6 +----- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index a56eb37faef84..2ab432e94fcbd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -31,10 +31,6 @@ parameters: type: boolean default: false -- name: BuildTraining - type: boolean - default: true - - name: WithCache type: boolean default: false @@ -116,19 +112,6 @@ jobs: DisplayName: 'Build and test (browser) (simd + threads)' WithCache: ${{ parameters.WithCache }} - - ${{ if eq(parameters.BuildTraining, true) }}: - - template: build-linux-wasm-step.yml - parameters: - Today: $(Today) - ${{ if eq(parameters.BuildStaticLib, true)}}: - AdditionalKey: wasm_training | ${{ parameters.BuildConfig }} | static - ${{ else }}: - AdditionalKey: wasm_training | ${{ parameters.BuildConfig }} - CacheDir: $(ORT_CACHE_DIR)/wasm_training - Arguments: '$(CommonBuildArgs) --build_dir $(Build.BinariesDirectory)/wasm_training --enable_training_apis --target onnxruntime_webassembly --skip_tests' - DisplayName: 'Build (training + simd + threads)' - WithCache: ${{ parameters.WithCache }} - - ${{ if eq(parameters.BuildJsep, true) }}: - template: build-linux-wasm-step.yml parameters: @@ -150,10 +133,6 @@ jobs: cp $(Build.BinariesDirectory)/wasm_inferencing_jsep/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.jsep.wasm $(Build.ArtifactStagingDirectory) cp $(Build.BinariesDirectory)/wasm_inferencing_jsep/${{ parameters.BuildConfig }}/ort-wasm-simd-threaded.jsep.mjs $(Build.ArtifactStagingDirectory) fi - if [ -d $(Build.BinariesDirectory)/wasm_training ]; then - cp $(Build.BinariesDirectory)/wasm_training/${{ parameters.BuildConfig }}/ort-training-wasm-simd-threaded.wasm $(Build.ArtifactStagingDirectory) - cp $(Build.BinariesDirectory)/wasm_training/${{ parameters.BuildConfig }}/ort-training-wasm-simd-threaded.mjs $(Build.ArtifactStagingDirectory) - fi displayName: 'Create Artifacts' - ${{ if eq(parameters.SkipPublish, false) }}: - task: PublishPipelineArtifact@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index c1fde93d8e640..0e8a7eb94379b 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -214,11 +214,7 @@ jobs: workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'E2E package consuming test' condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) - - script: | - npm run test:training:e2e - workingDirectory: '$(Build.SourcesDirectory)\js\web' - displayName: 'E2E training package test' - condition: and(succeeded(), eq('${{ parameters.BuildConfig }}', 'Release')) + - task: CopyFiles@2 inputs: sourceFolder: $(Build.SourcesDirectory)\js\common From 7e2c722459a7a7015a238379acc8705d9ce5b8dc Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Fri, 13 Sep 2024 13:21:11 -0700 Subject: [PATCH 2/9] Add Continuous Decoding support in GQA (#21523) ### Description This PR will add support for Continuous Decoding for batch_size = 1 input. From now on, GQA can take arbitrary length input using seqlens_k as total_sequence_length - 1 and the sequence length of qkv as new_sequence_length. **This change will not affect the default behavior of GQA** ### Motivation and Context Prior to this change it was impossible to support sequence_length > 1 inputs when past context was given. This use case is essential to making continuous decoding work, which is one of our current efforts in ORT-GenAI. --- docs/ContribOperators.md | 6 +- .../contrib_ops/cpu/bert/attention_common.h | 3 +- .../contrib_ops/cpu/bert/attention_helper.h | 11 +- .../contrib_ops/cpu/bert/gqa_attention_base.h | 177 +++++------ .../cpu/bert/group_query_attention.cc | 31 +- .../cpu/bert/group_query_attention_helper.h | 36 ++- .../cpu/sparse/sparse_attention_base.h | 4 +- .../bert/cutlass_fmha/fmha_launch_template.h | 1 - .../cuda/bert/group_query_attention.cc | 5 +- .../cuda/bert/group_query_attention_helper.h | 298 ------------------ .../cuda/bert/group_query_attention_impl.cu | 149 ++++++--- .../cuda/bert/group_query_attention_impl.h | 4 +- .../rocm/bert/group_query_attention.cu | 10 +- .../core/graph/contrib_ops/bert_defs.cc | 8 +- .../transformers/test_flash_attn_cuda.py | 171 +++++++++- .../test/python/transformers/test_gqa_cpu.py | 79 ++++- .../transformers/test_sparse_attention.py | 7 +- 17 files changed, 498 insertions(+), 502 deletions(-) delete mode 100644 onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index aadf4ebe2f488..09a7e47fc9913 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2521,6 +2521,8 @@ This version of the operator has been available since version 1 of the 'com.micr Only supports causal and local attention. Supports rotary position embedding for CPU and CUDA. Supports packed input for CPU and CUDA. + Supports continuous decoding for batch_size == 1 for CPU and CUDA. + #### Version @@ -2561,9 +2563,9 @@ This version of the operator has been available since version 1 of the 'com.micr
past_value (optional) : T
past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
seqlens_k : M
-
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
+
1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).
total_sequence_length : M
-
Scalar tensor of total sequence length (past + new).
+
Scalar tensor equivalent to the maximum total sequence length (past + new) of the batch. Used for checking inputs and determining prompt vs token generation case.
cos_cache (optional) : T
2D tensor with shape (max_sequence_length, head_size / 2).
sin_cache (optional) : T
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index 45acb90ba68b0..e0fa581c8071d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -114,7 +114,8 @@ struct GroupQueryAttentionParameters { int local_window_size; bool kv_share_buffer; bool is_packed_qkv; - bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1 + bool is_first_prompt; // indicates whether this is first decoding step bool do_rotary; bool rotary_interleaved; bool use_smooth_softmax; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index e6c948acb0d6c..4d435f71cc195 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -236,19 +236,16 @@ T* ConcatStateChunkGQA(const T* past, size_t past_buff_chunk_length, size_t past_chunk_length, size_t new_chunk_length, - bool is_prompt, bool past_present_share_buffer, std::ptrdiff_t i) { T* start = present + i * present_buff_chunk_length; T* p = start; - if (!is_prompt) { - if (!past_present_share_buffer) { - const T* src_past = past + i * past_buff_chunk_length; - memcpy(p, src_past, past_chunk_length * sizeof(T)); - } - p += past_chunk_length; + if (!past_present_share_buffer && past_chunk_length > 0) { + const T* src_past = past + i * past_buff_chunk_length; + memcpy(p, src_past, past_chunk_length * sizeof(T)); } + p += past_chunk_length; memcpy(p, chunk, new_chunk_length * sizeof(T)); return start; diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index 2bf0aa0915c2d..bfec9aef56727 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -59,6 +59,7 @@ class GQAAttentionBase { GroupQueryAttentionParameters& parameters, // attention parameters AllocatorPtr allocator, // allocator for temporary tensors OpKernelContext* context) const { + const bool is_prompt = parameters.is_first_prompt; const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int head_size = parameters.head_size; @@ -88,14 +89,14 @@ class GQAAttentionBase { const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, - present_key_data, past_present_share_buffer, packed_qkv, tp); + present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, - tp); + is_prompt, tp); return Status::OK(); } @@ -105,35 +106,35 @@ class GQAAttentionBase { // attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T) // attention_probs(B, N, S, T) = Softmax(attention_probs) template - void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT - const T* Q, // Q data. Its size is BxNxSxH - const T* K, // k data. Its size is BxNxLxH - const int32_t* seqlens_k, // past sequence lengths tensor - int batch_size, // batch size of self-attention - int sequence_length, // sequence length of self-attention (S) - int past_buffer_sequence_length, // sequence length of past state - int present_buffer_sequence_length, // sequence length of present state - int head_size, // head size of self-attention - const T* past_key, // past key only - T* present_key, // present key only - bool past_present_share_buffer, // whether present key and value share the same buffer - bool packed_qkv, // whether Q, K, V are packed - ThreadPool* tp) const { // thread pool - const bool is_prompt = sequence_length != 1; + void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT + const T* Q, // Q data. Its size is BxNxSxH + const T* K, // k data. Its size is BxNxLxH + const int32_t* seqlens_k, // total - 1 sequence lengths tensor + const size_t batch_size, // batch size of self-attention + const size_t sequence_length, // sequence length of self-attention (S) + const size_t past_buffer_sequence_length, // sequence length of past state + const size_t present_buffer_sequence_length, // sequence length of present state + const size_t head_size, // head size of self-attention + const T* past_key, // past key only + T* present_key, // present key only + const bool past_present_share_buffer, // whether present key and value share the same buffer + const bool packed_qkv, // whether Q, K, V are packed + const bool is_prompt, // whether it is prompt + ThreadPool* tp) const { // thread pool const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); - const int kv_num_heads_factor = num_heads_ / kv_num_heads_; - const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H - const size_t kv_input_chunk_length = static_cast(sequence_length) * head_size; // L x H - const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H - const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; // T x H + const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_; + const size_t q_input_chunk_length = sequence_length * head_size; // S x H + const size_t kv_input_chunk_length = sequence_length * head_size; // L x H + const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H + const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } - const int loop_len = batch_size * num_heads_; + const size_t loop_len = batch_size * num_heads_; const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; TensorOpCost unit_cost; @@ -156,12 +157,11 @@ class GQAAttentionBase { ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t i = begin; i != end; ++i) { - const int batch_index = static_cast(i) / num_heads_; - const int head_index = static_cast(i) % num_heads_; - const int past_seqlen = - sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; - const size_t past_chunk_length = static_cast(past_seqlen) * head_size; - const int total_seqlen = seqlens_k[batch_index] + 1; + const size_t batch_index = i / num_heads_; + const size_t head_index = i % num_heads_; + const size_t total_seqlen = static_cast(seqlens_k[batch_index]) + 1; + const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length; // Assume no padding sequence length + const size_t past_chunk_length = past_seqlen * head_size; const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; T* output = attention_probs + output_offset; @@ -174,7 +174,7 @@ class GQAAttentionBase { } if (nullptr != present_key) { k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); } @@ -189,16 +189,17 @@ class GQAAttentionBase { } else { q = Q + q_input_chunk_length * i; } + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, - head_size, k, head_size, 0.0f /*bata*/, output, present_buffer_sequence_length, - nullptr); + static_cast(head_size), k, static_cast(head_size), 0.0f /*bata*/, output, + static_cast(present_buffer_sequence_length), nullptr); // compute Softmax T* output_softmax = output; - for (int seq = 0; seq < sequence_length; seq++) { - int seq_causal_length = sequence_length == 1 ? total_seqlen : seq + 1; - if (local_window_size_ > 0 && seq_causal_length > local_window_size_ + 1) { - for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { + for (size_t seq = 0; seq < sequence_length; seq++) { + size_t seq_causal_length = past_seqlen + seq + 1; + if (local_window_size_ > 0 && seq_causal_length > static_cast(local_window_size_) + 1) { + for (size_t total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } if (softcap_ > 0.f) { @@ -214,17 +215,17 @@ class GQAAttentionBase { } } else { if (softcap_ > 0.f) { - ComputeAttentionSoftcapInplace(output_softmax, seq_causal_length, softcap_); + ComputeAttentionSoftcapInplace(output_softmax, static_cast(seq_causal_length), softcap_); } if (use_smooth_softmax_) { - ComputeSmoothSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + ComputeSmoothSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr); } else { - ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); + ComputeAttentionSoftmaxInplace(output_softmax, 1, static_cast(seq_causal_length), nullptr); } } // set causal [seq_causal_length, total_seqlen) to 0.f - for (int total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { + for (size_t total_seq_id = seq_causal_length; total_seq_id < total_seqlen; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } @@ -235,34 +236,36 @@ class GQAAttentionBase { } template - void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH - const T* attention_probs, // Attention probs with size BxNxSxT - const T* V, // V value with size BxN_kvxSxH - const int32_t* seqlens_k, // past sequence lengths tensor - int batch_size, // batch size - int sequence_length, // sequence length - int past_buffer_sequence_length, // sequence length in past state - int present_buffer_sequence_length, // sequence length in past state - int head_size, // head size of Q, K, V - int hidden_size, // hidden size of Output - const T* past_value, // past value only - T* present_value, // present value only - bool past_present_share_buffer, // whether present key and value share the same buffer - bool packed_qkv, // whether Q, K, V are packed + void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH + const T* attention_probs, // Attention probs with size BxNxSxT + const T* V, // V value with size BxN_kvxSxH + const int32_t* seqlens_k, // total - 1 sequence lengths tensor + const size_t batch_size, // batch size + const size_t sequence_length, // sequence length + const size_t past_buffer_sequence_length, // sequence length in past state + const size_t present_buffer_sequence_length, // sequence length in past state + const size_t head_size, // head size of Q, K, V + const size_t hidden_size, // hidden size of Output + const T* past_value, // past value only + T* present_value, // present value only + const bool past_present_share_buffer, // whether present key and value share the same buffer + const bool packed_qkv, // whether Q, K, V are packed + const bool is_prompt, // whether it is prompt ThreadPool* tp) const { - const bool is_prompt = sequence_length != 1; const ptrdiff_t packed_batch_stride = packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : SafeInt(0); - const int kv_num_heads_factor = num_heads_ / kv_num_heads_; - const int kv_input_chunk_length = sequence_length * head_size; // L x H - const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H - const size_t present_buff_chunk_length = static_cast(present_buffer_sequence_length) * head_size; // T x H + const size_t kv_num_heads_factor = num_heads_ / kv_num_heads_; + const size_t kv_input_chunk_length = sequence_length * head_size; // L x H + const size_t past_buff_chunk_length = past_buffer_sequence_length * head_size; // L x H + const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H if (!past_present_share_buffer) { memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T)); } + const size_t loop_len = batch_size * num_heads_; + // The cost of Gemm TensorOpCost unit_cost; unit_cost.compute_cycles = @@ -282,37 +285,35 @@ class GQAAttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; - ThreadPool::TryParallelFor( - tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const int batch_index = static_cast(i / num_heads_); - const int head_index = static_cast(i % num_heads_); - const int past_seqlen = - sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; - const size_t past_chunk_length = static_cast(past_seqlen) * head_size; - const int total_seqlen = seqlens_k[batch_index] + 1; - - const T* v; - if (packed_qkv) { - v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); - } else { - v = V + kv_input_chunk_length * (i / kv_num_heads_factor); - } - if (nullptr != present_value) { - v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, - i / kv_num_heads_factor); - } + ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const size_t batch_index = i / num_heads_; + const size_t head_index = i % num_heads_; + const size_t total_seqlen = static_cast(seqlens_k[batch_index]) + 1; + const size_t past_seqlen = is_prompt ? 0 : total_seqlen - sequence_length; // Assume no padding sequence length + const size_t past_chunk_length = past_seqlen * head_size; + + const T* v; + if (packed_qkv) { + v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); + } else { + v = V + kv_input_chunk_length * (i / kv_num_heads_factor); + } + if (nullptr != present_value) { + v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, past_present_share_buffer, + i / kv_num_heads_factor); + } - T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; - ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; - math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, - 1.f, /*alpha*/ - attention_probs + attention_probs_offset, present_buffer_sequence_length, v, - head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr); - } - }); + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/ + attention_probs + attention_probs_offset, + static_cast(present_buffer_sequence_length), v, static_cast(head_size), + 0.0f /*beta*/, output_current, static_cast(hidden_size), nullptr); + } + }); } }; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index 87675255f5ba4..2a38e4a1ac636 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -45,7 +45,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { const Tensor* past_key = context->Input(3); const Tensor* past_value = context->Input(4); const Tensor* seqlens_k = context->Input(5); - const Tensor* total_seqlen = context->Input(6); + const Tensor* total_seqlen_tensor = context->Input(6); const Tensor* cos_cache = context->Input(7); const Tensor* sin_cache = context->Input(8); @@ -61,7 +61,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { num_heads_, kv_num_heads_, seqlens_k, - total_seqlen, + total_seqlen_tensor, scale_, softcap_)); @@ -103,6 +103,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { } if (do_rotary_) { + // Initialize rotary parameters rotary_embedding_helper::RotaryParameters rotary_params = {}; rotary_params.batch_size = batch_size; rotary_params.sequence_length = sequence_length; @@ -114,17 +115,29 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { rotary_params.seq_stride = head_size; rotary_params.head_stride = sequence_length * rotary_params.seq_stride; rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * rotary_params.head_stride; - rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0; + rotary_params.position_ids_format = !parameters.is_first_prompt ? 1 : 0; rotary_params.transposed = true; auto* tp = context->GetOperatorThreadPool(); - std::vector pos_ids(sequence_length == 1 ? batch_size : 1); - if (sequence_length == 1) { + // Generate position ids + const int pos_ids_size = parameters.is_first_prompt ? 1 : batch_size * sequence_length; + std::vector pos_ids(pos_ids_size); + if (parameters.is_first_prompt) { + pos_ids[0] = static_cast(0); + } else { + // Note: As of now, interactive decoding supports only batch size 1 and token generation supports only sequence length 1. for (int b = 0; b < batch_size; b++) { - pos_ids[b] = static_cast(seqlens_k->Data()[b]); + const int total_seqlen = seqlens_k->Data()[b] + 1; + const int past_seqlen = total_seqlen - sequence_length; + for (int s = 0; s < sequence_length; s++) { + if (past_seqlen + s < total_seqlen) { + pos_ids[b * sequence_length + s] = static_cast(past_seqlen) + s; + } else { + pos_ids[b * sequence_length + s] = static_cast(1); + } + } } - } else { - pos_ids[0] = static_cast(0); } + // Initialize separate buffers for rotary embeddings const T* q_input; const T* k_input; T* q_rotary; @@ -149,6 +162,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { Q = RotaryQ; K = RotaryK; } + // Run rotary embedding for Q and K ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, q_input, pos_ids.data(), cos_cache->Data(), sin_cache->Data(), q_rotary, rotary_interleaved_)); @@ -161,6 +175,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(RunRotaryEmbedding(tp, rotary_params, k_input, pos_ids.data(), cos_cache->Data(), sin_cache->Data(), k_rotary, rotary_interleaved_)); + // Pack V into rotary QKV buffer if (packed_qkv) { const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size; T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index 3342052260ff9..0bdee151d2173 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -168,14 +168,13 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent."); } - // Check seqlens_k tensor (holding past seqlen for token gen) - const auto& seqlens_dim = seqlens_k->Shape().GetDims(); - if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { + const auto& seqlens_k_dim = seqlens_k->Shape().GetDims(); + if (seqlens_k_dim.size() != 1 && seqlens_k_dim[0] != batch_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "seqlens_k must be shape (batch_size)."); } - // Set present sequence length and kv_share_buffer from input total_seqlen tensor + // Set present sequence length from input total_seqlen tensor if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "total_sequence_length tensor must be of one element."); @@ -195,11 +194,11 @@ Status CheckInputs(const Tensor* query, } if (cos_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be not be less than total_sequence_length."); + "cos_cache dimension 0 shall not be less than total_sequence_length."); } if (sin_dims[0] < total_sequence_length) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be not be less than total_sequence_length."); + "sin_cache dimension 0 shall not be less than total_sequence_length."); } if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -219,7 +218,26 @@ Status CheckInputs(const Tensor* query, "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); } - bool is_prompt = sequence_length != 1; + bool is_subsequent_prompt = false; + if (sequence_length > 1 && sequence_length != total_sequence_length) { + if (batch_size != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "batch_size must be 1 when sequence_length > 1 and past context is given."); + } + is_subsequent_prompt = true; + } + + bool is_first_prompt; + if (is_subsequent_prompt) { + is_first_prompt = false; // irrelevant for interactive decoding + } else { + // If not interactive, sequence_length is 1 for token gen and arbitrarily large for prompt + is_first_prompt = (sequence_length == total_sequence_length); + if (!is_first_prompt && sequence_length != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sequence_length shall be 1 when it is not prompt."); + } + } if (parameters != nullptr) { GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); @@ -227,6 +245,7 @@ Status CheckInputs(const Tensor* query, output_parameters->sequence_length = sequence_length; // sequence length of Q output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors + output_parameters->total_sequence_length = total_sequence_length; // total sequence length output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; output_parameters->head_size = head_size; @@ -235,7 +254,8 @@ Status CheckInputs(const Tensor* query, output_parameters->rotary_dim = rotary_dim; output_parameters->is_packed_qkv = is_packed_qkv; output_parameters->is_unidirectional = true; - output_parameters->is_prompt = is_prompt; + output_parameters->is_subsequent_prompt = is_subsequent_prompt; + output_parameters->is_first_prompt = is_first_prompt; output_parameters->scale = scale; output_parameters->softcap = softcap; output_parameters->qkv_format = qkv_format; diff --git a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h index cf66bd8407126..37172074e5d86 100644 --- a/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h @@ -184,7 +184,7 @@ class SparseAttentionBase { // Concatenate past_k + k -> present_k // TODO: avoid copying mutiple times for a group. k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); // Compute Q*K' + AttentionMask @@ -365,7 +365,7 @@ class SparseAttentionBase { // Concatenate past_v + v -> present_v v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer, i / kv_num_heads_factor); DUMP_CPU_TENSOR("present_value", v, total_seq_len, head_size); diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index a10d2548fa7b8..7f1c3786858c8 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -42,7 +42,6 @@ struct RightPaddingBatchHook { auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE; - // Advance to current batch - in case of different sequence lengths if (p.seqlen_k_ptr) { p.num_keys = p.seqlen_k_ptr[batch_id]; } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index d0ae812bb4fa2..6eff584cec5da 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -5,7 +5,7 @@ #include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/group_query_attention.h" -#include "contrib_ops/cuda/bert/group_query_attention_helper.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" @@ -95,7 +95,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { kv_num_heads_, seqlens_k, total_seqlen, - is_past_bsnh_, scale_, softcap_, device_prop.maxThreadsPerBlock)); @@ -253,7 +252,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.out_accum = reinterpret_cast(out_accum_buffer.get()); } if (seqlens_k_buffer != nullptr) { - data.seqlens_k_total = reinterpret_cast(seqlens_k_buffer.get()); + data.seqlens_k_buff = reinterpret_cast(seqlens_k_buffer.get()); } // Memory Efficient Buffers if (k_buffer != nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h deleted file mode 100644 index e65827e4ccdd5..0000000000000 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include "core/common/common.h" -#include "core/providers/common.h" -#include "contrib_ops/cpu/bert/attention_common.h" - -namespace onnxruntime { -namespace contrib { -namespace group_query_attention_helper { - -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, - void* parameters, - int num_heads, - int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, - bool is_past_bsnh, - float scale, - float softcap) { - // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length - // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr - // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr - // no packing for q/k/v: - // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) - // key (K) : (B, S, D_kv) or nullptr - // value (V) : (B, S, D_kv) or nullptr - AttentionQkvFormat qkv_format = Q_K_V_BSNH; - AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; - const bool is_packed_qkv = key == nullptr; - const auto& query_dims = query->Shape().GetDims(); - - if (query_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", - query_dims.size()); - } - - int batch_size = static_cast(query_dims[0]); - int sequence_length = static_cast(query_dims[1]); - int q_hidden_size = static_cast(query_dims[2]); - int head_size = 0; - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - - int kv_hidden_size = 0; - // Check key and value when not packed - if (!is_packed_qkv) { - head_size = static_cast(q_hidden_size) / num_heads; - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - if (value == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); - } - const auto& key_dims = key->Shape().GetDims(); - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } else if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } else if (query_dims[1] != key_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 1 (sequence length)"); - } - kv_hidden_size = static_cast(key_dims[2]); - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } else if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch size)"); - } else if (query_dims[1] != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 1 (sequence length)"); - } else if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - } else { - // Check packed qkv - head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); - if (head_size % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size must be a multiple of 8. Got head_size % 8 == ", - head_size % 8); - } - if (value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); - } - q_hidden_size = head_size * num_heads; - kv_hidden_size = head_size * kv_num_heads; - } - - // Check past-present KV - int32_t past_sequence_length = 0; - if (past_key != nullptr && past_value != nullptr) { - const auto& past_key_dims = past_key->Shape().GetDims(); - const auto& past_value_dims = past_value->Shape().GetDims(); - - if (past_key_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' is expected to have 4 dimensions, got ", - past_key_dims.size()); - } - if (past_value_dims.size() != 4) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' is expected to have 4 dimensions, got ", - past_value_dims.size()); - } - - if (past_key_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 0 should be batch_size, got ", - past_key_dims[0]); - } - if (past_value_dims[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 0 should be batch_size, got ", - past_value_dims[0]); - } - - // BNSH - if (!is_past_bsnh) { - if (past_key_dims[2] != past_value_dims[2]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BNSH Input 'past_key' and 'past_value' should have same dimension 2 (max sequence" - "length or past sequence length), got ", - past_key_dims[1]); - } - if (past_key_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' shall have kv_num_heads"); - } - if (past_value_dims[1] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' shall have kv_num_heads"); - } - // We assume all sequence in past kv are right-padded to max or past sequence length - past_sequence_length = static_cast(past_key_dims[2]); - // BSNH - } else { - if (past_key_dims[1] != past_value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "BNSH Input 'past_key' and 'past_value' should have same dimension 1 (max sequence" - "length or past sequence length), got ", - past_key_dims[1]); - } - if (past_key_dims[2] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' shall have kv_num_heads"); - } - if (past_value_dims[2] != kv_num_heads) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' shall have kv_num_heads"); - } - // We assume all sequence in past kv are right-padded to max or past sequence length - past_sequence_length = static_cast(past_key_dims[1]); - } - - if (past_key_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' dimension 3 should be same as head_size, got ", - past_key_dims[3]); - } - if (past_value_dims[3] != head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_value' dimension 3 should be same as head_size, got ", - past_value_dims[3]); - } - } else if (past_key != nullptr || past_value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'past_key' and 'past_value' shall be both present or both absent."); - } - - // Check seqlens_k tensor (holding past seqlen for token gen) - const auto& seqlens_dim = seqlens_k->Shape().GetDims(); - if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "seqlens_k must be shape (batch_size)."); - } - - // Set present sequence length and kv_share_buffer from input total_seqlen tensor - if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "total_sequence_length tensor must be of one element."); - } - int total_sequence_length = *((*total_seqlen).template Data()); - int present_sequence_length = std::max(total_sequence_length, past_sequence_length); - - int rotary_dim = 0; - if (cos_cache != nullptr && sin_cache != nullptr) { - const auto& cos_dims = cos_cache->Shape().GetDims(); - const auto& sin_dims = sin_cache->Shape().GetDims(); - - if (head_size % 16 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_size shall be a multiple of 16. Got head_size % 16 == ", - head_size % 16); - } - if (cos_dims[0] < total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 0 should be not be less than total_sequence_length."); - } - if (sin_dims[0] < total_sequence_length) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 0 should be not be less than total_sequence_length."); - } - if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); - } - if (sin_dims[1] > (head_size / 16) * 8 || sin_dims[1] % 8 != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); - } - if (cos_dims[1] != sin_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "cos_cache and sin_cache dimension 1 must be the same."); - } - rotary_dim = static_cast(cos_dims[1] * 2); - } else if (cos_cache != nullptr || sin_cache != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); - } - - bool is_prompt = (sequence_length == total_sequence_length); - if (!is_prompt && sequence_length != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "sequence_length shall be 1 when it is not prompt."); - } - - if (parameters != nullptr) { - GroupQueryAttentionParameters* output_parameters = reinterpret_cast(parameters); - output_parameters->batch_size = batch_size; - output_parameters->sequence_length = sequence_length; // sequence length of Q - output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors - output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors - output_parameters->total_sequence_length = total_sequence_length; // total sequence length - output_parameters->hidden_size = q_hidden_size; - output_parameters->num_heads = num_heads; - output_parameters->head_size = head_size; - output_parameters->kv_hidden_size = kv_hidden_size; - output_parameters->kv_num_heads = kv_num_heads; - output_parameters->rotary_dim = rotary_dim; - output_parameters->is_packed_qkv = is_packed_qkv; - output_parameters->is_prompt = is_prompt; - output_parameters->scale = scale; - output_parameters->softcap = softcap; - output_parameters->qkv_format = qkv_format; - output_parameters->past_kv_format = past_kv_format; - } - - return Status::OK(); -} - -Status CheckInputs(const Tensor* query, - const Tensor* key, - const Tensor* value, - const Tensor* past_key, - const Tensor* past_value, - const Tensor* cos_cache, - const Tensor* sin_cache, - void* parameters, - int num_heads, - int kv_num_heads, - const Tensor* seqlens_k, - const Tensor* total_seqlen, - bool is_past_bsnh, - float scale, - float softcap, - int max_threads_per_block) { - if (max_threads_per_block > 0 && num_heads > max_threads_per_block) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); - } - - return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale, softcap); -} - -} // namespace group_query_attention_helper -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index be94f26ec298f..8bf9848245ec7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -71,6 +71,8 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const T* new_kv, T* present_kv, const int* seqlens_k, + const bool past_only, + // const int* seqlens_q, const bool is_bsnh) { // refers to past; otherwise bnsh const int h = threadIdx.x; const int n = threadIdx.y; @@ -88,7 +90,9 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, // past_kv: BPNH or BNPH // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + // prompt, token, and interactive decoding cases + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b] + 1 - new_seqlen; int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { @@ -96,7 +100,7 @@ __global__ void ConcatNewToPastKV(const int new_seqlen, const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; present_kv[out_offset] = past_kv[in_offset]; - } else if (s < past_seqlen + new_seqlen) { + } else if (!past_only && s < past_seqlen + new_seqlen) { // Note: new KV always BSNH const int new_batch_stride = new_seqlen * num_heads * H; const int new_row_stride = num_heads * H; @@ -116,6 +120,7 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const T* new_kv, T* present_kv, const int* seqlens_k, + const bool past_only, const bool is_bsnh) { int i = threadIdx.x + (blockDim.x * blockIdx.x); if (i < H * num_heads) { @@ -132,7 +137,9 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, // past_kv: BPNH or BNPH // new_kv: BLNH // present_kv: BTNH or BNTH, where T = P + L - const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b]; + + // prompt, token, and interactive decoding cases + const int past_seqlen = seqlens_k == nullptr ? 0 : seqlens_k[b] + 1 - new_seqlen; int out_offset = b * present_batch_stride + s * row_stride + n * present_head_stride + h; if (s < past_seqlen) { @@ -140,7 +147,7 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, const int past_head_stride = is_bsnh ? H : past_buffer_seqlen * H; const int in_offset = b * past_batch_stride + s * row_stride + n * past_head_stride + h; present_kv[out_offset] = past_kv[in_offset]; - } else if (s < past_seqlen + new_seqlen) { + } else if (!past_only && s < past_seqlen + new_seqlen) { const int new_batch_stride = new_seqlen * num_heads * H; const int new_row_stride = num_heads * H; const int new_head_stride = H; @@ -160,13 +167,12 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter const int max_threads_per_block, const bool past_only = false) { const int batch_size = parameters.batch_size; - const int kv_sequence_length = past_only ? 0 : parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; const int past_sequence_length = parameters.seqlen_past_kv_cache; const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; - const int* seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); - + const int* seqlens_k = parameters.is_first_prompt ? nullptr : reinterpret_cast(data.seqlens_k); AttentionQkvFormat past_kv_format = parameters.past_kv_format; assert(past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); @@ -180,6 +186,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKV<<>>(kv_sequence_length, past_sequence_length, @@ -187,6 +194,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } else { int steps = (H * kv_num_heads + 255) / 256; @@ -200,6 +208,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKVLarge<<>>(kv_sequence_length, past_sequence_length, @@ -209,6 +218,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, + past_only, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } return CUDA_CALL(cudaGetLastError()); @@ -219,7 +229,7 @@ template __global__ void ConcatKVInPlace(const int max_seqlen, T* kv_buff, const T* new_kv, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, const bool is_past_kv_bnsh_format, const bool is_new_kv_bnsh_format) { @@ -234,7 +244,7 @@ __global__ void ConcatKVInPlace(const int max_seqlen, const int past_seq_len = (total_seqlens_k != nullptr) ? (total_seqlens_k[b] - new_seqlen) - : (past_seqlens_k == nullptr ? 0 : past_seqlens_k[b]); + : (seqlens_k == nullptr ? 0 : (seqlens_k[b] + 1 - new_seqlen)); int out_offset = is_past_kv_bnsh_format ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) @@ -253,7 +263,7 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, const int kv_num_heads, T* kv_buff, const T* new_kv, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, const bool is_past_kv_bnsh_format, const bool is_new_kv_bnsh_format) { // refers to kv buff; otherwise bnsh @@ -264,9 +274,10 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, const int s = blockIdx.y; const int b = blockIdx.z; const int new_seqlen = gridDim.y; + const int past_seq_len = (total_seqlens_k != nullptr) ? (total_seqlens_k[b] - new_seqlen) - : (past_seqlens_k == nullptr ? 0 : past_seqlens_k[b]); + : (seqlens_k == nullptr ? 0 : (seqlens_k[b] + 1 - new_seqlen)); int out_offset = is_past_kv_bnsh_format ? INDEX_4D(kv_num_heads, max_seqlen, H, b, n, s + past_seq_len, h) @@ -286,15 +297,15 @@ Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, int new_seq_len, const T* new_key, const T* new_value, T* present_key, T* present_value, - bool is_past_kv_bnsh_format, - bool is_new_kv_bnsh_format, + const bool is_past_kv_bnsh_format, + const bool is_new_kv_bnsh_format, cudaStream_t stream, const int max_threads_per_block) { static_assert(sizeof(T) == 2); @@ -307,14 +318,14 @@ Status LaunchConcatKVInPlace(int batch_size, ConcatKVInPlace<<>>(max_sequence_length, reinterpret_cast(present_key), reinterpret_cast(new_key), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); ConcatKVInPlace<<>>(max_sequence_length, reinterpret_cast(present_value), reinterpret_cast(new_value), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); @@ -327,7 +338,7 @@ Status LaunchConcatKVInPlace(int batch_size, kv_num_heads, reinterpret_cast(present_key), reinterpret_cast(new_key), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); @@ -336,7 +347,7 @@ Status LaunchConcatKVInPlace(int batch_size, kv_num_heads, reinterpret_cast(present_value), reinterpret_cast(new_value), - past_seqlens_k, + seqlens_k, total_seqlens_k, is_past_kv_bnsh_format, is_new_kv_bnsh_format); @@ -354,7 +365,8 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, cudaStream_t stream, const int max_threads_per_block) { const int max_sequence_length = parameters.seqlen_present_kv_cache; - const int* past_seqlens_k = parameters.is_prompt ? nullptr : reinterpret_cast(data.seqlens_k); + const int* seqlens_k = (parameters.is_first_prompt && !parameters.is_subsequent_prompt) ? nullptr + : reinterpret_cast(data.seqlens_k); assert(parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BSNH || parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); @@ -364,8 +376,8 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, parameters.kv_num_heads, parameters.head_size, max_sequence_length, - past_seqlens_k, - nullptr, // total_seqlens_k is not available + seqlens_k, + nullptr, // total_seqlens_k would be wrong to use here parameters.sequence_length, reinterpret_cast(new_key), reinterpret_cast(new_value), @@ -495,23 +507,33 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, seqlens_k_buff[threadIdx.x] = seqlens_k[threadIdx.x] + add_seqlen; } -// Convert Past to Total sequence length tensor -Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, - int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int /*threads_per_block*/) { - if (parameters.is_prompt) { - return Status::OK(); - } - const int batch_size = parameters.batch_size; - const int add_seqlen = is_total ? parameters.sequence_length : 0; - +// Calculate total sequence length from seqlens_k +Status LaunchGetSeqlensTotal(int32_t* seqlens_k, int32_t* seqlens_k_buff, const int batch_size, cudaStream_t stream, + const int /*threads_per_block*/) { const dim3 grid(1, 1, 1); // TODO(aciddelgado): unlikely but could have a bigger batch_size than max_threads const dim3 block(batch_size, 1, 1); + PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, 1); + return CUDA_CALL(cudaGetLastError()); +} - // TODO(aciddelgado): small version - PastToTotalSeqlen<<>>(seqlens_k, seqlens_k_buff, add_seqlen); +// Currently, interactive decoding only works for batch_size 1 +__global__ void GetSeqlensInteractive(const int32_t* seqlens_k, int32_t* seqlens_k_buff, + const int batch_size, const int sequence_length) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + seqlens_k_buff[tid] = seqlens_k[tid] + 1 - sequence_length; + } +} +// Calculate past sequence length for each batch entry for flash attention kernel +Status LaunchGetSeqlensInteractive(const int32_t* seqlens_k, int32_t* seqlens_k_buff, + const int batch_size, const int sequence_length, cudaStream_t stream, + const int max_threads_per_block) { + const int threads = std::min(batch_size, max_threads_per_block); + const int blocks = (threads / max_threads_per_block) + 1; + GetSeqlensInteractive<<>>(seqlens_k, seqlens_k_buff, batch_size, + sequence_length); return CUDA_CALL(cudaGetLastError()); } @@ -576,7 +598,22 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp return CUDA_CALL(cudaGetLastError()); } -// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsInteractive(const int32_t* seqlens_k, int64_t* position_ids, + const int seqlen, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + const int total_seqlen = seqlens_k[b] + 1; + const int past_seqlen = total_seqlen - seqlen; + if (past_seqlen + s < total_seqlen) { + position_ids[tid] = past_seqlen + s; + } else { + position_ids[tid] = 1; + } + } +} + __global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; @@ -591,7 +628,6 @@ __global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* positio } } -// Kernel to convert seqlens_k to position_ids __global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid < batch_size) { @@ -601,12 +637,15 @@ __global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position // Convert seqlens_k to position_ids Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, - int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) { + int64_t* position_ids, cudaStream_t stream, + const int max_threads_per_block) { const int seqlen = parameters.sequence_length; const int batch_size = parameters.batch_size; const int threads = max_threads_per_block; const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_prompt) { + if (parameters.is_subsequent_prompt) { + SeqlensToPosIdsInteractive<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else if (parameters.is_first_prompt) { SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); } else { SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); @@ -650,7 +689,12 @@ Status FlashAttention( } void* seqlens_k = reinterpret_cast(data.seqlens_k); - if (parameters.is_prompt) { + if (parameters.is_subsequent_prompt) { + ORT_RETURN_IF_ERROR(LaunchGetSeqlensInteractive(reinterpret_cast(data.seqlens_k), + reinterpret_cast(data.seqlens_k_buff), batch_size, + sequence_length, stream, max_threads_per_block)); + seqlens_k = reinterpret_cast(data.seqlens_k_buff); + } else if (parameters.is_first_prompt) { // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value // user should use seqlens_k to index into output to get new tokens if (batch_size <= parameters.zeros_count) { @@ -659,10 +703,12 @@ Status FlashAttention( // Launch kernel to create larger seqlen tensor when batch_size > 256 constexpr int thr_per_blk = 256; int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size); - seqlens_k = data.seqlens_k_total; + repeat_seqlen<<>>(data.seqlens_k_buff, 0, batch_size); + seqlens_k = reinterpret_cast(data.seqlens_k_buff); } - } else if (!parameters.kv_share_buffer) { // copy past kv to present kv + } + + if (!parameters.kv_share_buffer || parameters.is_first_prompt) { // copy past kv to present kv ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, nullptr, nullptr, stream, max_threads_per_block, true)); } @@ -682,7 +728,7 @@ Status FlashAttention( reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv)); - // if (parameters.left_padding && parameters.is_prompt) { + // if (parameters.left_padding && parameters.is_first_prompt) { // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); // } @@ -766,15 +812,16 @@ Status EfficientAttention( key = reinterpret_cast(k_buffer); } - if (parameters.is_prompt) { + if (parameters.is_subsequent_prompt || !parameters.is_first_prompt) { + ORT_RETURN_IF_ERROR(LaunchGetSeqlensTotal(data.seqlens_k, data.seqlens_k_buff, batch_size, stream, 256)); + } else { // Launch kernel to copy seqlen constexpr int thr_per_blk = 256; int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, + repeat_seqlen<<>>(data.seqlens_k_buff, parameters.sequence_length, batch_size); - } else { - ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); } + int* seqlens_k = data.seqlens_k_buff; if (parameters.kv_share_buffer) { // Share buffer case @@ -815,7 +862,7 @@ Status EfficientAttention( } DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", data.seqlens_k_total, batch_size, 1); + DUMP_TENSOR("seqlens_k", seqlens_k, batch_size, 1); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -823,14 +870,14 @@ Status EfficientAttention( p.batch_size = batch_size; p.num_heads = num_heads; p.sequence_length = sequence_length; - p.kv_sequence_length = present_sequence_length; // TOTALLY UNNECESSARY IF WE HAVE SEQLENS_K, maybe remove + p.kv_sequence_length = present_sequence_length; // maybe remove p.max_sequence_length = present_sequence_length; p.qk_head_size = head_size; p.v_head_size = head_size; p.causal = true; p.scale = scale; p.softcap = parameters.softcap; - p.seqlen_k_ptr = data.seqlens_k_total; // Note: seqlens_k is total sequence length for efficient + p.seqlen_k_ptr = seqlens_k; // Note: seqlens_k is total sequence length for efficient p.seqstart_q_ptr = nullptr; p.seqstart_k_ptr = nullptr; p.query = query; @@ -912,7 +959,7 @@ template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, int new_seq_len, const half* new_key, @@ -928,7 +975,7 @@ template Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, - const int* past_seqlens_k, + const int* seqlens_k, const int* total_seqlens_k, int new_seq_len, const BFloat16* new_key, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index e8dc69188b95f..8593ecede2bab 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -27,7 +27,7 @@ struct GroupQueryAttentionData { T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; T* out_accum = nullptr; - int* seqlens_k_total = nullptr; + int* seqlens_k_buff = nullptr; // Memory Efficient buffers T* fmha_buffer = nullptr; T* unpacked_qkv_buffer = nullptr; @@ -61,7 +61,7 @@ Status LaunchConcatKVInPlace(int batch_size, int kv_num_heads, int head_size, int max_sequence_length, // max sequence length of present_key or present_value. - const int* past_seqlens_k, // it is not used when total_seqlens_k is available. + const int* seqlens_k, // it is not used when total_seqlens_k is available. const int* total_seqlens_k, // optional, nullptr means it is not available. int new_seq_len, const T* new_key, diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu index 7a16eb38181aa..e644b7e903138 100644 --- a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -5,7 +5,7 @@ #include "core/providers/rocm/rocm_common.h" #include "core/platform/env_var_utils.h" #include "contrib_ops/rocm/bert/group_query_attention.h" -#include "contrib_ops/rocm/bert/group_query_attention_helper.h" +#include "contrib_ops/cpu/bert/group_query_attention_helper.h" #include "contrib_ops/rocm/bert/rotary_embedding_impl.h" #include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" @@ -115,7 +115,7 @@ Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int batch_size = parameters.batch_size; const int threads = max_threads_per_block; const int blocks = (batch_size * seqlen + threads - 1) / threads; - if (parameters.is_prompt) { + if (parameters.is_first_prompt) { SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); } else { SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); @@ -325,7 +325,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { // build present kv cache auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); - if (parameters.is_prompt) { + if (parameters.is_first_prompt) { // copy prompt kv to present kv ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); @@ -383,7 +383,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { return ret; } - if (parameters.is_prompt && is_unidirectional_) { + if (parameters.is_first_prompt && is_unidirectional_) { return mask_info::decode("t", sequence_length, kv_sequence_length); } @@ -496,7 +496,7 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { parameters.head_size, parameters.head_size, // v head size GetCkFmhaDataTypeString(), - !parameters.is_prompt, // true, // is_group_mode + !parameters.is_first_prompt, // true, // is_group_mode true, // is_v_rowmajor ? dim is fastest : seq is fastest mask.type, bias_type, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 5185205f1dde1..c706c6fc5ff5f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1049,6 +1049,8 @@ Supports different number of heads for q and kv for CPU and CUDA. Only supports causal and local attention. Supports rotary position embedding for CPU and CUDA. Supports packed input for CPU and CUDA. +Supports continuous decoding for batch_size == 1 for CPU and CUDA. + )DOC"; ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1110,12 +1112,12 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(5, "seqlens_k", - // For prompt, the value is number of tokens (excluding padding) - 1. - "1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.", + "1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).", "M") .Input(6, "total_sequence_length", - "Scalar tensor of total sequence length (past + new).", + "Scalar tensor equivalent to the maximum total sequence length (past + new) of the batch. Used for " + "checking inputs and determining prompt vs token generation case.", "M") .Input(7, "cos_cache", diff --git a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py index c04929a3b603e..46ab905977f48 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn_cuda.py +++ b/onnxruntime/test/python/transformers/test_flash_attn_cuda.py @@ -223,6 +223,7 @@ def create_group_query_attention_graph_prompt( rotary=False, rotary_interleaved=False, packed=False, + interactive=False, softcap=0.0, use_smooth_softmax=False, ): @@ -1224,7 +1225,7 @@ def parity_check_gqa_prompt( config, causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1422,7 +1423,7 @@ def parity_check_gqa_prompt_no_buff( config, causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1597,7 +1598,7 @@ def parity_check_gqa_past( config, causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1667,7 +1668,6 @@ def parity_check_gqa_past( if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length - config.sequence_length + 1, @@ -1696,7 +1696,6 @@ def parity_check_gqa_past( "b 1 (s h) d -> b s h d", s=config.sequence_length, ) - # q_ro = q k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) else: cos, sin = None, None @@ -1730,6 +1729,8 @@ def parity_check_gqa_past( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1783,15 +1784,14 @@ def parity_check_gqa_past( numpy.testing.assert_allclose( present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg ) - numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) def parity_check_gqa_past_no_buff( config, - causal=False, + causal=True, local=False, - past_format=Formats.BSNH, + past_format=Formats.BNSH, rotary=False, rotary_interleaved=False, packed=False, @@ -1864,7 +1864,6 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) k_cache_ref = torch.cat((k_cache_ref, new_k), 1) v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cuda").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length, @@ -1896,7 +1895,6 @@ def parity_check_gqa_past_no_buff( "b 1 (s h) d -> b s h d", s=config.sequence_length, ) - # q_ro = q k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) else: cos, sin = None, None @@ -1930,6 +1928,8 @@ def parity_check_gqa_past_no_buff( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1976,6 +1976,23 @@ def parity_check_gqa_past_no_buff( f" with {config}, causal={causal}, local={local}, past_format={past_format}," f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" ) + for b in range(config.batch_size): + numpy.testing.assert_allclose( + present_k[b, :, : (cache_seqlens + 1)[b]], + k_cache_ref[b, :, : (cache_seqlens + 1)[b]].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=err_msg, + ) + numpy.testing.assert_allclose( + present_v[b, :, : (cache_seqlens + 1)[b]], + v_cache_ref[b, :, : (cache_seqlens + 1)[b]].detach().cpu().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + err_msg=err_msg, + ) numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) @@ -2229,6 +2246,86 @@ def gqa_past_flash_attention_test_cases(): ) +def gqa_interactive_one_batch_flash_attention_test_cases(): + batches = [1] + seqs = ( + [(2, 128), (128, 129), (32, 128), (256, 2048)] + if pipeline_mode + else [ + (1, 128), + (32, 128), + (128, 2048), + (1235, 5000), + (40, 800), + (1, 256), + (2, 799), + (41, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in rotary_options_for_current_os(): + for packed in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + yield ( + str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", + config, + local, + rotary, + rotary_interleaved, + packed, + ) + + +def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): + batches = [1] + seqs = ( + [(2, 128), (128, 129), (32, 128), (256, 2048)] + if pipeline_mode + else [ + (1, 128), + (32, 128), + (128, 2048), + (1235, 5000), + (40, 800), + (1, 256), + (2, 799), + (41, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for rotary, rotary_interleaved in rotary_options_for_current_os(): + for packed in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + yield ( + str(config) + f"{rotary}_{rotary_interleaved}_{packed}", + config, + rotary, + rotary_interleaved, + packed, + ) + + class TestGQA(unittest.TestCase): @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) def test_gqa_no_past_memory_efficient(self, _, config, rotary, rotary_interleaved, packed, softcap): @@ -2350,6 +2447,60 @@ def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interle use_smooth_softmax=True, ) + @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) + def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + if not has_flash_attention(): + return + print("------- FLASH ATTENTION (INTERACTIVE) -------") + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + + parity_check_gqa_past( + config, + local=local, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + + @parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases()) + def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed): + if not has_memory_efficient(): + return + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + print("-------- MEMORY EFFICIENT (INTERACTIVE) --------") + + parity_check_gqa_past( + config, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + past_format=Formats.BNSH, + rtol=5e-3, + atol=5e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index cc9d7ff51a5c6..dc21d4e4a5890 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -121,8 +121,12 @@ def rotate_tensor( else: x_rot = torch.cat((real, imag), dim=-1) else: - cos_x = cos[:, 0:seq_len, :, :] - sin_x = sin[:, 0:seq_len, :, :] + batch_size = x.shape[0] + cos_x = torch.zeros((batch_size, seq_len, 1, cos.shape[3]), device=x.device) + sin_x = torch.zeros((batch_size, seq_len, 1, sin.shape[3]), device=x.device) + for b in range(x.shape[0]): + cos_x[b] = cos[0, pos[b] : pos[b] + seq_len, :, :] + sin_x[b] = sin[0, pos[b] : pos[b] + seq_len, :, :] real = cos_x * x1 - sin_x * x2 imag = sin_x * x1 + cos_x * x2 if interleaved: @@ -716,7 +720,6 @@ def gqa_prompt_func( ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - # TODO: do we need io binding for cpu input? io_binding.bind_cpu_input("query", ort_inputs["query"]) io_binding.bind_input( "past_key", "cpu", 0, numpy.float32, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() @@ -788,6 +791,7 @@ def gqa_past_func( softcap=0.0, use_smooth_softmax=False, ): + assert seqlens_k is not None onnx_model_str = create_group_query_attention_graph_past( config, past_kv_format, @@ -819,12 +823,12 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() - if new_k is not None: + if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: + if cos is not None and sin is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) @@ -867,12 +871,12 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CPUExecutionProvider"]) io_binding = ort_session.io_binding() - if new_k is not None: + if new_k is not None and new_v is not None: ort_inputs["key"] = new_k.detach().cpu().numpy() ort_inputs["value"] = new_v.detach().cpu().numpy() io_binding.bind_cpu_input("key", ort_inputs["key"]) io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: + if cos is not None and sin is not None: ort_inputs["cos_cache"] = cos.detach().cpu().numpy() ort_inputs["sin_cache"] = sin.detach().cpu().numpy() io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) @@ -1518,7 +1522,6 @@ def parity_check_gqa_past( if past_format == Formats.BNSH: k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cpu").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length - config.sequence_length + 1, @@ -1576,6 +1579,8 @@ def parity_check_gqa_past( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # ORT function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -1739,7 +1744,6 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) k_cache_ref = torch.cat((k_cache_ref, new_k), 1) v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - # cache_seqlens = torch.tensor([config.past_sequence_length], device="cpu").repeat(config.batch_size) cache_seqlens = torch.randint( 0, config.kv_sequence_length, @@ -1800,6 +1804,8 @@ def parity_check_gqa_past_no_buff( k_cache_ref = k_cache_ref.transpose(1, 2) v_cache_ref = v_cache_ref.transpose(1, 2) + cache_seqlens += config.sequence_length - 1 + # Flash function if packed: packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) @@ -2000,6 +2006,61 @@ def test_gqa_past(self): ) self.assertTrue(all_close) + def test_gqa_interactive_one_batch(self): + print("-------- TEST GQA INTERACTIVE ---------") + batches = [1] + seqs = ( + [(2, 128), (128, 129), (32, 128), (256, 2048)] + if pipeline_mode + else [ + (1, 128), + (1, 339), + (1, 1024), + (1, 5000), + (1, 800), + (1, 256), + (1, 799), + (1, 2048), + # (1, 128 * 512), + # (16, 128 * 512), + # (128, 128), + ] + ) + num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] + h_sizes = [16, 64, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] + random.seed(69) + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for local in [False, True]: + for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: + for packed in [False, True]: + config = Config(b, s, s2, -1, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + self.assertTrue(all_close) + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index 6a08d2101b100..5dbb9a277e45a 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -890,7 +890,7 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot dtype=dtype, is_packed_qkv=packed_qkv, do_rotary=do_rotary, - rotary_interleaved=sequence_length <= 128, + rotary_interleaved=do_rotary and sequence_length <= 128, max_cache_sequence_length=None if sequence_length >= 128 else 128, ) yield config @@ -929,7 +929,7 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot dtype=dtype, is_packed_qkv=packed_qkv, do_rotary=do_rotary, - rotary_interleaved=sequence_length <= 128, + rotary_interleaved=do_rotary and sequence_length <= 128, max_cache_sequence_length=None if sequence_length >= 128 else 128, # test smaller kv cache buffer. ) yield config @@ -940,7 +940,6 @@ def get_test_cases(provider: str, has_past_kv: bool, comprehensive: bool, do_rot class TestSparseAttention(unittest.TestCase): - @unittest.skipUnless(has_cuda_support(), "cuda not available") def test_sparse_attention_cuda(self): major, minor = torch.cuda.get_device_capability() @@ -1056,7 +1055,7 @@ def run_relevance_past(self, sm: int, device, do_rotary: bool): vert_stride=4, softmax_scale=None, do_rotary=do_rotary, - rotary_interleaved=(past_seq_len % 2 == 1), + rotary_interleaved=do_rotary and (past_seq_len % 2 == 1), device=device, is_packed_qkv=packed_qkv, max_rotary_sequence_length=None if past_seq_len >= 128 else 128, # test smaller rotary buffer. From a89bddd5c224c045510d09537a95d32602e021cc Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Fri, 13 Sep 2024 14:55:08 -0700 Subject: [PATCH 3/9] Matmul_nbits kernel for mlas sqnbits to support Fp16 inputs (#21807) --- cmake/onnxruntime_mlas.cmake | 4 +- docs/OperatorKernels.md | 2 +- .../cpu/quantization/matmul_nbits.cc | 246 +++++++++++++----- .../cpu/quantization/matmul_nbits_impl.cc | 11 +- onnxruntime/core/mlas/inc/mlas.h | 36 ++- onnxruntime/core/mlas/lib/cast.cpp | 42 ++- onnxruntime/core/mlas/lib/mlasi.h | 11 +- onnxruntime/core/mlas/lib/platform.cpp | 4 + .../core/mlas/lib/sqnbitgemm_kernel_avx2.cpp | 45 ++++ .../core/providers/cpu/tensor/cast_op.cc | 2 +- .../test/contrib_ops/matmul_4bits_test.cc | 54 +++- 11 files changed, 341 insertions(+), 116 deletions(-) diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index b612b3ead4658..e35c83ba45952 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -580,10 +580,10 @@ message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") message(STATUS "Using -mavx2 -mfma -mavxvnni flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mavxvnni") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") else() message(STATUS "Using -mavx2 -mfma flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") endif() set(mlas_platform_srcs_avx512f ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d57394b3e7b97..121240e6e18f9 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -488,7 +488,7 @@ Do not modify directly.* |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| -|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)| +|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(float16), tensor(uint8)
**T4** = tensor(int32)| |MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)| |MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)| diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index bf43aca73ef3a..ccb779721d006 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -146,8 +146,15 @@ class MatMulNBits final : public OpKernel { bool all_constant_{false}; #endif // defined(ORT_NEURAL_SPEED) + + template + Status ComputeTyped(OpKernelContext* ctx) const; }; +bool IsATypeFloat16(const Tensor& tensor) { + return tensor.GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16; +} + Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { @@ -211,10 +218,10 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat #else // defined(ORT_NEURAL_SPEED) ORT_UNUSED_PARAMETER(prepacked_weights); const auto compute_type = static_cast(accuracy_level_); + if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { + return Status::OK(); + } if (input_idx == InputIndex::B) { - if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) { - return Status::OK(); - } packed_b_size_ = MlasSQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, compute_type); if (packed_b_size_ == 0) { return Status::OK(); @@ -226,8 +233,15 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat } else if (compute_type == CompInt8) { #ifdef MLAS_TARGET_AMD64_IX86 if (input_idx == InputIndex::scales && packed_b_ != nullptr) { - auto sptr = tensor.Data(); - MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + if (IsATypeFloat16(tensor)) { + auto sptr = tensor.Data(); + std::vector scales_v(static_cast(tensor.Shape().Size())); + MlasConvertHalfToFloatBuffer(sptr, &scales_v[0], scales_v.size()); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), &scales_v[0], has_zp_input_, nullptr, nullptr); + } else { + auto sptr = tensor.Data(); + MlasSQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type, nullptr, packed_b_.get(), sptr, has_zp_input_, nullptr, nullptr); + } is_packed = false; } else if (input_idx == InputIndex::zero_points && packed_b_ != nullptr) { auto zptr = tensor.Data(); @@ -274,9 +288,20 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& prep } Status MatMulNBits::Compute(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(InputIndex::A); + + if (IsATypeFloat16(*a)) { + return ComputeTyped(ctx); + } else { + return ComputeTyped(ctx); + } +} + +template +Status MatMulNBits::ComputeTyped(OpKernelContext* ctx) const { concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); const Tensor* a = ctx->Input(InputIndex::A); - const auto* a_data = a->Data(); + const auto* a_data = a->Data(); TensorShape b_shape({static_cast(N_), static_cast(K_)}); MatMulComputeHelper helper; @@ -289,7 +314,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { return Status::OK(); } - auto* y_data = y->MutableData(); + auto* y_data = y->MutableData(); const size_t batch_count = helper.OutputOffsets().size(); const size_t M = static_cast(helper.M()); @@ -297,9 +322,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(false); - const bool has_single_b_matrix = std::all_of(helper.RightOffsets().begin(), - helper.RightOffsets().end(), - [](size_t offset) { return offset == 0; }); + // clang-format off + const bool has_single_b_matrix = std::all_of( + helper.RightOffsets().begin(), + helper.RightOffsets().end(), + [](size_t offset) { return offset == 0; }); + // clang-format on #if defined(ORT_NEURAL_SPEED) @@ -336,9 +364,9 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* zero_points = ctx->Input(InputIndex::zero_points); const Tensor* bias = ctx->Input(InputIndex::bias); - const auto* scales_data = scales->Data(); + const auto* scales_data = scales->Data(); const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); - const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); + const auto* bias_data = bias == nullptr ? nullptr : bias->Data(); IAllocatorUniquePtr workspace{}; const size_t workspace_size = MlasSQNBitGemmBatchWorkspaceSize( @@ -349,26 +377,64 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { workspace = IAllocator::MakeUniquePtr(allocator, workspace_size); } - InlinedVector data(batch_count); - for (size_t i = 0; i < batch_count; ++i) { - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; + if constexpr (std::is_same::value) { + InlinedVector data(batch_count); + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); + + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); + + auto tmp_scales_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(scales->Shape().Size())); + MlasConvertHalfToFloatBuffer(scales_data, tmp_scales_data_ptr.get(), static_cast(scales->Shape().Size())); + + std::vector bias_data_v; + if (bias_data != nullptr) { + bias_data_v.resize((const unsigned int)(bias->Shape().Size())); + MlasConvertHalfToFloatBuffer(bias_data, &bias_data_v[0], bias_data_v.size()); + } + std::vector C_v((const unsigned int)(y->Shape().Size())); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; + data[i].lda = lda; #ifdef MLAS_TARGET_AMD64_IX86 - if (compute_type == CompInt8) { - data[i].QuantBDataWorkspace = packed_b_.get(); + if (compute_type == CompInt8) { + data[i].QuantBDataWorkspace = packed_b_.get(); + } +#endif + data[i].PackedQuantBData = static_cast(packed_b_.get()); + data[i].QuantBScale = tmp_scales_data_ptr.get(); + data[i].QuantBZeroPoint = zero_points_data; + data[i].Bias = bias_data != nullptr ? &bias_data_v[0] : nullptr; + data[i].C = &C_v[0] + helper.OutputOffsets()[i]; + data[i].ldc = N; } + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + MlasConvertFloatToHalfBuffer(&C_v[0], y_data, C_v.size()); + return Status::OK(); + } else { + InlinedVector data(batch_count); + for (size_t i = 0; i < batch_count; ++i) { + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; +#ifdef MLAS_TARGET_AMD64_IX86 + if (compute_type == CompInt8) { + data[i].QuantBDataWorkspace = packed_b_.get(); + } #endif - data[i].PackedQuantBData = static_cast(packed_b_.get()); - data[i].QuantBScale = scales_data; - data[i].QuantBZeroPoint = zero_points_data; - data[i].Bias = bias_data; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; + data[i].PackedQuantBData = static_cast(packed_b_.get()); + data[i].QuantBScale = scales_data; + data[i].QuantBZeroPoint = zero_points_data; + data[i].Bias = bias_data; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + } + MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), + thread_pool); + return Status::OK(); } - MlasSQNBitGemmBatch(M, N, K, batch_count, nbits_, block_size_, compute_type, data.data(), workspace.get(), - thread_pool); - - return Status::OK(); } } @@ -380,7 +446,17 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { const Tensor* zero_points = ctx->Input(InputIndex::zero_points); const Tensor* reorder_idx = ctx->Input(InputIndex::g_idx); - const auto* scales_data = scales->Data(); + const auto* scales_data = scales->Data(); + const float* scales_data_; + std::vector scales_data_v; + if constexpr (std::is_same::value) { + scales_data_v.resize((const unsigned int)scales->Shape().Size()); + MlasConvertHalfToFloatBuffer(scales_data, &scales_data_v[0], scales_data_v.size()); + scales_data_ = &scales_data_v[0]; + } else { + scales_data_ = scales_data; + } + const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw(); const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data(); @@ -391,12 +467,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&allocator)); auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); - if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { + if ((reorder_idx_data == nullptr) && (!zero_points || !zero_points->IsDataType())) { // dequantize b, only 4b quantization is supported for now MlasDequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data, // quantization scales + scales_data_, // quantization scales static_cast(zero_points_data), // quantization zero points static_cast(block_size_), // quantization block size column_wise_quant_, // columnwise quantization or row-wise @@ -406,12 +482,12 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { } else { ORT_ENFORCE(column_wise_quant_, "Row-wise quantization is not supported for now"); // !!!!!!!!!!!!!! naive implementation, need to be optimized !!!!!!!!!!!!!! - if ((zero_points && zero_points->IsDataType())) { - DequantizeBlockwise( + if ((zero_points && zero_points->IsDataType())) { + DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data, // quantization scales - static_cast(zero_points_data), // quantization zero points + scales_data_, // quantization scales + static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size column_wise_quant_, // columnwise quantization or row-wise @@ -422,7 +498,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { DequantizeBlockwise( tmp_b_data_ptr.get(), // dequantized output b_data, // quantized input - scales_data, // quantization scales + scales_data_, // quantization scales static_cast(zero_points_data), // quantization zero points reorder_idx_data, static_cast(block_size_), // quantization block size @@ -436,40 +512,80 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { auto tm_b_data_ptr_trans = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); MlasTranspose(tmp_b_data_ptr.get(), tm_b_data_ptr_trans.get(), N_, K_); #endif + if constexpr (std::is_same::value) { + std::vector data(batch_count); - std::vector data(batch_count); - for (size_t i = 0; i < batch_count; i++) { - data[i].BIsPacked = false; - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; - data[i].ldb = ldb; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; - data[i].alpha = 1.f; - data[i].beta = 0.0f; - } + auto tmp_a_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(a->Shape().Size())); + MlasConvertHalfToFloatBuffer(a_data, tmp_a_data_ptr.get(), static_cast(a->Shape().Size())); - // if there is a bias input, copy bias values into C and set beta to 1.0f - if (const Tensor* bias = ctx->Input(InputIndex::bias); - bias != nullptr) { - gsl::span bias_span = bias->DataAsSpan(); - for (size_t i = 0; i < batch_count; ++i) { - float* C_row = data[i].C; - const size_t ldc = data[i].ldc; - for (size_t m = 0; m < M; ++m) { - memcpy(C_row, bias_span.data(), bias_span.size_bytes()); - C_row += ldc; + auto tmp_c_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(y->Shape().Size())); + for (size_t i = 0; i < batch_count; i++) { + data[i].BIsPacked = false; + data[i].A = tmp_a_data_ptr.get() + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = tmp_c_ptr.get() + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + + // if there is a bias input, copy bias values into C and set beta to 1.0f + if (const Tensor* bias = ctx->Input(InputIndex::bias); + bias != nullptr) { + auto tmp_bias_data_ptr = IAllocator::MakeUniquePtr(allocator, (size_t)(bias->Shape().Size())); + MlasConvertHalfToFloatBuffer(bias->Data(), tmp_bias_data_ptr.get(), static_cast(bias->Shape().Size())); + for (size_t i = 0; i < batch_count; ++i) { + float* C_row = data[i].C; + const size_t ldc = data[i].ldc; + for (size_t m = 0; m < M; ++m) { + std::copy(tmp_bias_data_ptr.get(), tmp_bias_data_ptr.get() + bias->Shape().Size(), C_row); + C_row += ldc; + } + data[i].beta = 1.0f; } + } - data[i].beta = 1.0f; + MlasGemmBatch(CblasNoTrans, CblasTrans, + M, N, K, data.data(), batch_count, thread_pool); + MlasConvertFloatToHalfBuffer(tmp_c_ptr.get(), y_data, static_cast(y->Shape().Size())); + return Status::OK(); + } else { + std::vector data(batch_count); + for (size_t i = 0; i < batch_count; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; } - } - MlasGemmBatch(CblasNoTrans, CblasTrans, - M, N, K, data.data(), batch_count, thread_pool); + // if there is a bias input, copy bias values into C and set beta to 1.0f + if (const Tensor* bias = ctx->Input(InputIndex::bias); + bias != nullptr) { + gsl::span bias_span = bias->DataAsSpan(); + for (size_t i = 0; i < batch_count; ++i) { + float* C_row = data[i].C; + const size_t ldc = data[i].ldc; + for (size_t m = 0; m < M; ++m) { + memcpy(C_row, bias_span.data(), bias_span.size_bytes()); + C_row += ldc; + } - return Status::OK(); + data[i].beta = 1.0f; + } + } + + MlasGemmBatch(CblasNoTrans, CblasTrans, + M, N, K, data.data(), batch_count, thread_pool); + + return Status::OK(); + } } ONNX_OPERATOR_KERNEL_EX( @@ -478,9 +594,9 @@ ONNX_OPERATOR_KERNEL_EX( 1, kCpuExecutionProvider, KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .TypeConstraint("T2", DataTypeImpl::GetTensorType()) - .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) .TypeConstraint("T4", DataTypeImpl::GetTensorType()), MatMulNBits); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index b28f3758f89b5..6a19a741c3028 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -54,12 +54,12 @@ void Dequantize4BitsKernelReOrder( T scale = *(scale_data + n_idx * scales_shape_x + rid); float zp_f = 8; if (zero_points) { - if constexpr (std::is_same_v) { - zp_f = *(zero_points + n_idx * scales_shape_x + rid); - } else { + if constexpr (std::is_same_v) { uint8_t zp = 8; zp = zero_points[n_idx * zero_point_shape_x + rid / 2]; zp = (rid & 0x01) ? (zp >> 4) : (zp & 0x0f); + } else { + zp_f = *(zero_points + static_cast(n_idx) * static_cast(scales_shape_x) + static_cast(rid)); } } @@ -112,5 +112,10 @@ template void DequantizeBlockwise( const float* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); +template void DequantizeBlockwise( + float* output, const uint8_t* quant_data, const float* scales_data, + const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size, + bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); + } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index 8b3156d77e57c..28ae64c4d5b3e 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -20,6 +20,7 @@ Module Name: #include #include #include +#include // // Define the calling convention for Windows targets. @@ -1025,18 +1026,6 @@ MlasComputeTanh( size_t N ); -// -// Half-precision floating-point routines. -// - -void -MLASCALL -MlasConvertHalfToFloatBuffer( - const unsigned short* Source, - float* Destination, - size_t Count -); - // // Transpose routines. // @@ -1426,7 +1415,27 @@ using MLAS_FP16 = onnxruntime::MLFloat16; constexpr size_t FP16_SIZE = sizeof(uint16_t); -/** +// +// Half-precision floating-point routines. +// + +void +MLASCALL +MlasConvertHalfToFloatBuffer( + const MLAS_FP16* Source, + float* Destination, + size_t Count +); + +void +MLASCALL +MlasConvertFloatToHalfBuffer( +const float* Source, +MLAS_FP16* Destination, +size_t Count +); + + /** * @brief Whether current CPU supports FP16 acceleration. */ bool MLASCALL @@ -1787,6 +1796,7 @@ MlasTranspose( M, N); } + #ifdef MLAS_F16VEC_INTRINSICS_SUPPORTED /** * @brief Max Pooling for fp16 NHWC diff --git a/onnxruntime/core/mlas/lib/cast.cpp b/onnxruntime/core/mlas/lib/cast.cpp index 24af4064bbd9b..a6138e29bd796 100644 --- a/onnxruntime/core/mlas/lib/cast.cpp +++ b/onnxruntime/core/mlas/lib/cast.cpp @@ -23,37 +23,35 @@ union fp32_bits { void MLASCALL MlasConvertHalfToFloatBuffer( - const unsigned short* Source, + const MLAS_FP16* Source, float* Destination, size_t Count ) { - if (GetMlasPlatform().CastF16ToF32Kernel == nullptr) { - // If there is no kernel use the reference implementation, adapted from mlas_float16.h. - constexpr fp32_bits magic = {113 << 23}; - constexpr uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift + for (size_t i = 0; i < Count; ++i) { + Destination[i] = Source[i].ToFloat(); + } + } else { + // If the kernel is available, use it to perform the conversion. + GetMlasPlatform().CastF16ToF32Kernel(reinterpret_cast(Source), Destination, Count); + } +} +void +MLASCALL +MlasConvertFloatToHalfBuffer( + const float* Source, + MLAS_FP16* Destination, + size_t Count +) +{ + if (GetMlasPlatform().CastF32ToF16Kernel == nullptr) { for (size_t i = 0; i < Count; ++i) { - fp32_bits o; - o.u = (Source[i] & 0x7fff) << 13; // exponent/mantissa bits - uint32_t exp = shifted_exp & o.u; // just the exponent - o.u += (127 - 15) << 23; // exponent adjust - - // handle exponent special cases - if (exp == shifted_exp) { // Inf/NaN? - o.u += (128 - 16) << 23; // extra exp adjust - } else if (exp == 0) { // Zero/Denormal? - o.u += 1 << 23; // extra exp adjust - o.f -= magic.f; // renormalize - } - - o.u |= (Source[i] & 0x8000) << 16; // sign bit - Destination[i] = o.f; + Destination[i] = MLAS_FP16(Source[i]); } - } else { // If the kernel is available, use it to perform the conversion. - GetMlasPlatform().CastF16ToF32Kernel(Source, Destination, Count); + GetMlasPlatform().CastF32ToF16Kernel(Source, reinterpret_cast(Destination), Count); } } diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 6f5db766b7def..8e8f46b8a102e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -610,13 +610,19 @@ void size_t N ); -typedef +typedef void(MLASCALL MLAS_CAST_F16_TO_F32_KERNEL)( const unsigned short* Source, float* Destination, size_t Count ); +typedef void(MLASCALL MLAS_CAST_F32_TO_F16_KERNEL)( + const float* Source, + unsigned short* Destination, + size_t Count +); + typedef void (MLASCALL MLAS_QLINEAR_BINARY_OP_S8_KERNEL)( @@ -880,6 +886,8 @@ extern "C" { #if defined(MLAS_TARGET_AMD64) MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelSse; MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx; + MLAS_CAST_F16_TO_F32_KERNEL MlasCastF16ToF32KernelAvx2; + MLAS_CAST_F32_TO_F16_KERNEL MlasCastF32ToF16KernelAvx2; #endif } @@ -1165,6 +1173,7 @@ struct MLAS_PLATFORM { const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; + MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; }; inline diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 4cd7faaa9e6ff..2b4d99800c546 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -245,6 +245,7 @@ Return Value: this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernel; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernel; this->CastF16ToF32Kernel = nullptr; + this->CastF32ToF16Kernel = nullptr; #if defined(MLAS_TARGET_AMD64_IX86) @@ -387,6 +388,9 @@ Return Value: this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; + this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; + this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; + // // Check if the processor supports Hybrid core architecture. diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp index 55d86bb9cc18e..baaa4ba1a3b1f 100644 --- a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp @@ -29,6 +29,51 @@ Module Name: #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" +void +MlasCastF16ToF32KernelAvx2(const unsigned short* src_fp16, float* dst_fp32, size_t size) +{ + size_t i = 0; + + // Process 16 elements at a time using AVX2 + for (; i + 15 < size; i += 16) { + // Load 16 FP16 values into an AVX2 register + __m256i fp16_values = _mm256_loadu_si256(reinterpret_cast(src_fp16 + i)); + + // Convert FP16 values to FP32 + __m256 fp32_values1 = _mm256_cvtph_ps(_mm256_castsi256_si128(fp16_values)); + __m256 fp32_values2 = _mm256_cvtph_ps(_mm256_extracti128_si256(fp16_values, 1)); + + // Store the converted FP32 values into the output vector + _mm256_storeu_ps(dst_fp32 + i, fp32_values1); + _mm256_storeu_ps(dst_fp32 + i + 8, fp32_values2); + } + + // Process any remaining elements + const MLAS_FP16* fp16 = reinterpret_cast(src_fp16); + for (; i < size; ++i) { + dst_fp32[i] = fp16[i].ToFloat(); + } +} + +void +MlasCastF32ToF16KernelAvx2(const float* src_fp32, unsigned short* dst_fp16, size_t size) +{ + size_t i = 0; + + // Process 8 elements at a time using AVX2 + for (; i + 8 <= size; i += 8) { + __m256 fp32_chunk = _mm256_loadu_ps(&src_fp32[i]); + __m128i fp16_chunk = _mm256_cvtps_ph(fp32_chunk, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128(reinterpret_cast<__m128i*>(&dst_fp16[i]), fp16_chunk); + } + + // Process any remaining elements + for (; i < size; ++i) { + MLAS_FP16 fp16(src_fp32[i]); + dst_fp16[i] = fp16.val; + } +} + MLAS_FORCEINLINE __m256 load_float_n_avx2(const float* data, int n) diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index f2aaa75cadd8d..35f3b12aeba35 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -258,7 +258,7 @@ struct TensorCaster { auto out_data = out.MutableData(); auto in_data = in.Data(); const size_t shape_size = narrow(shape.Size()); - MlasConvertHalfToFloatBuffer(&in_data[0].val, out_data, shape_size); + MlasConvertHalfToFloatBuffer(in_data, out_data, shape_size); } }; diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index 548f24e8ac69e..fa7c6bce7c23e 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -262,8 +262,8 @@ void RunTest(const TestOptions& opts, } // namespace -TEST(MatMulNBits, Float32) { - // onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); +template +void TestMatMulNBitsTyped() { for (auto M : {1, 2, 100}) { for (auto N : {/*2560, */ 1, 2, 32, 288}) { for (auto K : {/*2560, */ 16, 32, 64, 128, 256, 1024, 93, 1234}) { @@ -276,30 +276,53 @@ TEST(MatMulNBits, Float32) { if (base_opts.accuracy_level == 4) { base_opts.output_abs_error = 0.1f; + } else { + if constexpr (std::is_same::value) { + base_opts.output_abs_error = 0.01f; + } } { TestOptions opts = base_opts; - RunTest(opts); + RunTest(opts); } { TestOptions opts = base_opts; opts.has_zero_point = true; - RunTest(opts); + RunTest(opts); } #if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) { TestOptions opts = base_opts; opts.has_g_idx = true; - RunTest(opts); + RunTest(opts); + } + + { + TestOptions opts = base_opts; + opts.has_g_idx = true; + opts.has_bias = true; + if constexpr (std::is_same::value) { + if (opts.accuracy_level == 0 || opts.accuracy_level == 1) { + // CI failure (not able to repro on either local machines): + // M:100, N:288, K:1234, block_size:16, accuracy_level:0, has_zero_point:0, zp_is_4bit:1, has_g_idx:1, has_bias:1 + // The difference between cur_expected[i] and cur_actual[i] is 1.0401010513305664e-05, which exceeds tolerance, + // tolerance evaluates to 1.006456386676291e-05. + opts.output_abs_error = 0.0001f; + } + } + // only enabled for CPU EP for now + std::vector> explicit_eps; + explicit_eps.emplace_back(DefaultCpuExecutionProvider()); + RunTest(opts, std::move(explicit_eps)); } { TestOptions opts = base_opts; opts.has_zero_point = true, opts.zp_is_4bit = false; - RunTest(opts); + RunTest(opts); } #endif // !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) @@ -311,7 +334,7 @@ TEST(MatMulNBits, Float32) { std::vector> explicit_eps; explicit_eps.emplace_back(DefaultCpuExecutionProvider()); - RunTest(opts, std::move(explicit_eps)); + RunTest(opts, std::move(explicit_eps)); } } } @@ -320,6 +343,21 @@ TEST(MatMulNBits, Float32) { } } +TEST(MatMulNBits, Float32) { + // onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling("profile.json"); + TestMatMulNBitsTyped(); +} + +#ifdef MLAS_TARGET_AMD64_IX86 +#if !defined(ORT_NEURAL_SPEED) && !defined(USE_DML) +// Actual and expected difference is over 0.01 with DmlExecutionProvider. +// Skip the tests instead of raising the tolerance to make is pass. +TEST(MatMulNBits, Float16) { + TestMatMulNBitsTyped(); +} +#endif +#endif + #if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_DML) namespace { @@ -367,7 +405,7 @@ void RunTest(int64_t M, int64_t N, int64_t K, int64_t block_size, int64_t accura } } // namespace -TEST(MatMulNBits, Float16) { +TEST(MatMulNBits, Float16Cuda) { #if defined(USE_CUDA) || defined(USE_ROCM) auto has_gidx_options = {true, false}; #else From c63dd0234b4e0236b24fabdca005bbeb75ff4eb9 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 14 Sep 2024 12:36:20 +0800 Subject: [PATCH 4/9] [WebNN EP] Use opSupportLimits to dynamically check data type support (#22025) - Remove hard code data type checks and use WebNN's opSupportLimits instead - Add HasSupportedOutputsImpl for output data type validation - Get preferred layout info from opSupportLimits - Move Not op to logical_op_builder.cc because it should be there. This avoid the inconsistent input names in `unary_op_builder.cc`. --- .../core/providers/webnn/builders/helper.cc | 61 ++++++++++++++++--- .../core/providers/webnn/builders/helper.h | 43 +++++++++---- .../builders/impl/activation_op_builder.cc | 40 ------------ .../builders/impl/argmax_min_op_builder.cc | 27 -------- .../webnn/builders/impl/base_op_builder.cc | 52 ++++++++-------- .../webnn/builders/impl/base_op_builder.h | 9 ++- .../webnn/builders/impl/binary_op_builder.cc | 36 +++-------- .../webnn/builders/impl/cast_op_builder.cc | 32 +++++----- .../webnn/builders/impl/clip_op_builder.cc | 29 --------- .../webnn/builders/impl/concat_op_builder.cc | 28 +++++++++ .../webnn/builders/impl/conv_op_builder.cc | 35 +++-------- .../webnn/builders/impl/gather_op_builder.cc | 26 +++----- .../webnn/builders/impl/gemm_op_builder.cc | 35 +++-------- .../webnn/builders/impl/gru_op_builder.cc | 40 ++++-------- .../webnn/builders/impl/logical_op_builder.cc | 42 +++++++------ .../webnn/builders/impl/max_min_op_builder.cc | 29 ++++----- .../builders/impl/normalization_op_builder.cc | 35 ++++------- .../webnn/builders/impl/pad_op_builder.cc | 27 -------- .../builders/impl/reduction_op_builder.cc | 52 ---------------- .../webnn/builders/impl/resize_op_builder.cc | 26 -------- .../webnn/builders/impl/shape_op_builder.cc | 27 -------- .../webnn/builders/impl/slice_op_builder.cc | 26 -------- .../webnn/builders/impl/softmax_op_builder.cc | 26 -------- .../webnn/builders/impl/ternary_op_builder.cc | 23 ++----- .../builders/impl/transpose_op_builder.cc | 27 -------- .../webnn/builders/impl/unary_op_builder.cc | 43 ------------- .../providers/webnn/builders/model_builder.cc | 7 ++- .../providers/webnn/builders/model_builder.h | 5 +- .../providers/webnn/builders/op_builder.h | 3 +- .../webnn/builders/op_builder_factory.cc | 2 +- .../webnn/webnn_execution_provider.cc | 22 ++++--- .../webnn/webnn_execution_provider.h | 1 + 32 files changed, 281 insertions(+), 635 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index d3c1d06818db2..c4a633fcc92bb 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -45,12 +45,12 @@ bool GetShape(const NodeArg& node_arg, std::vector& shape, const loggin return true; } -bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, - const WebnnDeviceType device_type, const logging::Logger& logger) { +bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger) { const auto& op_builders = GetOpBuilders(); if (Contains(op_builders, node.OpType())) { const auto* op_builder = op_builders.at(node.OpType()); - return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, logger); + return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, logger); } else { return false; } @@ -86,6 +86,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger) { std::vector> supported_node_groups; @@ -105,7 +106,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v // Firstly check if platform supports the WebNN op. if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) { LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; - supported = IsNodeSupported(*node, graph_viewer, device_type, logger); + supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger); } LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() @@ -130,10 +131,54 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v return supported_node_groups; } -bool IsSupportedDataType(const int32_t data_type, - const std::unordered_set& supported_data_types) { - return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) != - supported_data_types.end(); +bool AreInputDataTypesSame(const std::string& op_type, + gsl::span input_types, + const logging::Logger& logger) { + for (size_t i = 1; i < input_types.size(); i++) { + if (input_types[0] != input_types[i]) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same, but [" + << input_types[0] << "] does not match " + << input_types[i] << "]."; + return false; + } + } + return true; +} + +bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) { + auto it = onnx_to_webnn_data_type_map.find(static_cast(onnx_data_type)); + if (it == onnx_to_webnn_data_type_map.end()) + return false; + + std::string webnn_data_type = it->second; + + // Check if WebNN supports the data type. + emscripten::val is_supported = webnn_supported_data_types.call("includes", + emscripten::val(webnn_data_type)); + return is_supported.as(); +} + +// Check if the input or output data type of ONNX node is supported by the WebNN operator. +bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger) { + std::string webnn_op_type; + if (!GetWebNNOpType(onnx_op_type, webnn_op_type)) + return false; + + if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type + << "] " << onnx_input_output_name + << " type: [" << onnx_data_type + << "] is not supported for now"; + return false; + } + + return true; } bool GetBidirectionalBroadcastShape(std::vector& shape_a, diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index b51092619db22..257fcff9ef50c 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -148,6 +148,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger); static const InlinedHashMap op_map = { {"Abs", "abs"}, @@ -250,20 +251,38 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn return true; } -static const std::unordered_set webnn_supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_UINT64, +inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) { + auto it = op_map.find(op_type); + // Returns false if the op_type is not listed in the op_map. + if (it == op_map.end()) { + return false; + } + webnn_op_type = it->second; + return true; +} + +static const InlinedHashMap onnx_to_webnn_data_type_map = { + {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"}, }; -bool IsSupportedDataType(const int32_t data_type, - const std::unordered_set& supported_data_types); +bool AreInputDataTypesSame(const std::string& op_type, + gsl::span input_types, + const logging::Logger& logger); +bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types); +bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger); bool GetBidirectionalBroadcastShape(std::vector& shape_a, std::vector& shape_b, diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc index 626aaf5c71b74..781ddcb896155 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -21,8 +21,6 @@ class ActivationOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -94,44 +92,6 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi return true; } -bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - // WebNN relu op supports float32, float16, int32, int8 input data types. - if (op_type == "Relu") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - // WebNN CPU backend does not support int32 data type for relu. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - } - } else { // Others only support float32 and float16. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index 05f3a742a3775..d61ae1a1f6be7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -22,8 +22,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -77,31 +75,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia return true; } -bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index fa535889299ea..8da255a288f17 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -38,9 +38,9 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { ORT_RETURN_IF_NOT( - IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), logger), - "Unsupported operator ", - node.OpType()); + IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), + model_builder.GetOpSupportLimits(), logger), + "Unsupported operator ", node.OpType()); ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() << "] type: [" << node.OpType() << "] was added"; @@ -50,8 +50,12 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& // Operator support related. bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, device_type, logger)) + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + if (!HasSupportedInputs(node, wnn_limits, logger)) + return false; + + if (!HasSupportedOutputsImpl(node, wnn_limits, logger)) return false; // We do not support external initializers for now. @@ -64,7 +68,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons return IsOpSupportedImpl(initializers, node, device_type, logger); } -bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, +bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { @@ -73,39 +77,33 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d } } - // WebNN CPU backend (TFLite) will enable float16 input data type soon, - // temporarily fallback float16 input data type for WebNN CPU. - if (device_type == WebnnDeviceType::CPU) { - const auto& input = *node.InputDefs()[0]; - - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) - return false; - } - - return HasSupportedInputsImpl(node, device_type, logger); + return HasSupportedInputsImpl(node, wnn_limits, logger); } bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, - const WebnnDeviceType /* device_type */, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { // We only check the type of input 0 by default, specific op builder can override this. const auto& input = *node.InputDefs()[0]; - + const auto& op_type = node.OpType(); int32_t input_type; if (!GetType(input, input_type, logger)) return false; - if (!IsSupportedDataType(input_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Input type: [" << input_type - << "] is not supported for now"; + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger); +} + +bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + // We only check the type of output 0 by default, specific op builder can override this. + const auto& output = *node.OutputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t output_type; + if (!GetType(output, output_type, logger)) return false; - } - return true; + return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger); } bool BaseOpBuilder::HasSupportedOpSet(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index 85e38b668cee4..584455f62cb4e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -28,7 +28,8 @@ class BaseOpBuilder : public IOpBuilder { // Operator support related. public: bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; protected: virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */, @@ -36,8 +37,10 @@ class BaseOpBuilder : public IOpBuilder { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const; // ONNX Runtime only *guarantees* support for models stamped // with opset version 7 or above for opset domain 'ai.onnx'. @@ -50,7 +53,7 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; - bool HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const; + bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 555de68cd60fe..af82a01b14de5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -22,7 +22,7 @@ class BinaryOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -86,7 +86,7 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -97,36 +97,14 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevice !GetType(*input_defs[1], input1_type, logger)) return false; - std::unordered_set supported_data_types; - // WebNN prelu op only supports float32, float16, int32, int8 input data types. - if (op_type == "Prelu") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - // WebNN CPU backend doesn't support int32 for prelu. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - } - } else { - supported_data_types = webnn_supported_data_types; - } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; + std::array input_types{input0_type, input1_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; - } - - return true; + std::string webnn_input_name = op_type == "PRelu" ? "input" : "a"; + std::string onnx_input_name = op_type == "PRelu" || op_type == "Pow" ? "X" : "A"; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger); } void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index a08e1681a8464..3c4fc822f3d01 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -80,26 +80,22 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. +bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input_type; -bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, - const Node& node, - const WebnnDeviceType device_type, - const logging::Logger& logger) const { - NodeAttrHelper helper(node); - // Check cast output type. - const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); - - // WebNN CPU backend doesn't support casting to uint64 data type. - if (device_type == WebnnDeviceType::CPU && to_type == ONNX_NAMESPACE::TensorProto_DataType_UINT64) { - LOGS(logger, VERBOSE) << "Cast to uint64 is not supported for WebNN CPU backend."; + if (!GetType(*input_defs[0], input_type, logger)) return false; - } - if (!IsSupportedDataType(to_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "WebNN doesn't support casting to type " << to_type << "."; + + if (!IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "input", logger)) return false; - } - return true; + NodeAttrHelper helper(node); + // Check cast to type. + const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); + return IsDataTypeSupportedByOp(op_type, to_type, wnn_limits, "output", "to", logger); } void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index b5c3206072d50..374143c886849 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -25,8 +25,6 @@ class ClipOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -94,33 +92,6 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, }; } -bool ClipOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support int32, uint32, int64, uint64 input data types for clamp. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index dedc76b80e978..48dd6f3beb020 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -19,6 +19,10 @@ class ConcatOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -52,6 +56,30 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } +bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type; + + if (!GetType(*input_defs[0], input0_type, logger)) + return false; + + for (size_t i = 1; i < input_defs.size(); i++) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + return false; + } + + std::array input_types{input0_type, input_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } + } + + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); +} + void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 76a8a178678df..35498c2e9b8b7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -29,7 +29,7 @@ class ConvOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -397,7 +397,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -415,35 +415,18 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - std::unordered_set supported_data_types; - if (op_type == "Conv" || op_type == "ConvTranspose") { - // WebNN conv2d and convTranspose2d only support float32 and float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else if (op_type == "ConvInteger") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - }; + InlinedVector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input3) { + input_types.push_back(input3_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index 23233539d34c7..ae9fe3e3f3bd1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -22,7 +22,7 @@ class GatherOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -69,29 +69,19 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; const auto& op_type = node.OpType(); int32_t input_type; - if (!GetType(input, input_type, logger)) + int32_t indices_type; + if (!GetType(input, input_type, logger) || + !GetType(indices, indices_type, logger)) return false; - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for gather. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index bd452b118fe3e..30e024792ed42 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -25,7 +25,7 @@ class GemmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -215,7 +215,7 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer return true; } -bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -233,35 +233,18 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - std::unordered_set supported_data_types; - if (op_type == "Gemm" || op_type == "MatMul") { - // WebNN gemm and matmul only support float32 and float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else if (op_type == "MatMulInteger") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - }; + InlinedVector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input3) { + input_types.push_back(input3_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); } void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index 23cc7f1b11459..c92fe7366d494 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -26,7 +26,7 @@ class GruOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -185,7 +185,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c return true; } -bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -208,37 +208,21 @@ bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTyp return false; } - std::unordered_set supported_data_types; - if (device_type == WebnnDeviceType::CPU) { - // WebNN CPU backend only support float32 input data type. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - }; - } else if (device_type == WebnnDeviceType::GPU) { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; + InlinedVector input_types = {input0_type, input1_type, input2_type}; + if (has_input3) { + input_types.push_back(input3_type); } - - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input4) { + input_types.push_back(input4_type); } - - if (input0_type != input1_type || - input0_type != input2_type || - (has_input3 && input0_type != input3_type) || - (has_input4 && input0_type != input4_type) || - (has_input5 && input0_type != input5_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (has_input5) { + input_types.push_back(input5_type); + } + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 23f3a938fee5e..ea7f70b4598e6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -21,7 +21,7 @@ class LogicalOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -29,9 +29,14 @@ class LogicalOpBuilder : public BaseOpBuilder { Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { + const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); - emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); - emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); + emscripten::val input0 = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val input1 = emscripten::val::undefined(); + if (input_defs.size() > 1) { + input1 = model_builder.GetOperand(input_defs[1]->Name()); + } + emscripten::val output = emscripten::val::object(); emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); @@ -45,6 +50,8 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons output = model_builder.GetBuilder().call("lesser", input0, input1, options); } else if (op_type == "LessOrEqual") { output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1, options); + } else if (op_type == "Not") { + output = model_builder.GetBuilder().call("logicalNot", input0, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -61,7 +68,7 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - if (input_defs.size() < 2) { + if (input_defs.size() < 2 && op_type != "Not") { LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: " << input_defs.size(); return false; @@ -69,31 +76,27 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; int32_t input1_type; - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger)) - return false; - - if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; + if (!GetType(*input_defs[0], input0_type, logger)) return false; - } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; + if (op_type != "Not") { + if (!GetType(*input_defs[1], input1_type, logger)) + return false; + std::array input_types{input0_type, input1_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } } - return true; + std::string onnx_input_name = op_type == "Not" ? "X" : "A"; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { @@ -107,6 +110,7 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& "GreaterOrEqual", "Less", "LessOrEqual", + "Not", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 5d88afda7b6a7..e111ca412c6e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -22,7 +22,7 @@ class MaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -87,31 +87,28 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; - int32_t input1_type; - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger)) + if (!GetType(*input_defs[0], input0_type, logger)) return false; - if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; - } + for (size_t i = 1; i < input_defs.size(); i++) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + return false; + } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; + std::array input_types{input0_type, input_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 4d068baf35e72..a3c6b8fdcea9b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -25,7 +25,7 @@ class NormalizationOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -182,7 +182,7 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi return true; } -bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -203,30 +203,21 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const Webn return false; } - // WebNN batchNormalization, instanceNormalization, layerNormalization - // only support float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + std::vector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type) || - (has_input4 && input0_type != input4_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (has_input3) { + input_types.push_back(input3_type); + } + if (has_input4) { + input_types.push_back(input4_type); + } + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc index 071155a2fb372..d8373a45e4423 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc @@ -28,8 +28,6 @@ class PadOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -196,31 +194,6 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } // namespace webnn -bool PadOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for pad. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index 3e6d4d9820e9a..93ad933d71c34 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -31,8 +31,6 @@ class ReductionOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -147,56 +145,6 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ return true; } -bool ReductionOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - if (op_type == "ReduceL1" || op_type == "ReduceProd" || - op_type == "ReduceSum" || op_type == "ReduceSumSquare") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, - ONNX_NAMESPACE::TensorProto_DataType_UINT64, - }; - - if (device_type == WebnnDeviceType::CPU) { - // WebNN CPU backend doesn't support uint32 and uint64 for reduceL1, - // reduceProd, reduceSum and reduceSumSquare. - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - } else if (op_type == "ReduceL2" || op_type == "ReduceLogSum" || - op_type == "ReduceLogSumExp" || op_type == "ReduceMean") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else { // ReduceMax and ReduceMin - supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 for reduceMax and reduceMin. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - } - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 2218c858951d3..9dc79f4f52f46 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -35,8 +35,6 @@ class ResizeOpBuilder : public BaseOpBuilder { // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing. // We only support Resize opset 11+ here. int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; // Helper functions @@ -275,30 +273,6 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool ResizeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - // WebNN resample2d op only supports float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 0eb7dafdffe4d..6b56d2c740f40 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -18,11 +18,6 @@ class ShapeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - private: - bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; }; Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -69,28 +64,6 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -// Operator support related. - -bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - int32_t output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; - if (!IsSupportedDataType(output_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Output type: [" << output_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index bef13841c646c..3f0d633ac888b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -29,8 +29,6 @@ class SliceOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; // TODO: Support Slice opset < 10, which uses attributes for starts and ends. int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; } - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -166,30 +164,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint64 input data type for slice. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 798cfabae65db..b1b737b114998 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -24,8 +24,6 @@ class SoftmaxOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -63,30 +61,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool SoftmaxOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - // WebNN softmax only supports float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 2ed8330bf25be..4b6cf312074ba 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -18,7 +18,7 @@ class TernaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -46,7 +46,7 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons return Status::OK(); } -bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -59,27 +59,14 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevic !GetType(*input_defs[2], input2_type, logger)) return false; - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint64 X, Y data type for where. - if (device_type == WebnnDeviceType::CPU && op_type == "Where") { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } // ONNX's condition data type is bool which is same as WebNN. // Only need to check X, Y data types. - if (!IsSupportedDataType(input1_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input1_type - << "] is not supported for now"; - return false; - } - - if (input1_type != input2_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input X, Y data types should be the same."; + std::array input_types{input1_type, input2_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc index 03c88ad9db88a..3a5e39f7f7a56 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -18,8 +18,6 @@ class TransposeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -50,31 +48,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool TransposeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for transpose. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index 061404c8a9ce0..8e64e98445f03 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -18,8 +18,6 @@ class UnaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; // Add operator related. @@ -51,8 +49,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const output = model_builder.GetBuilder().call("log", input, options); } else if (op_type == "Neg") { output = model_builder.GetBuilder().call("neg", input, options); - } else if (op_type == "Not") { - output = model_builder.GetBuilder().call("logicalNot", input, options); } else if (op_type == "Reciprocal") { output = model_builder.GetBuilder().call("reciprocal", input, options); } else if (op_type == "Sin") { @@ -70,44 +66,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } -bool UnaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - if (op_type == "Identity") { - supported_data_types = webnn_supported_data_types; - } else if (op_type == "Abs" || op_type == "Neg") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - } else if (op_type == "Not") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - }; - } else { // Others only support float32, float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; @@ -123,7 +81,6 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op "Identity", "Log", "Neg", - "Not", "Reciprocal", "Sin", "Sqrt", diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 44bec1fb6fd48..b58bf8233692e 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -21,12 +21,13 @@ namespace webnn { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, const emscripten::val& context, const DataLayout preferred_layout, - const WebnnDeviceType wnn_device_type) + const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits) : graph_viewer_(graph_viewer), logger_(logger), wnn_context_(context), preferred_layout_(preferred_layout), - wnn_device_type_(wnn_device_type) { + wnn_device_type_(wnn_device_type), + wnn_limits_(wnn_limits) { // Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build() // is only allowed to be called once. wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context); @@ -102,7 +103,7 @@ Status ModelBuilder::RegisterInitializers() { desc.set("dimensions", emscripten::val::array(dims)); auto data_type = tensor.data_type(); emscripten::val operand = emscripten::val::object(); - if (IsSupportedDataType(data_type, webnn_supported_data_types)) { + if (IsSupportedDataType(data_type, wnn_limits_["constant"]["dataTypes"])) { ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); auto num_elements = SafeInt(Product(shape)); emscripten::val view = emscripten::val::undefined(); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 2d686070cdcc1..256337baeba7e 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -23,7 +23,7 @@ class ModelBuilder { public: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, const emscripten::val& context, const DataLayout preferred_layout, - const WebnnDeviceType wnn_device_type); + const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits); ~ModelBuilder() = default; Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; @@ -35,6 +35,8 @@ class ModelBuilder { const emscripten::val& GetBuilder() const { return wnn_builder_; } const emscripten::val& GetContext() const { return wnn_context_; } const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); } + const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; } + void AddOperand(const std::string& name, const emscripten::val& operand); const emscripten::val& GetZeroConstant(const std::string& data_type); // Use the buffers to persist WebNN allocated data like transposed weight. @@ -66,6 +68,7 @@ class ModelBuilder { emscripten::val wnn_builder_ = emscripten::val::undefined(); DataLayout preferred_layout_; WebnnDeviceType wnn_device_type_; + emscripten::val wnn_limits_ = emscripten::val::undefined(); InlinedHashMap wnn_operands_; std::vector input_names_; std::vector output_names_; diff --git a/onnxruntime/core/providers/webnn/builders/op_builder.h b/onnxruntime/core/providers/webnn/builders/op_builder.h index 6ecc5d1068963..bb69a6a545597 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder.h @@ -29,7 +29,8 @@ class IOpBuilder { public: // Check if an operator is supported. virtual bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const = 0; + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const = 0; }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 01761290f07e3..3dc1c7966ae41 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -25,7 +25,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateUnaryOpBuilder("Identity", op_registrations); CreateUnaryOpBuilder("Log", op_registrations); CreateUnaryOpBuilder("Neg", op_registrations); - CreateUnaryOpBuilder("Not", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); CreateUnaryOpBuilder("Sin", op_registrations); CreateUnaryOpBuilder("Sqrt", op_registrations); @@ -118,6 +117,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateLogicalOpBuilder("GreaterOrEqual", op_registrations); CreateLogicalOpBuilder("Less", op_registrations); CreateLogicalOpBuilder("LessOrEqual", op_registrations); + CreateLogicalOpBuilder("Not", op_registrations); } { // Max/Min diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index b918daf838c99..b729623c5d3d8 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -21,10 +21,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { - preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; } else { - preferred_layout_ = DataLayout::NCHW; if (webnn_device_flags.compare("gpu") == 0) { wnn_device_type_ = webnn::WebnnDeviceType::GPU; } else if (webnn_device_flags.compare("npu") == 0) { @@ -38,6 +36,17 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } + + // Retrieve the level of support for different WebNN operators. + // This varies across implementations and is obtained via the WebNN's opSupportLimits() function. + // https://www.w3.org/TR/webnn/#api-mlcontext-opsupportlimits + wnn_limits_ = wnn_context_.call("opSupportLimits"); + + if (wnn_limits_["preferredInputLayout"].as().compare("nhwc") == 0) { + preferred_layout_ = DataLayout::NHWC; + } else { + preferred_layout_ = DataLayout::NCHW; + } } WebNNExecutionProvider::~WebNNExecutionProvider() {} @@ -82,7 +91,7 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view ORT_THROW("Failed to create WebNN builder."); } - const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger); + const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger); wnn_builder = emscripten::val::undefined(); if (node_groups.empty()) { @@ -213,7 +222,7 @@ common::Status WebNNExecutionProvider::Compile(const std::vector model; ORT_RETURN_IF_ERROR(builder.Compile(model)); @@ -295,11 +304,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector Date: Sun, 15 Sep 2024 18:31:55 -0400 Subject: [PATCH 5/9] [java] Adding ability to load a model from a memory mapped byte buffer (#20062) ### Description Adds support for constructing an `OrtSession` from a `java.nio.ByteBuffer`. These buffers can be memory mapped from files which means there doesn't need to be copies of the model protobuf held in Java, reducing peak memory usage during session construction. ### Motivation and Context Reduces memory usage on model construction by not requiring as many copies on the Java side. Should help with #19599. --- .../java/ai/onnxruntime/OrtEnvironment.java | 49 ++++++++++++++++++- .../main/java/ai/onnxruntime/OrtSession.java | 35 +++++++++++++ .../main/native/ai_onnxruntime_OrtSession.c | 25 +++++++++- .../java/ai/onnxruntime/InferenceTest.java | 31 ++++++++++++ 4 files changed, 138 insertions(+), 2 deletions(-) diff --git a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java index 26137e88478b5..8382ef06e26e5 100644 --- a/java/src/main/java/ai/onnxruntime/OrtEnvironment.java +++ b/java/src/main/java/ai/onnxruntime/OrtEnvironment.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2023 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ package ai.onnxruntime; @@ -7,6 +7,7 @@ import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtTrainingSession.OrtCheckpointState; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.EnumSet; import java.util.Objects; import java.util.logging.Logger; @@ -236,6 +237,52 @@ OrtSession createSession(String modelPath, OrtAllocator allocator, SessionOption return new OrtSession(this, modelPath, allocator, options); } + /** + * Create a session using the specified {@link SessionOptions}, model and the default memory + * allocator. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @param options The session options. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + public OrtSession createSession(ByteBuffer modelBuffer, SessionOptions options) + throws OrtException { + return createSession(modelBuffer, defaultAllocator, options); + } + + /** + * Create a session using the default {@link SessionOptions}, model and the default memory + * allocator. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + public OrtSession createSession(ByteBuffer modelBuffer) throws OrtException { + return createSession(modelBuffer, new OrtSession.SessionOptions()); + } + + /** + * Create a session using the specified {@link SessionOptions} and model buffer. + * + * @param modelBuffer Byte buffer representing an ONNX model. Must be a direct byte buffer. + * @param allocator The memory allocator to use. + * @param options The session options. + * @return An {@link OrtSession} with the specified model. + * @throws OrtException If the model failed to parse, wasn't compatible or caused an error. + */ + OrtSession createSession(ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + throws OrtException { + Objects.requireNonNull(modelBuffer, "model array must not be null"); + if (modelBuffer.remaining() == 0) { + throw new OrtException("Invalid model buffer, no elements remaining."); + } else if (!modelBuffer.isDirect()) { + throw new OrtException("ByteBuffer is not direct."); + } + return new OrtSession(this, modelBuffer, allocator, options); + } + /** * Create a session using the specified {@link SessionOptions}, model and the default memory * allocator. diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 8fe73ff69e169..f87cbc76ef141 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -11,6 +11,7 @@ import ai.onnxruntime.providers.OrtFlags; import ai.onnxruntime.providers.OrtTensorRTProviderOptions; import java.io.IOException; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -94,6 +95,31 @@ public class OrtSession implements AutoCloseable { allocator); } + /** + * Creates a session reading the model from the supplied byte buffer. + * + *

Must be a direct byte buffer. + * + * @param env The environment. + * @param modelBuffer The model protobuf as a byte buffer. + * @param allocator The allocator to use. + * @param options Session configuration options. + * @throws OrtException If the model was corrupted or some other error occurred in native code. + */ + OrtSession( + OrtEnvironment env, ByteBuffer modelBuffer, OrtAllocator allocator, SessionOptions options) + throws OrtException { + this( + createSession( + OnnxRuntime.ortApiHandle, + env.getNativeHandle(), + modelBuffer, + modelBuffer.position(), + modelBuffer.remaining(), + options.getNativeHandle()), + allocator); + } + /** * Private constructor to build the Java object wrapped around a native session. * @@ -514,6 +540,15 @@ private static native long createSession( private static native long createSession( long apiHandle, long envHandle, byte[] modelArray, long optsHandle) throws OrtException; + private static native long createSession( + long apiHandle, + long envHandle, + ByteBuffer modelBuffer, + int bufferPos, + int bufferSize, + long optsHandle) + throws OrtException; + private native long getNumInputs(long apiHandle, long nativeHandle) throws OrtException; private native String[] getInputNames(long apiHandle, long nativeHandle, long allocatorHandle) diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index f4d5ab080cd31..ee8cdee659296 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019, 2020, 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2019, 2024 Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -48,6 +48,29 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_la return (jlong)session; } +/* + * Class: ai_onnxruntime_OrtSession + * Method: createSession + * Signature: (JJLjava/nio/ByteBuffer;IIJ)J + */ +JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtSession_createSession__JJLjava_nio_ByteBuffer_2IIJ(JNIEnv* jniEnv, jclass jclazz, jlong apiHandle, jlong envHandle, jobject buffer, jint bufferPos, jint bufferSize, jlong optsHandle) { + (void)jclazz; // Required JNI parameter not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtEnv* env = (OrtEnv*)envHandle; + OrtSessionOptions* opts = (OrtSessionOptions*)optsHandle; + OrtSession* session = NULL; + + // Extract the buffer + char* bufferArr = (char*)(*jniEnv)->GetDirectBufferAddress(jniEnv, buffer); + // Increment by bufferPos bytes + bufferArr = bufferArr + bufferPos; + + // Create the session + checkOrtStatus(jniEnv, api, api->CreateSessionFromArray(env, bufferArr, bufferSize, opts, &session)); + + return (jlong)session; +} + /* * Class: ai_onnxruntime_OrtSession * Method: createSession diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 3340a2e5e9f3a..f76e1b3b20e19 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -20,10 +20,14 @@ import ai.onnxruntime.OrtSession.SessionOptions.OptLevel; import java.io.File; import java.io.IOException; +import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.nio.LongBuffer; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.FileChannel.MapMode; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -338,6 +342,33 @@ public void partialInputsTest() throws OrtException { } } + @Test + public void createSessionFromByteBuffer() throws IOException, OrtException { + Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); + try (RandomAccessFile file = new RandomAccessFile(modelPath.toFile(), "r"); + FileChannel channel = file.getChannel()) { + MappedByteBuffer modelBuffer = channel.map(MapMode.READ_ONLY, 0, channel.size()); + try (OrtSession.SessionOptions options = new SessionOptions(); + OrtSession session = env.createSession(modelBuffer, options)) { + assertNotNull(session); + assertEquals(1, session.getNumInputs()); // 1 input node + Map inputInfoList = session.getInputInfo(); + assertNotNull(inputInfoList); + assertEquals(1, inputInfoList.size()); + NodeInfo input = inputInfoList.get("data_0"); + assertEquals("data_0", input.getName()); // input node name + assertTrue(input.getInfo() instanceof TensorInfo); + TensorInfo inputInfo = (TensorInfo) input.getInfo(); + assertEquals(OnnxJavaType.FLOAT, inputInfo.type); + int[] expectedInputDimensions = new int[] {1, 3, 224, 224}; + assertEquals(expectedInputDimensions.length, inputInfo.shape.length); + for (int i = 0; i < expectedInputDimensions.length; i++) { + assertEquals(expectedInputDimensions[i], inputInfo.shape[i]); + } + } + } + } + @Test public void createSessionFromByteArray() throws IOException, OrtException { Path modelPath = TestHelpers.getResourcePath("/squeezenet.onnx"); From 6d7235ba5ab995e42a0e251874e65e9d7eaa2997 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sun, 15 Sep 2024 21:55:38 -0400 Subject: [PATCH 6/9] [Java] Exposing SessionOptions.SetDeterministicCompute (#18998) ### Description Exposes `SetDeterministicCompute` in Java, added to the C API by #18944. ### Motivation and Context Parity between C and Java APIs. --- .../main/java/ai/onnxruntime/OrtSession.java | 17 +++++++++++++++++ .../ai_onnxruntime_OrtSession_SessionOptions.c | 13 +++++++++++++ .../test/java/ai/onnxruntime/InferenceTest.java | 1 + 3 files changed, 31 insertions(+) diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index f87cbc76ef141..6d146d5857d3c 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -942,6 +942,20 @@ public void setSymbolicDimensionValue(String dimensionName, long dimensionValue) OnnxRuntime.ortApiHandle, nativeHandle, dimensionName, dimensionValue); } + /** + * Set whether to use deterministic compute. + * + *

Default is false. If set to true, this will enable deterministic compute for GPU kernels + * where possible. Note that this most likely will have a performance cost. + * + * @param value Should the compute be deterministic? + * @throws OrtException If there was an error in native code. + */ + public void setDeterministicCompute(boolean value) throws OrtException { + checkClosed(); + setDeterministicCompute(OnnxRuntime.ortApiHandle, nativeHandle, value); + } + /** * Disables the per session thread pools. Must be used in conjunction with an environment * containing global thread pools. @@ -1327,6 +1341,9 @@ private native void registerCustomOpsUsingFunction( private native void closeOptions(long apiHandle, long nativeHandle); + private native void setDeterministicCompute( + long apiHandle, long nativeHandle, boolean isDeterministic) throws OrtException; + private native void addFreeDimensionOverrideByName( long apiHandle, long nativeHandle, String dimensionName, long dimensionValue) throws OrtException; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index ff9348c299e90..ff6b7fa703e6e 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -259,6 +259,19 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setSes checkOrtStatus(jniEnv,api,api->SetSessionLogVerbosityLevel(options,logLevel)); } +/* + * Class: ai_onnxruntime_OrtSession_SessionOptions + * Method: setDeterministicCompute + * Signature: (JJZ)V + */ +JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_setDeterministicCompute + (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong optionsHandle, jboolean isDeterministic) { + (void) jobj; // Required JNI parameters not needed by functions which don't need to access their host object. + const OrtApi* api = (const OrtApi*)apiHandle; + OrtSessionOptions* options = (OrtSessionOptions*) optionsHandle; + checkOrtStatus(jniEnv,api,api->SetDeterministicCompute(options, isDeterministic)); +} + /* * Class: ai_onnxruntime_OrtSession_SessionOptions * Method: registerCustomOpLibrary diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index f76e1b3b20e19..11141a3a65a3e 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -1263,6 +1263,7 @@ public void testExtraSessionOptions() throws OrtException, IOException { options.setLoggerId("monkeys"); options.setSessionLogLevel(OrtLoggingLevel.ORT_LOGGING_LEVEL_FATAL); options.setSessionLogVerbosityLevel(5); + options.setDeterministicCompute(true); Map configEntries = options.getConfigEntries(); assertTrue(configEntries.isEmpty()); options.addConfigEntry("key", "value"); From 1a1669fe817232e7d19c6459da0fc610e0c74b0a Mon Sep 17 00:00:00 2001 From: George Wu Date: Mon, 16 Sep 2024 09:12:13 -0700 Subject: [PATCH 7/9] use node name in transpose optimizer when adding nodes rather than optype (#22084) patch from @john-dance "The main change is simple: Use the original node name rather than the original node op_type when creating new nodes. Here are my comments on the change: ------ The onnx runtime uses the op_type as the basis for a new node name, so a node claimed by QNN EP might be named Conv_token_1 with no relation to the original /conv1/Conv. This patch: 1. Adds OpName as a virtual function in NodeRef and implements it in ApiNode. 2. AddNode now takes an op_name and op_type and passes them both to CreateNodeHelper. 3. CreateNodeHelper uses the op_name rather than the op_type in GenerateNodeName 4. Direct calls to AddNode are modified to either use the NodeRef if available, or just repeat the op_type if not available. The result is that the new nodes are named something like /conv1/Conv_token_1, allowing a straight forward mapping back to the original model node (if they exist in the original graph)." --- .../onnx_transpose_optimization.cc | 18 +++++++++--------- .../transpose_optimization/optimizer_api.h | 6 +++++- .../ort_optimizer_api_impl.cc | 17 +++++++++++------ .../internal_testing/internal_testing_tests.cc | 4 ++-- 4 files changed, 27 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index df81367c5bbee..5d689a9d933e8 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -78,7 +78,7 @@ static std::unique_ptr MakeNode1Attr(api::GraphRef& graph, std::st std::string_view input, std::string_view attr_name, const std::vector& attr_val) { std::vector inputs{input}; - std::unique_ptr node = graph.AddNode(op_type, inputs, /*num_outputs*/ 1); + std::unique_ptr node = graph.AddNode(op_type, op_type, inputs, /*num_outputs*/ 1); node->SetAttributeInts(attr_name, attr_val); return node; } @@ -102,7 +102,7 @@ static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api:: std::vector inputs{input, axes_initializer}; - return graph.AddNode(op_type, inputs, /*num_outputs*/ 1); + return graph.AddNode(op_type, op_type, inputs, /*num_outputs*/ 1); } ///

@@ -136,7 +136,7 @@ static std::unique_ptr MakeQuantizeOp(api::GraphRef& graph, std::s std::optional block_size, std::optional output_dtype, std::optional saturate) { - std::unique_ptr node = graph.AddNode("QuantizeLinear", inputs, /* num_outputs */ 1, domain); + std::unique_ptr node = graph.AddNode("QuantizeLinear", "QuantizeLinear", inputs, /* num_outputs */ 1, domain); SetAttrIfNotDefault(*node, "axis", axis, 1); @@ -170,7 +170,7 @@ static std::unique_ptr MakeDequantizeOp(api::GraphRef& graph, std: std::vector inputs, std::optional axis, std::optional block_size) { - std::unique_ptr node = graph.AddNode("DequantizeLinear", inputs, /* num_outputs */ 1, domain); + std::unique_ptr node = graph.AddNode("DequantizeLinear", "DequantizeLinear", inputs, /* num_outputs */ 1, domain); SetAttrIfNotDefault(*node, "axis", axis, 1); @@ -1724,7 +1724,7 @@ static bool HandleShape(HandlerArgs& args) { // X -> Shape -> Y, Gather std::vector gather_inputs{"", perm_const}; - auto gather_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); + auto gather_ptr = args.ctx.graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather = *gather_ptr; gather.SetAttributeInt("axis", 0); @@ -1767,7 +1767,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con // inputs that would never be quantized. std::string_view gather_indices_const = AddInitializerInt64(graph, /*shape*/ {rank_int}, perm); std::vector gather_inputs{input_name, gather_indices_const}; - auto gather_ptr = graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); + auto gather_ptr = graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather = *gather_ptr; std::string_view gather_output = gather.Outputs()[0]; graph.CopyValueInfo(input_name, gather_output); @@ -2215,7 +2215,7 @@ static bool HandleTile(HandlerArgs& args) { // Case 2: Repeats is computed. Insert Gather node. std::string_view perm_inv_const = AddInitializerInt64(args.ctx.graph, perm_shape, args.perm_inv); std::vector gather_inputs{repeats_inp, perm_inv_const}; - auto gather_node_ptr = args.ctx.graph.AddNode("Gather", gather_inputs, /*num_outputs*/ 1); + auto gather_node_ptr = args.ctx.graph.AddNode("Gather", "Gather", gather_inputs, /*num_outputs*/ 1); api::NodeRef& gather_node = *gather_node_ptr; std::string_view gather_output = gather_node.Outputs()[0]; args.ctx.graph.CopyValueInfo(repeats_inp, gather_output); @@ -2265,7 +2265,7 @@ static void RemoveCancelingTransposeNodes(HandlerArgs& args) { // Worst-case scenario: Both parent output and 2nd transpose/reshape output cannot be removed (both graph outputs) // despite computing the same value. Use an Identity op instead. std::vector single_empty_input{""}; - auto identity_ptr = args.ctx.graph.AddNode("Identity", single_empty_input, /*num_outputs*/ 1); + auto identity_ptr = args.ctx.graph.AddNode("Identity", "Identity", single_empty_input, /*num_outputs*/ 1); api::NodeRef& identity = *identity_ptr; args.ctx.graph.MoveOutput(args.node, 0, identity, 0); identity.SetInput(0, transpose_input); @@ -2297,7 +2297,7 @@ static bool HandleTransposeImpl(HandlerArgs& args, const std::vector& n // replace Reshape with Transpose to simplify the logic. // use the same input as the 1st Transpose, move the output from the Reshape to the new Transpose node, // and remove the Reshape node. - new_node = args.ctx.graph.AddNode("Transpose", {args.transpose.Inputs()[0]}, 1); + new_node = args.ctx.graph.AddNode("Transpose", "Transpose", {args.transpose.Inputs()[0]}, 1); args.ctx.graph.MoveOutput(args.node, 0, *new_node, 0); args.ctx.graph.RemoveNode(args.node); } else { diff --git a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h index 211734f4bacc8..7122aec45e61a 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h +++ b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h @@ -146,6 +146,9 @@ class ValueInfoRef { /// class NodeRef { public: + /// Node name + virtual std::string_view Name() const = 0; + /// Op computed by the node virtual std::string_view OpType() const = 0; @@ -361,6 +364,7 @@ class GraphRef { /// generated. Outputs of created node have unspecified shapes/dtypes. They will be populated afterwards using /// CopyValueInfo. /// + /// The new node's name /// The new node's op type /// Inputs for the node. "" for missing optional inputs. /// @@ -368,7 +372,7 @@ class GraphRef { /// /// The new node's domain. Empty string signifies default onnx domain. /// The new node - virtual std::unique_ptr AddNode(std::string_view op_type, const std::vector& inputs, + virtual std::unique_ptr AddNode(std::string_view name, std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain = /*kOnnxDomain*/ "") = 0; /// diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index 33408474f92a6..f87df746234fa 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -80,6 +80,10 @@ class ApiNode final : public api::NodeRef { return node_; } + std::string_view Name() const override { + return node_.Name(); + } + std::string_view OpType() const override { return node_.OpType(); } @@ -134,7 +138,7 @@ class ApiGraph final : public api::GraphRef { std::unique_ptr GetNodeProducingOutput(std::string_view name) const override; void TransposeInitializer(std::string_view name, const std::vector& perm) override; void ReshapeInitializer(std::string_view name, const std::vector& shape) override; - std::unique_ptr AddNode(std::string_view op_type, const std::vector& inputs, + std::unique_ptr AddNode(std::string_view name, std::string_view op_type, const std::vector& inputs, size_t num_outputs = 1, std::string_view domain = "") override; std::unique_ptr CopyNode(const api::NodeRef& source_node, std::string_view op_type, @@ -621,11 +625,12 @@ void ApiGraph::ReshapeInitializer(std::string_view name, const std::vectorSetShape(new_shape); } -static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_type, +static Node& CreateNodeHelper(onnxruntime::Graph& graph, std::string_view op_name, std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain, int since_version, std::string_view node_ep) { const std::string op_type_str(op_type); - std::string name = graph.GenerateNodeName(op_type_str); + const std::string op_name_str(op_name); + std::string name = graph.GenerateNodeName(op_name_str); std::vector input_args; std::vector output_args; @@ -731,11 +736,11 @@ static int GetSinceVersionForNewOp(std::string_view op_type, std::string_view do return *since_version; } -std::unique_ptr ApiGraph::AddNode(std::string_view op_type, +std::unique_ptr ApiGraph::AddNode(std::string_view name, std::string_view op_type, const std::vector& inputs, size_t num_outputs, std::string_view domain) { int since_version = GetSinceVersionForNewOp(op_type, domain, graph_.DomainToVersionMap()); - Node& node = CreateNodeHelper(graph_, op_type, inputs, num_outputs, + Node& node = CreateNodeHelper(graph_, name, op_type, inputs, num_outputs, domain, since_version, new_node_ep_ != nullptr ? new_node_ep_ : ""); return std::make_unique(node, graph_); @@ -744,7 +749,7 @@ std::unique_ptr ApiGraph::AddNode(std::string_view op_type, std::unique_ptr ApiGraph::CopyNode(const api::NodeRef& source_node, std::string_view op_type, std::string_view domain, std::optional since_version) { const int new_node_since_version = since_version.has_value() ? *since_version : source_node.SinceVersion(); - Node& node = CreateNodeHelper(graph_, op_type, source_node.Inputs(), + Node& node = CreateNodeHelper(graph_, source_node.Name(), op_type, source_node.Inputs(), source_node.Outputs().size(), domain, new_node_since_version, source_node.GetExecutionProviderType()); diff --git a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc index 9f7be524daa34..67fb35d26e6dc 100644 --- a/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc +++ b/onnxruntime/test/providers/internal_testing/internal_testing_tests.cc @@ -196,7 +196,7 @@ TEST(InternalTestingEP, TestMixOfStaticAndCompiledKernels) { // Error message should come from the Conv implementation with the statically registered kernel ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches), - "Non-zero status code returned while running Conv node. Name:'Conv' " + "Non-zero status code returned while running Conv node. Name:'_token_2' " "Status Message: TODO: add NHWC implementation here."); } @@ -242,7 +242,7 @@ TEST(InternalTestingEP, TestNhwcConversionOfStaticKernels) { std::vector fetches; ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(session.Run(feeds, output_names, &fetches), - "Non-zero status code returned while running Conv node. Name:'Conv' " + "Non-zero status code returned while running Conv node. Name:'_token_2' " "Status Message: TODO: add NHWC implementation here."); }; From e93f14e00d09b0c62ba0869bc87f14ee5f1cf4c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Erick=20Mu=C3=B1oz?= Date: Mon, 16 Sep 2024 10:20:06 -0600 Subject: [PATCH 8/9] Check partial conversion on FP16 to FP32 AVX Cast kernel (#22091) ### Description Added checks to convert partial vectors in the early stages of the FP16 to FP32 cast using AVX NE CONVERT ISA. ### Motivation and Context Avoid storing data in sections outside of the output buffer, these checks are missing on the [original PR](https://github.com/microsoft/onnxruntime/pull/21183). This fix prevents memory corruption when the output buffer has a size [n*16 + 1, n*16 + 7] with 0< n --- onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm | 4 +++- onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm index c7f6342c527bf..800863c77a230 100644 --- a/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm +++ b/onnxruntime/core/mlas/lib/amd64/cvtfp16Avx.asm @@ -54,7 +54,7 @@ HIGH_SELECTOR equ 00110001b LEAF_ENTRY MlasCastF16ToF32KernelAvx, _TEXT - test r8, r8 ; Check if we have any elements to convert + test r8, r8 ; Check if we have any elements to convert jz ExitRoutine cmp r8, 8 jb ConvertMaskedVectors @@ -80,6 +80,8 @@ Convert256Vectors: jz ExitRoutine ; If we are done, exit cmp r8, 16 ; If the vector is big enough, we go again jae Convert256Vectors + cmp r8, 8 ; Check if we have enough elements to convert + jb ConvertMaskedVectors diff --git a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S index 1a70061460e50..a4d730fa513ab 100644 --- a/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S +++ b/onnxruntime/core/mlas/lib/x86_64/cvtfp16Avx.S @@ -51,8 +51,6 @@ FUNCTION_ENTRY MlasCastF16ToF32KernelAvx test rdx, rdx // Check if we have any elements to convert jz ExitRoutine - -AVX_NE_CONVERT: cmp rdx, 8 jb ConvertMaskedVectors cmp rdx, 16 @@ -75,6 +73,8 @@ Convert256Vectors: jz ExitRoutine // If we are done, exit cmp rdx, 16 // If the vector is big enough, we go again jae Convert256Vectors + cmp rdx, 8 // Check if we have enough elements to convert + jb ConvertMaskedVectors From 291a5352b27ded5714e5748b381f2efb88f28fb9 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Mon, 16 Sep 2024 10:56:22 -0700 Subject: [PATCH 9/9] [js/web] remove training release (#22103) ### Description Remove training from onnxruntime-web Following up of #22082 --- js/web/lib/backend-wasm-inference.ts | 5 - js/web/lib/backend-wasm-training.ts | 29 - js/web/lib/backend-wasm.ts | 2 + js/web/lib/index.ts | 4 +- js/web/lib/wasm/session-handler-training.ts | 198 ------ js/web/lib/wasm/wasm-core-impl.ts | 9 +- js/web/lib/wasm/wasm-training-core-impl.ts | 631 ------------------ js/web/lib/wasm/wasm-types.ts | 76 +-- js/web/lib/wasm/wasm-utils-import.ts | 16 +- js/web/package.json | 7 - js/web/script/build.ts | 13 +- js/web/script/pull-prebuilt-wasm-artifacts.ts | 2 - js/web/test/training/e2e/browser-test-wasm.js | 21 - js/web/test/training/e2e/common.js | 248 ------- js/web/test/training/e2e/data/model.onnx | 16 - js/web/test/training/e2e/karma.conf.js | 54 -- js/web/test/training/e2e/package.json | 14 - js/web/test/training/e2e/run.js | 143 ---- .../test/training/e2e/simple-http-server.js | 67 -- js/web/types.d.ts | 4 - 20 files changed, 15 insertions(+), 1544 deletions(-) delete mode 100644 js/web/lib/backend-wasm-inference.ts delete mode 100644 js/web/lib/backend-wasm-training.ts delete mode 100644 js/web/lib/wasm/session-handler-training.ts delete mode 100644 js/web/lib/wasm/wasm-training-core-impl.ts delete mode 100644 js/web/test/training/e2e/browser-test-wasm.js delete mode 100644 js/web/test/training/e2e/common.js delete mode 100644 js/web/test/training/e2e/data/model.onnx delete mode 100644 js/web/test/training/e2e/karma.conf.js delete mode 100644 js/web/test/training/e2e/package.json delete mode 100644 js/web/test/training/e2e/run.js delete mode 100644 js/web/test/training/e2e/simple-http-server.js diff --git a/js/web/lib/backend-wasm-inference.ts b/js/web/lib/backend-wasm-inference.ts deleted file mode 100644 index 7dfe7ee05a1d3..0000000000000 --- a/js/web/lib/backend-wasm-inference.ts +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; -export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts deleted file mode 100644 index 7332b3f97eba0..0000000000000 --- a/js/web/lib/backend-wasm-training.ts +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, TrainingSessionHandler } from 'onnxruntime-common'; - -import { OnnxruntimeWebAssemblyBackend } from './backend-wasm'; -import { OnnxruntimeWebAssemblyTrainingSessionHandler } from './wasm/session-handler-training'; - -class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { - async createTrainingSessionHandler( - checkpointStateUriOrBuffer: string | Uint8Array, - trainModelUriOrBuffer: string | Uint8Array, - evalModelUriOrBuffer: string | Uint8Array, - optimizerModelUriOrBuffer: string | Uint8Array, - options: InferenceSession.SessionOptions, - ): Promise { - const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); - await handler.createTrainingSession( - checkpointStateUriOrBuffer, - trainModelUriOrBuffer, - evalModelUriOrBuffer, - optimizerModelUriOrBuffer, - options, - ); - return Promise.resolve(handler); - } -} - -export const wasmBackend = new OnnxruntimeTrainingWebAssemblyBackend(); diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index 7bef538b26063..766937dc4c4cf 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -99,3 +99,5 @@ export class OnnxruntimeWebAssemblyBackend implements Backend { return Promise.resolve(handler); } } + +export const wasmBackend = new OnnxruntimeWebAssemblyBackend(); diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 321394466b365..776c0d026bc97 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -20,9 +20,7 @@ if (!BUILD_DEFS.DISABLE_WEBGL) { } if (!BUILD_DEFS.DISABLE_WASM) { - const wasmBackend = BUILD_DEFS.DISABLE_TRAINING - ? require('./backend-wasm-inference').wasmBackend - : require('./backend-wasm-training').wasmBackend; + const wasmBackend = require('./backend-wasm').wasmBackend; if (!BUILD_DEFS.DISABLE_JSEP) { registerBackend('webgpu', wasmBackend, 5); registerBackend('webnn', wasmBackend, 5); diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts deleted file mode 100644 index 8bbfb9cf06668..0000000000000 --- a/js/web/lib/wasm/session-handler-training.ts +++ /dev/null @@ -1,198 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessionHandler } from 'onnxruntime-common'; - -import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; -import { decodeTensorMetadata, encodeTensorMetadata } from './session-handler-inference'; -import { copyFromExternalBuffer } from './wasm-core-impl'; -import { - createCheckpointHandle, - createTrainingSessionHandle, - getContiguousParameters, - getModelInputOutputNames, - getParametersSize, - lazyResetGrad, - loadParametersBuffer, - releaseTrainingSessionAndCheckpoint, - runEvalStep, - runOptimizerStep, - runTrainStep, -} from './wasm-training-core-impl'; - -export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { - private sessionId: number; - private checkpointId: number; - - inputNames: string[]; - outputNames: string[]; - - evalInputNames: string[] = []; - evalOutputNames: string[] = []; - - async uriOrBufferToHeap(uriOrBuffer: string | Uint8Array): Promise { - let buffer: Uint8Array; - if (typeof uriOrBuffer === 'string') { - const response = await fetch(uriOrBuffer); - const arrayBuffer = await response.arrayBuffer(); - buffer = new Uint8Array(arrayBuffer); - } else { - buffer = uriOrBuffer; - } - return copyFromExternalBuffer(buffer); - } - - async createTrainingSession( - checkpointStateUriOrBuffer: string | Uint8Array, - trainModelUriOrBuffer: string | Uint8Array, - evalModelUriOrBuffer: string | Uint8Array, - optimizerModelUriOrBuffer: string | Uint8Array, - options: InferenceSession.SessionOptions, - ) { - const checkpointData: SerializableInternalBuffer = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); - const trainModelData: SerializableInternalBuffer = await this.uriOrBufferToHeap(trainModelUriOrBuffer); - // 0 is supposed to be the nullptr - let evalModelData: SerializableInternalBuffer = [0, 0]; - let optimizerModelData: SerializableInternalBuffer = [0, 0]; - - if (evalModelUriOrBuffer !== '') { - evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); - } - if (optimizerModelUriOrBuffer !== '') { - optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); - } - - this.checkpointId = createCheckpointHandle(checkpointData); - this.sessionId = createTrainingSessionHandle( - this.checkpointId, - trainModelData, - evalModelData, - optimizerModelData, - options, - ); - [this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false); - if (evalModelUriOrBuffer !== '') { - [this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true); - } - } - - /** - * Helper method that converts a feeds or fetches datatype to two arrays, one of values and one that stores the - * corresponding name as a number referring to the index in the list of names provided. - * - * @param feeds meant to match either SessionHandler.FeedsType or SessionHandler.FetchesType - * @param names either inputNames or outputNames - * @returns a tuple of a list of values and a list of indices. - */ - convertMapIntoValuesArrayAndIndicesArray( - feeds: { [name: string]: T }, - names: string[], - mapFunc: (val: T, index: number) => U, - ): [T[], number[], U[]] { - const values: T[] = []; - const indices: number[] = []; - Object.entries(feeds).forEach((kvp) => { - const name = kvp[0]; - const tensor = kvp[1]; - const index = names.indexOf(name); - if (index === -1) { - throw new Error(`invalid input '${name}`); - } - values.push(tensor); - indices.push(index); - }); - - const uList = values.map(mapFunc); - return [values, indices, uList]; - } - - /** - * Helper method that converts the TensorMetadata that the wasm-core functions return to the - * SessionHandler.ReturnType. Any outputs in the provided outputArray that are falsy will be populated with the - * corresponding result. - * - * @param results used to populate the resultMap if there is no value for that outputName already - * @param outputArray used to populate the resultMap. If null or undefined, use the corresponding result from results - * @param outputIndices specifies which outputName the corresponding value for outputArray refers to. - * @returns a map of output names and OnnxValues. - */ - convertTensorMetadataToReturnType( - results: TensorMetadata[], - outputArray: Array, - outputIndices: number[], - ): SessionHandler.ReturnType { - const resultMap: SessionHandler.ReturnType = {}; - for (let i = 0; i < results.length; i++) { - resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); - } - return resultMap; - } - - async lazyResetGrad(): Promise { - await lazyResetGrad(this.sessionId); - } - - async runTrainStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise { - const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( - feeds, - this.inputNames, - (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`), - ); - - const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray< - Tensor | null, - TensorMetadata | null - >(fetches, this.outputNames, (t, i): TensorMetadata | null => - t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null, - ); - - const results = await runTrainStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); - } - - async runOptimizerStep(options: InferenceSession.RunOptions): Promise { - await runOptimizerStep(this.sessionId, options); - } - - async runEvalStep( - feeds: SessionHandler.FeedsType, - fetches: SessionHandler.FetchesType, - options: InferenceSession.RunOptions, - ): Promise { - const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray( - feeds, - this.evalInputNames, - (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`), - ); - - const [outputArray, outputIndices, outputs] = this.convertMapIntoValuesArrayAndIndicesArray< - Tensor | null, - TensorMetadata | null - >(fetches, this.evalOutputNames, (t, i): TensorMetadata | null => - t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null, - ); - - const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices); - } - - async getParametersSize(trainableOnly: boolean): Promise { - return getParametersSize(this.sessionId, trainableOnly); - } - - async loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise { - await loadParametersBuffer(this.sessionId, array, trainableOnly); - } - async getContiguousParameters(trainableOnly: boolean): Promise { - const tensorResult = await getContiguousParameters(this.sessionId, trainableOnly); - return decodeTensorMetadata(tensorResult); - } - - async dispose(): Promise { - return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId); - } -} diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 6c4e28df62f23..ed001cfa90f59 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -41,8 +41,8 @@ import { loadFile } from './wasm-utils-load-file'; * Refer to web/lib/index.ts for the backend registration. * * 2. WebAssembly artifact initialization. - * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` or - * `ort.TrainingSession.create()` is called). In this step, onnxruntime-web does the followings: + * This happens when any registered wasm backend is used for the first time (ie. `ort.InferenceSession.create()` is + * called). In this step, onnxruntime-web does the followings: * - create a proxy worker and make sure the proxy worker is ready to receive messages, if proxy is enabled. * - perform feature detection, locate correct WebAssembly artifact path and call the Emscripten generated * JavaScript code to initialize the WebAssembly runtime. @@ -57,9 +57,8 @@ import { loadFile } from './wasm-utils-load-file'; * - logging level (ort.env.logLevel) and thread number (ort.env.wasm.numThreads) are set in this step. * * 4. Session initialization. - * This happens when `ort.InferenceSession.create()` or `ort.TrainingSession.create()` is called. Unlike the first 3 - * steps (they only called once), this step will be done for each session. In this step, onnxruntime-web does the - * followings: + * This happens when `ort.InferenceSession.create()` is called. Unlike the first 3 steps (they only called once), + * this step will be done for each session. In this step, onnxruntime-web does the followings: * If the parameter is a URL: * - download the model data from the URL. * - copy the model data to the WASM heap. (proxy: 'copy-from') diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts deleted file mode 100644 index 22cd6ec30732c..0000000000000 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ /dev/null @@ -1,631 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { InferenceSession, Tensor } from 'onnxruntime-common'; - -import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; -import { setRunOptions } from './run-options'; -import { setSessionOptions } from './session-options'; -import { - dataLocationStringToEnum, - tensorDataTypeEnumToString, - tensorDataTypeStringToEnum, - tensorTypeToTypedArrayConstructor, -} from './wasm-common'; -import { prepareInputOutputTensor } from './wasm-core-impl'; -import { getInstance } from './wasm-factory'; -import { checkLastError } from './wasm-utils'; - -const NO_TRAIN_FUNCS_MSG = - "Built without training API's enabled. Use the onnxruntime-web/training import for training " + - 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + - 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; - -/** - * Runs the checkLastError function which will throw an error, if the provided error code matches the specified - * pattern for an error code. - * @param errCode number to evaluated for if it's an error - * @param message message to pass into checkLastError - * @param checkNeqZero when true, treats not equal to zero as an error. - * When false, treats equal to zero as an error. - */ -const ifErrCodeCheckLastError = (errCode: number, message: string, checkNeqZero = true) => { - if (checkNeqZero && errCode !== 0) { - checkLastError(message); - } else if (!checkNeqZero && errCode === 0) { - checkLastError(message); - } -}; - -export const createCheckpointHandle = (checkpointData: SerializableInternalBuffer): number => { - const wasm = getInstance(); - - const [checkpointDataOffset, checkpointDataLength] = checkpointData; - let checkpointHandle = 0; - - try { - if (wasm._OrtTrainingLoadCheckpoint) { - checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - ifErrCodeCheckLastError(checkpointHandle, 'Error occurred when trying to create a CheckpointState', false); - return checkpointHandle; - } catch (e) { - if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { - wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); - } - throw e; - } finally { - // free buffer from wasm heap - wasm._OrtFree(checkpointData[0]); - } -}; - -const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - try { - const dataOffset = wasm.stackAlloc(8); - if (wasm._OrtTrainingGetModelInputOutputCount) { - const errorCode = wasm._OrtTrainingGetModelInputOutputCount( - trainingSessionId, - dataOffset, - dataOffset + 4, - isEvalModel, - ); - ifErrCodeCheckLastError(errorCode, "Can't get session input/output count."); - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } -}; - -const getModelInputOutputNamesLoop = ( - trainingSessionId: number, - count: number, - isInput: boolean, - isEvalModel: boolean, -): string[] => { - const names = []; - const wasm = getInstance(); - - for (let i = 0; i < count; i++) { - if (wasm._OrtTrainingGetModelInputOutputName) { - const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); - ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false); - - names.push(wasm.UTF8ToString(name)); - wasm._free(name); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } - return names; -}; - -export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => { - let inputNames: string[] = []; - let outputNames: string[] = []; - - const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel); - - inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel); - outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel); - - return [inputNames, outputNames]; -}; - -export const createTrainingSessionHandle = ( - checkpointHandle: number, - trainModelData: SerializableInternalBuffer, - evalModelData: SerializableInternalBuffer, - optimizerModelData: SerializableInternalBuffer, - options: InferenceSession.SessionOptions, -): number => { - const wasm = getInstance(); - - let trainingSessionHandle = 0; - let sessionOptionsHandle = 0; - let allocs: number[] = []; - - try { - [sessionOptionsHandle, allocs] = setSessionOptions(options); - if (wasm._OrtTrainingCreateSession) { - trainingSessionHandle = wasm._OrtTrainingCreateSession( - sessionOptionsHandle, - checkpointHandle, - trainModelData[0], - trainModelData[1], - evalModelData[0], - evalModelData[1], - optimizerModelData[0], - optimizerModelData[1], - ); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false); - return trainingSessionHandle; - } catch (e) { - if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { - wasm._OrtTrainingReleaseSession(trainingSessionHandle); - } - throw e; - } finally { - wasm._free(trainModelData[0]); - wasm._free(evalModelData[0]); - wasm._free(optimizerModelData[0]); - - if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); - } - allocs.forEach((alloc) => wasm._free(alloc)); - } -}; - -/** - * Prepares input and output tensors by creating the tensors in the WASM side then creates a list of the handles of the - * WASM tensors. - * - * @param trainingSessionId - * @param indices for each tensor, the index of the input or output name that the tensor corresponds with - * @param tensors list of TensorMetaData - * @param tensorHandles should pass in an empty list of numbers; modified in-place by this method & stores the resulting - * handles of the allocated tensors on the heap - * @param inputOutputAllocs modified in-place by this method - * @param indexAdd constant to add to the index that is passed to prepareInputOutputTensor - */ -const createAndAllocateTensors = ( - trainingSessionId: number, - indices: number[], - tensors: Array, - tensorHandles: number[], - inputOutputAllocs: number[], - indexAdd: number, -) => { - const count = indices.length; - - // creates the tensors - for (let i = 0; i < count; i++) { - prepareInputOutputTensor(tensors[i], tensorHandles, inputOutputAllocs, trainingSessionId, indexAdd + indices[i]); - } - - // moves to heap - const wasm = getInstance(); - const valuesOffset = wasm.stackAlloc(count * 4); - let valuesIndex = valuesOffset / 4; - for (let i = 0; i < count; i++) { - wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; - } - - return valuesOffset; -}; - -/** - * Retrieves the information from the output tensor handles, copies to an array, and frees the WASM information - * associated with the tensor handle. - * - * @param outputValuesOffset - * @param outputCount - * @returns list of TensorMetadata retrieved from the output handles. - */ -const moveOutputToTensorMetadataArr = ( - outputValuesOffset: number, - outputCount: number, - outputTensorHandles: number[], - outputTensors: Array, -) => { - const wasm = getInstance(); - const output: TensorMetadata[] = []; - - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; - if (tensor === outputTensorHandles[i]) { - // output tensor is pre-allocated. no need to copy data. - output.push(outputTensors[i]!); - continue; - } - - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); - - let type: Tensor.Type | undefined, - dataOffset = 0; - try { - const errorCode = wasm._OrtGetTensorData( - tensor, - tensorDataOffset, - tensorDataOffset + 4, - tensorDataOffset + 8, - tensorDataOffset + 12, - ); - ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); - - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); - } - wasm._OrtFree(dimsOffset); - - const size = dims.reduce((a, b) => a * b, 1); - type = tensorDataTypeEnumToString(dataType); - - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); - } - output.push([type, dims, stringData, 'cpu']); - } else { - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); - const data = new typedArrayConstructor(size); - new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( - wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength), - ); - output.push([type, dims, data, 'cpu']); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); - } - wasm._OrtReleaseTensor(tensor); - } - } - - return output; -}; - -export const lazyResetGrad = async (trainingSessionId: number): Promise => { - const wasm = getInstance(); - - if (wasm._OrtTrainingLazyResetGrad) { - const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId); - ifErrCodeCheckLastError(errorCode, "Can't call lazyResetGrad."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } -}; - -export const runTrainStep = async ( - trainingSessionId: number, - inputIndices: number[], - inputTensors: TensorMetadata[], - outputIndices: number[], - outputTensors: Array, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - const inputCount = inputIndices.length; - const outputCount = outputIndices.length; - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - const inputTensorHandles: number[] = []; - const outputTensorHandles: number[] = []; - const inputOutputAllocs: number[] = []; - - const beforeRunStack = wasm.stackSave(); - - try { - // prepare parameters by moving them to heap - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - // handle inputs -- you don't want anything added to the index - const inputValuesOffset = createAndAllocateTensors( - trainingSessionId, - inputIndices, - inputTensors, - inputTensorHandles, - inputOutputAllocs, - 0, - ); - // handle outputs - // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor - const outputValuesOffset = createAndAllocateTensors( - trainingSessionId, - outputIndices, - outputTensors, - outputTensorHandles, - inputOutputAllocs, - inputCount, - ); - - if (wasm._OrtTrainingRunTrainStep) { - const errorCode = wasm._OrtTrainingRunTrainStep( - trainingSessionId, - inputValuesOffset, - inputCount, - outputValuesOffset, - outputCount, - runOptionsHandle, - ); - ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingRunTrainStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); - } finally { - wasm.stackRestore(beforeRunStack); - - inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach((p) => wasm._free(p)); - - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const runOptimizerStep = async ( - trainingSessionId: number, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - try { - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - if (wasm._OrtTrainingOptimizerStep) { - const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle); - ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const runEvalStep = async ( - trainingSessionId: number, - inputIndices: number[], - inputTensors: TensorMetadata[], - outputIndices: number[], - outputTensors: Array, - options: InferenceSession.RunOptions, -): Promise => { - const wasm = getInstance(); - - const inputCount = inputIndices.length; - const outputCount = outputIndices.length; - - let runOptionsHandle = 0; - let runOptionsAllocs: number[] = []; - - const inputTensorHandles: number[] = []; - const outputTensorHandles: number[] = []; - const inputOutputAllocs: number[] = []; - - const beforeRunStack = wasm.stackSave(); - - try { - // prepare parameters by moving them to heap - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); - - // handle inputs -- you don't want anything added to the index - const inputValuesOffset = createAndAllocateTensors( - trainingSessionId, - inputIndices, - inputTensors, - inputTensorHandles, - inputOutputAllocs, - 0, - ); - // handle outputs - // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor - const outputValuesOffset = createAndAllocateTensors( - trainingSessionId, - outputIndices, - outputTensors, - outputTensorHandles, - inputOutputAllocs, - inputCount, - ); - - if (wasm._OrtTrainingEvalStep) { - const errorCode = wasm._OrtTrainingEvalStep( - trainingSessionId, - inputValuesOffset, - inputCount, - outputValuesOffset, - outputCount, - runOptionsHandle, - ); - - ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer'); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors); - } finally { - wasm.stackRestore(beforeRunStack); - - inputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - outputTensorHandles.forEach((v) => wasm._OrtReleaseTensor(v)); - inputOutputAllocs.forEach((p) => wasm._free(p)); - - if (runOptionsHandle !== 0) { - wasm._OrtReleaseRunOptions(runOptionsHandle); - } - runOptionsAllocs.forEach((p) => wasm._free(p)); - } -}; - -export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - try { - const sizeOffset = wasm.stackAlloc(4); - if (wasm._OrtTrainingGetParametersSize) { - const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); - ifErrCodeCheckLastError(errorCode, "Can't get parameters size"); - - return wasm.HEAP32[sizeOffset / 4]; - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - wasm.stackRestore(stack); - } -}; - -export const getContiguousParameters = async ( - trainingSessionId: number, - trainableOnly: boolean, -): Promise => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - const tensorTypeAsString = 'float32'; - const locationAsString = 'cpu'; - - const parametersSize = getParametersSize(trainingSessionId, trainableOnly); - let tensor = 0; - - // allocates a buffer of the correct size on the WASM heap - const paramsByteLength = 4 * parametersSize; - const paramsOffset = wasm._malloc(paramsByteLength); - - // handles the dimensions-related createTensor parameters - const dims = [parametersSize]; - - const dimsOffset = wasm.stackAlloc(4); - const dimsIndex = dimsOffset / 4; - wasm.HEAP32[dimsIndex] = parametersSize; - - try { - // wraps allocated array in a tensor - tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(tensorTypeAsString), - paramsOffset, - paramsByteLength, - dimsOffset, - dims.length, - dataLocationStringToEnum(locationAsString), - ); - ifErrCodeCheckLastError( - tensor, - `Can't create tensor for getContiguousParameters. session=${trainingSessionId}.`, - false, - ); - - if (wasm._OrtTrainingCopyParametersToBuffer) { - const errCode = wasm._OrtTrainingCopyParametersToBuffer(trainingSessionId, tensor, parametersSize, trainableOnly); - ifErrCodeCheckLastError(errCode, "Can't get contiguous parameters."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - - // copies from WASM memory to a JavaScript typed array, which is then put into a TensorMetadata object - const typedArrayConstructor = tensorTypeToTypedArrayConstructor(tensorTypeAsString); - const data = new typedArrayConstructor(parametersSize); - const output: TensorMetadata[] = []; - new Uint8Array(data.buffer, data.byteOffset, data.byteLength).set( - wasm.HEAPU8.subarray(paramsOffset, paramsOffset + paramsByteLength), - ); - output.push([tensorTypeAsString, dims, data, locationAsString]); - if (output.length !== 1) { - throw new Error(`something unexpected happened in the getContiguousParameters function. Expected output length of - one, got ${output.length}`); - } else { - return output[0]; - } - } finally { - if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); - } - wasm._free(paramsOffset); - wasm._free(dimsOffset); - wasm.stackRestore(stack); - } -}; - -export const loadParametersBuffer = async ( - trainingSessionId: number, - buffer: Uint8Array, - trainableOnly: boolean, -): Promise => { - const wasm = getInstance(); - const stack = wasm.stackSave(); - - const tensorTypeAsString = 'float32'; - const locationAsString = 'cpu'; - - // allocates & copies JavaScript buffer to WASM heap - const bufferByteLength = buffer.length; - const bufferCount = bufferByteLength / 4; - const bufferOffset = wasm._malloc(bufferByteLength); - wasm.HEAPU8.set(buffer, bufferOffset); - - // allocates and handles moving dimensions information to WASM memory - const dimsOffset = wasm.stackAlloc(4); - wasm.HEAP32[dimsOffset / 4] = bufferCount; - const dimsLength = 1; - let tensor = 0; - - try { - tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(tensorTypeAsString), - bufferOffset, - bufferByteLength, - dimsOffset, - dimsLength, - dataLocationStringToEnum(locationAsString), - ); - ifErrCodeCheckLastError(tensor, `Can't create tensor for input/output. session=${trainingSessionId}`, false); - - if (wasm._OrtTrainingCopyParametersFromBuffer) { - const errCode = wasm._OrtTrainingCopyParametersFromBuffer(trainingSessionId, tensor, bufferCount, trainableOnly); - ifErrCodeCheckLastError(errCode, "Can't copy buffer to parameters."); - } else { - throw new Error(NO_TRAIN_FUNCS_MSG); - } - } finally { - if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); - } - wasm.stackRestore(stack); - wasm._free(bufferOffset); - wasm._free(dimsOffset); - } -}; - -export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => { - const wasm = getInstance(); - - if (wasm._OrtTrainingReleaseSession) { - wasm._OrtTrainingReleaseSession(sessionId); - } - if (wasm._OrtTrainingReleaseCheckpoint) { - wasm._OrtTrainingReleaseCheckpoint(checkpointId); - } -}; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 70b6cceab0eef..828cd3cfd94fa 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -213,84 +213,10 @@ export interface OrtInferenceAPIs { _OrtEndProfiling(sessionHandle: number): number; } -export interface OrtTrainingAPIs { - _OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number; - - _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void; - - _OrtTrainingCreateSession( - sessionOptionsHandle: number, - checkpointHandle: number, - trainOffset: number, - trainLength: number, - evalOffset: number, - evalLength: number, - optimizerOffset: number, - optimizerLength: number, - ): number; - - _OrtTrainingLazyResetGrad(trainingHandle: number): number; - - _OrtTrainingRunTrainStep( - trainingHandle: number, - inputsOffset: number, - inputCount: number, - outputsOffset: number, - outputCount: number, - runOptionsHandle: number, - ): number; - - _OrtTrainingOptimizerStep(trainingHandle: number, runOptionsHandle: number): number; - - _OrtTrainingEvalStep( - trainingHandle: number, - inputsOffset: number, - inputCount: number, - outputsOffset: number, - outputCount: number, - runOptionsHandle: number, - ): number; - - _OrtTrainingGetParametersSize(trainingHandle: number, paramSizeT: number, trainableOnly: boolean): number; - - _OrtTrainingCopyParametersToBuffer( - trainingHandle: number, - parametersBuffer: number, - parameterCount: number, - trainableOnly: boolean, - ): number; - - _OrtTrainingCopyParametersFromBuffer( - trainingHandle: number, - parametersBuffer: number, - parameterCount: number, - trainableOnly: boolean, - ): number; - - _OrtTrainingGetModelInputOutputCount( - trainingHandle: number, - inputCount: number, - outputCount: number, - isEvalModel: boolean, - ): number; - _OrtTrainingGetModelInputOutputName( - trainingHandle: number, - index: number, - isInput: boolean, - isEvalModel: boolean, - ): number; - - _OrtTrainingReleaseSession(trainingHandle: number): void; -} - /** * The interface of the WebAssembly module for ONNX Runtime, compiled from C++ source code by Emscripten. */ -export interface OrtWasmModule - extends EmscriptenModule, - OrtInferenceAPIs, - Partial, - Partial { +export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial { // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts index 008b9b41b1592..bd9e0ce083ef0 100644 --- a/js/web/lib/wasm/wasm-utils-import.ts +++ b/js/web/lib/wasm/wasm-utils-import.ts @@ -135,11 +135,9 @@ const embeddedWasmModule: EmscriptenModuleFactory | undefined = BUILD_DEFS.IS_ESM && BUILD_DEFS.DISABLE_DYNAMIC_IMPORT ? // eslint-disable-next-line @typescript-eslint/no-require-imports, @typescript-eslint/no-var-requires require( - !BUILD_DEFS.DISABLE_TRAINING - ? '../../dist/ort-training-wasm-simd-threaded.mjs' - : !BUILD_DEFS.DISABLE_JSEP - ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' - : '../../dist/ort-wasm-simd-threaded.mjs', + !BUILD_DEFS.DISABLE_JSEP + ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' + : '../../dist/ort-wasm-simd-threaded.mjs', ).default : undefined; @@ -163,11 +161,9 @@ export const importWasmModule = async ( if (BUILD_DEFS.DISABLE_DYNAMIC_IMPORT) { return [undefined, embeddedWasmModule!]; } else { - const wasmModuleFilename = !BUILD_DEFS.DISABLE_TRAINING - ? 'ort-training-wasm-simd-threaded.mjs' - : !BUILD_DEFS.DISABLE_JSEP - ? 'ort-wasm-simd-threaded.jsep.mjs' - : 'ort-wasm-simd-threaded.mjs'; + const wasmModuleFilename = !BUILD_DEFS.DISABLE_JSEP + ? 'ort-wasm-simd-threaded.jsep.mjs' + : 'ort-wasm-simd-threaded.mjs'; const wasmModuleUrl = urlOverride ?? normalizeUrl(wasmModuleFilename, prefixOverride); // need to preload if all of the following conditions are met: // 1. not in Node.js. diff --git a/js/web/package.json b/js/web/package.json index 94dd047915b05..d770499adada4 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -23,7 +23,6 @@ "build:doc": "node ./script/generate-webgl-operator-md && node ./script/generate-webgpu-operator-md", "pull:wasm": "node ./script/pull-prebuilt-wasm-artifacts", "test:e2e": "node ./test/e2e/run", - "test:training:e2e": "node ./test/training/e2e/run", "prebuild": "tsc -p . --noEmit && tsc -p lib/wasm/proxy-worker --noEmit", "build": "node ./script/build", "test": "tsc --build ../scripts && node ../scripts/prepare-onnx-node-tests && node ./script/test-runner-cli", @@ -101,12 +100,6 @@ "import": "./dist/ort.webgpu.bundle.min.mjs", "require": "./dist/ort.webgpu.min.js", "types": "./types.d.ts" - }, - "./training": { - "node": null, - "import": "./dist/ort.training.wasm.min.mjs", - "require": "./dist/ort.training.wasm.min.js", - "types": "./types.d.ts" } }, "types": "./types.d.ts", diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 6d1b3bdb65068..408f9e00a5cbd 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -56,7 +56,6 @@ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_JSEP': 'false', 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', - 'BUILD_DEFS.DISABLE_TRAINING': 'true', 'BUILD_DEFS.DISABLE_DYNAMIC_IMPORT': 'false', 'BUILD_DEFS.IS_ESM': 'false', @@ -253,7 +252,7 @@ async function buildBundle(options: esbuild.BuildOptions) { * * The distribution code is split into multiple files: * - [output-name][.min].[m]js - * - ort[-training]-wasm-simd-threaded[.jsep].mjs + * - ort-wasm-simd-threaded[.jsep].mjs */ async function buildOrt({ isProduction = false, @@ -630,16 +629,6 @@ async function main() { 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', }, }); - // ort.training.wasm[.min].[m]js - await addAllWebBuildTasks({ - outputName: 'ort.training.wasm', - define: { - ...DEFAULT_DEFINE, - 'BUILD_DEFS.DISABLE_TRAINING': 'false', - 'BUILD_DEFS.DISABLE_JSEP': 'true', - 'BUILD_DEFS.DISABLE_WEBGL': 'true', - }, - }); } if (BUNDLE_MODE === 'dev' || BUNDLE_MODE === 'perf') { diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index b1b2fa26b2351..5b8b0d27c88db 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -149,11 +149,9 @@ downloadJson( void jszip.loadAsync(buffer).then((zip) => { extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.wasm', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.wasm', folderName); - extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.wasm', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.mjs', folderName); extractFile(zip, WASM_FOLDER, 'ort-wasm-simd-threaded.jsep.mjs', folderName); - extractFile(zip, WASM_FOLDER, 'ort-training-wasm-simd-threaded.mjs', folderName); }); }); }, diff --git a/js/web/test/training/e2e/browser-test-wasm.js b/js/web/test/training/e2e/browser-test-wasm.js deleted file mode 100644 index 05750ed149303..0000000000000 --- a/js/web/test/training/e2e/browser-test-wasm.js +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -describe('Browser E2E testing for training package', function () { - it('Check that training package encompasses inference', async function () { - ort.env.wasm.numThreads = 1; - await testInferenceFunction(ort, { executionProviders: ['wasm'] }); - }); - - it('Check training functionality, all options', async function () { - ort.env.wasm.numThreads = 1; - await testTrainingFunctionAll(ort, { executionProviders: ['wasm'] }); - }); - - it('Check training functionality, minimum options', async function () { - ort.env.wasm.numThreads = 1; - await testTrainingFunctionMin(ort, { executionProviders: ['wasm'] }); - }); -}); diff --git a/js/web/test/training/e2e/common.js b/js/web/test/training/e2e/common.js deleted file mode 100644 index 0574ae85aabd1..0000000000000 --- a/js/web/test/training/e2e/common.js +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const DATA_FOLDER = 'data/'; -const TRAININGDATA_TRAIN_MODEL = DATA_FOLDER + 'training_model.onnx'; -const TRAININGDATA_OPTIMIZER_MODEL = DATA_FOLDER + 'adamw.onnx'; -const TRAININGDATA_EVAL_MODEL = DATA_FOLDER + 'eval_model.onnx'; -const TRAININGDATA_CKPT = DATA_FOLDER + 'checkpoint.ckpt'; - -const trainingSessionAllOptions = { - checkpointState: TRAININGDATA_CKPT, - trainModel: TRAININGDATA_TRAIN_MODEL, - evalModel: TRAININGDATA_EVAL_MODEL, - optimizerModel: TRAININGDATA_OPTIMIZER_MODEL, -}; - -const trainingSessionMinOptions = { - checkpointState: TRAININGDATA_CKPT, - trainModel: TRAININGDATA_TRAIN_MODEL, -}; - -// ASSERT METHODS - -function assert(cond) { - if (!cond) throw new Error(); -} - -function assertStrictEquals(actual, expected) { - if (actual !== expected) { - let strRep = actual; - if (typeof actual === 'object') { - strRep = JSON.stringify(actual); - } - throw new Error(`expected: ${expected}; got: ${strRep}`); - } -} - -function assertTwoListsUnequal(list1, list2) { - if (list1.length !== list2.length) { - return; - } - for (let i = 0; i < list1.length; i++) { - if (list1[i] !== list2[i]) { - return; - } - } - throw new Error(`expected ${list1} and ${list2} to be unequal; got two equal lists`); -} - -// HELPER METHODS FOR TESTS - -function generateGaussianRandom(mean = 0, scale = 1) { - const u = 1 - Math.random(); - const v = Math.random(); - const z = Math.sqrt(-2.0 * Math.log(u)) * Math.cos(2.0 * Math.PI * v); - return z * scale + mean; -} - -function generateGaussianFloatArray(length) { - const array = new Float32Array(length); - - for (let i = 0; i < length; i++) { - array[i] = generateGaussianRandom(); - } - - return array; -} - -/** - * creates the TrainingSession and verifies that the input and output names of the training model loaded into the - * training session are correct. - * @param {} ort - * @param {*} createOptions - * @param {*} options - * @returns - */ -async function createTrainingSessionAndCheckTrainingModel(ort, createOptions, options) { - const trainingSession = await ort.TrainingSession.create(createOptions, options); - - assertStrictEquals(trainingSession.trainingInputNames[0], 'input-0'); - assertStrictEquals(trainingSession.trainingInputNames[1], 'labels'); - assertStrictEquals(trainingSession.trainingInputNames.length, 2); - assertStrictEquals(trainingSession.trainingOutputNames[0], 'onnx::loss::21273'); - assertStrictEquals(trainingSession.trainingOutputNames.length, 1); - return trainingSession; -} - -/** - * verifies that the eval input and output names associated with the eval model loaded into the given training session - * are correct. - */ -function checkEvalModel(trainingSession) { - assertStrictEquals(trainingSession.evalInputNames[0], 'input-0'); - assertStrictEquals(trainingSession.evalInputNames[1], 'labels'); - assertStrictEquals(trainingSession.evalInputNames.length, 2); - assertStrictEquals(trainingSession.evalOutputNames[0], 'onnx::loss::21273'); - assertStrictEquals(trainingSession.evalOutputNames.length, 1); -} - -/** - * Checks that accessing trainingSession.evalInputNames or trainingSession.evalOutputNames will throw an error if - * accessed - * @param {} trainingSession - */ -function checkNoEvalModel(trainingSession) { - try { - assertStrictEquals(trainingSession.evalInputNames, 'should have thrown an error upon accessing'); - } catch (error) { - assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); - } - try { - assertStrictEquals(trainingSession.evalOutputNames, 'should have thrown an error upon accessing'); - } catch (error) { - assertStrictEquals(error.message, 'This training session has no evalModel loaded.'); - } -} - -/** - * runs the train step with the given inputs and checks that the tensor returned is of type float32 and has a length - * of 1 for the loss. - * @param {} trainingSession - * @param {*} feeds - * @returns - */ -var runTrainStepAndCheck = async function (trainingSession, feeds) { - const results = await trainingSession.runTrainStep(feeds); - assertStrictEquals(Object.keys(results).length, 1); - assertStrictEquals(results['onnx::loss::21273'].data.length, 1); - assertStrictEquals(results['onnx::loss::21273'].type, 'float32'); - return results; -}; - -var loadParametersBufferAndCheck = async function (trainingSession, paramsLength, constant, paramsBefore) { - // make a float32 array that is filled with the constant - const newParams = new Float32Array(paramsLength); - for (let i = 0; i < paramsLength; i++) { - newParams[i] = constant; - } - - const newParamsUint8 = new Uint8Array(newParams.buffer, newParams.byteOffset, newParams.byteLength); - - await trainingSession.loadParametersBuffer(newParamsUint8); - const paramsAfterLoad = await trainingSession.getContiguousParameters(); - - // check that the parameters have changed - assertTwoListsUnequal(paramsAfterLoad.data, paramsBefore.data); - assertStrictEquals(paramsAfterLoad.dims[0], paramsLength); - - // check that the parameters have changed to what they should be - for (let i = 0; i < paramsLength; i++) { - // round to the same number of digits (4 decimal places) - assertStrictEquals(paramsAfterLoad.data[i].toFixed(4), constant.toFixed(4)); - } - - return paramsAfterLoad; -}; - -// TESTS - -var testInferenceFunction = async function (ort, options) { - const session = await ort.InferenceSession.create('data/model.onnx', options || {}); - - const dataA = Float32Array.from([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); - const dataB = Float32Array.from([10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120]); - - const fetches = await session.run({ - a: new ort.Tensor('float32', dataA, [3, 4]), - b: new ort.Tensor('float32', dataB, [4, 3]), - }); - - const c = fetches.c; - - assert(c instanceof ort.Tensor); - assert(c.dims.length === 2 && c.dims[0] === 3 && c.dims[1] === 3); - assert(c.data[0] === 700); - assert(c.data[1] === 800); - assert(c.data[2] === 900); - assert(c.data[3] === 1580); - assert(c.data[4] === 1840); - assert(c.data[5] === 2100); - assert(c.data[6] === 2460); - assert(c.data[7] === 2880); - assert(c.data[8] === 3300); -}; - -var testTrainingFunctionMin = async function (ort, options) { - const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionMinOptions, options); - checkNoEvalModel(trainingSession); - const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); - const labels = new ort.Tensor('int32', [2, 1], [2]); - const feeds = { 'input-0': input0, labels: labels }; - - // check getParametersSize - const paramsSize = await trainingSession.getParametersSize(); - assertStrictEquals(paramsSize, 397510); - - // check getContiguousParameters - const originalParams = await trainingSession.getContiguousParameters(); - assertStrictEquals(originalParams.dims.length, 1); - assertStrictEquals(originalParams.dims[0], 397510); - assertStrictEquals(originalParams.data[0], -0.025190064683556557); - assertStrictEquals(originalParams.data[2000], -0.034044936299324036); - - await runTrainStepAndCheck(trainingSession, feeds); - - await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, originalParams); -}; - -var testTrainingFunctionAll = async function (ort, options) { - const trainingSession = await createTrainingSessionAndCheckTrainingModel(ort, trainingSessionAllOptions, options); - checkEvalModel(trainingSession); - - const input0 = new ort.Tensor('float32', generateGaussianFloatArray(2 * 784), [2, 784]); - const labels = new ort.Tensor('int32', [2, 1], [2]); - let feeds = { 'input-0': input0, labels: labels }; - - // check getParametersSize - const paramsSize = await trainingSession.getParametersSize(); - assertStrictEquals(paramsSize, 397510); - - // check getContiguousParameters - const originalParams = await trainingSession.getContiguousParameters(); - assertStrictEquals(originalParams.dims.length, 1); - assertStrictEquals(originalParams.dims[0], 397510); - assertStrictEquals(originalParams.data[0], -0.025190064683556557); - assertStrictEquals(originalParams.data[2000], -0.034044936299324036); - - const results = await runTrainStepAndCheck(trainingSession, feeds); - - await trainingSession.runOptimizerStep(feeds); - feeds = { 'input-0': input0, labels: labels }; - // check getContiguousParameters after optimizerStep -- that the parameters have been updated - const optimizedParams = await trainingSession.getContiguousParameters(); - assertTwoListsUnequal(originalParams.data, optimizedParams.data); - - const results2 = await runTrainStepAndCheck(trainingSession, feeds); - - // check that loss decreased after optimizer step and training again - assert(results2['onnx::loss::21273'].data < results['onnx::loss::21273'].data); - - await loadParametersBufferAndCheck(trainingSession, 397510, -1.2, optimizedParams); -}; - -if (typeof module === 'object') { - module.exports = [testInferenceFunction, testTrainingFunctionMin, testTrainingFunctionAll, testTest]; -} diff --git a/js/web/test/training/e2e/data/model.onnx b/js/web/test/training/e2e/data/model.onnx deleted file mode 100644 index 088124bd48624..0000000000000 --- a/js/web/test/training/e2e/data/model.onnx +++ /dev/null @@ -1,16 +0,0 @@ - backend-test:b - -a -bc"MatMultest_matmul_2dZ -a -  - -Z -b -  - -b -c -  - -B \ No newline at end of file diff --git a/js/web/test/training/e2e/karma.conf.js b/js/web/test/training/e2e/karma.conf.js deleted file mode 100644 index 74662b67676f7..0000000000000 --- a/js/web/test/training/e2e/karma.conf.js +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const args = require('minimist')(process.argv.slice(2)); -const SELF_HOST = !!args['self-host']; -const ORT_MAIN = args['ort-main']; -const TEST_MAIN = args['test-main']; -if (typeof TEST_MAIN !== 'string') { - throw new Error('flag --test-main= is required'); -} -const USER_DATA = args['user-data']; -if (typeof USER_DATA !== 'string') { - throw new Error('flag --user-data= is required'); -} - -module.exports = function (config) { - const distPrefix = SELF_HOST ? './node_modules/onnxruntime-web/dist/' : 'http://localhost:8081/dist/'; - config.set({ - frameworks: ['mocha'], - files: [ - { pattern: distPrefix + ORT_MAIN }, - { pattern: './common.js' }, - { pattern: TEST_MAIN }, - { pattern: './node_modules/onnxruntime-web/dist/*.*', included: false, nocache: true }, - { pattern: './data/*', included: false }, - ], - plugins: [require('@chiragrupani/karma-chromium-edge-launcher'), ...config.plugins], - proxies: { - '/model.onnx': '/base/model.onnx', - '/data/': '/base/data/', - }, - client: { captureConsole: true, mocha: { expose: ['body'], timeout: 60000 } }, - reporters: ['mocha'], - captureTimeout: 120000, - reportSlowerThan: 100, - browserDisconnectTimeout: 600000, - browserNoActivityTimeout: 300000, - browserDisconnectTolerance: 0, - browserSocketTimeout: 60000, - hostname: 'localhost', - browsers: [], - customLaunchers: { - Chrome_default: { base: 'ChromeHeadless', chromeDataDir: USER_DATA }, - Chrome_no_threads: { - base: 'ChromeHeadless', - chromeDataDir: USER_DATA, - // TODO: no-thread flags - }, - Edge_default: { base: 'Edge', edgeDataDir: USER_DATA }, - }, - }); -}; diff --git a/js/web/test/training/e2e/package.json b/js/web/test/training/e2e/package.json deleted file mode 100644 index 5f11a27de6dfc..0000000000000 --- a/js/web/test/training/e2e/package.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "devDependencies": { - "@chiragrupani/karma-chromium-edge-launcher": "^2.2.2", - "fs-extra": "^11.1.0", - "globby": "^13.1.3", - "karma": "^6.4.1", - "karma-chrome-launcher": "^3.1.1", - "karma-mocha": "^2.0.1", - "karma-mocha-reporter": "^2.2.5", - "light-server": "^2.9.1", - "minimist": "^1.2.7", - "mocha": "^10.2.0" - } -} diff --git a/js/web/test/training/e2e/run.js b/js/web/test/training/e2e/run.js deleted file mode 100644 index d12bcc7aa66ed..0000000000000 --- a/js/web/test/training/e2e/run.js +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -const path = require('path'); -const fs = require('fs-extra'); -const { spawn } = require('child_process'); -const startServer = require('./simple-http-server'); -const minimist = require('minimist'); - -// copy whole folder to out-side of /js/ because we need to test in a folder that no `package.json` file -// exists in its parent folder. -// here we use /build/js/e2e-training/ for the test - -const TEST_E2E_SRC_FOLDER = __dirname; -const JS_ROOT_FOLDER = path.resolve(__dirname, '../../../..'); -const TEST_E2E_RUN_FOLDER = path.resolve(JS_ROOT_FOLDER, '../build/js/e2e-training'); -const NPM_CACHE_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../npm_cache'); -const CHROME_USER_DATA_FOLDER = path.resolve(TEST_E2E_RUN_FOLDER, '../user_data'); -fs.emptyDirSync(TEST_E2E_RUN_FOLDER); -fs.emptyDirSync(NPM_CACHE_FOLDER); -fs.emptyDirSync(CHROME_USER_DATA_FOLDER); -fs.copySync(TEST_E2E_SRC_FOLDER, TEST_E2E_RUN_FOLDER); - -// training data to copy -const ORT_ROOT_FOLDER = path.resolve(JS_ROOT_FOLDER, '..'); -const TRAINING_DATA_FOLDER = path.resolve(ORT_ROOT_FOLDER, 'onnxruntime/test/testdata/training_api'); -const TRAININGDATA_DEST = path.resolve(TEST_E2E_RUN_FOLDER, 'data'); - -// always use a new folder as user-data-dir -let nextUserDataDirId = 0; -function getNextUserDataDir() { - const dir = path.resolve(CHROME_USER_DATA_FOLDER, nextUserDataDirId.toString()); - nextUserDataDirId++; - fs.emptyDirSync(dir); - return dir; -} - -// commandline arguments -const BROWSER = minimist(process.argv.slice(2)).browser || 'Chrome_default'; - -async function main() { - // find packed package - const { globbySync } = await import('globby'); - - const ORT_COMMON_FOLDER = path.resolve(JS_ROOT_FOLDER, 'common'); - const ORT_COMMON_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-common-*.tgz', { cwd: ORT_COMMON_FOLDER }); - - const PACKAGES_TO_INSTALL = []; - - if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length === 1) { - PACKAGES_TO_INSTALL.push(path.resolve(ORT_COMMON_FOLDER, ORT_COMMON_PACKED_FILEPATH_CANDIDATES[0])); - } else if (ORT_COMMON_PACKED_FILEPATH_CANDIDATES.length > 1) { - throw new Error('multiple packages found for onnxruntime-common.'); - } - - const ORT_WEB_FOLDER = path.resolve(JS_ROOT_FOLDER, 'web'); - const ORT_WEB_PACKED_FILEPATH_CANDIDATES = globbySync('onnxruntime-web-*.tgz', { cwd: ORT_WEB_FOLDER }); - if (ORT_WEB_PACKED_FILEPATH_CANDIDATES.length !== 1) { - throw new Error('cannot find exactly single package for onnxruntime-web.'); - } - PACKAGES_TO_INSTALL.push(path.resolve(ORT_WEB_FOLDER, ORT_WEB_PACKED_FILEPATH_CANDIDATES[0])); - - // we start here: - - // install dev dependencies - await runInShell(`npm install`); - - // npm install with "--cache" to install packed packages with an empty cache folder - await runInShell(`npm install --cache "${NPM_CACHE_FOLDER}" ${PACKAGES_TO_INSTALL.map((i) => `"${i}"`).join(' ')}`); - - // prepare training data - prepareTrainingDataByCopying(); - - console.log('==============================================================='); - console.log('Running self-hosted tests'); - console.log('==============================================================='); - // test cases with self-host (ort hosted in same origin) - await testAllBrowserCases({ hostInKarma: true }); - - console.log('==============================================================='); - console.log('Running not self-hosted tests'); - console.log('==============================================================='); - // test cases without self-host (ort hosted in cross origin) - const server = startServer(path.join(TEST_E2E_RUN_FOLDER, 'node_modules', 'onnxruntime-web'), 8081); - try { - await testAllBrowserCases({ hostInKarma: false }); - } finally { - // close the server after all tests - await server.close(); - } -} - -async function testAllBrowserCases({ hostInKarma }) { - await runKarma({ hostInKarma, main: './browser-test-wasm.js' }); -} - -async function runKarma({ hostInKarma, main, browser = BROWSER, ortMain = 'ort.training.wasm.min.js' }) { - console.log('==============================================================='); - console.log(`Running karma with the following binary: ${ortMain}`); - console.log('==============================================================='); - const selfHostFlag = hostInKarma ? '--self-host' : ''; - await runInShell( - `npx karma start --single-run --browsers ${browser} ${selfHostFlag} --ort-main=${ - ortMain - } --test-main=${main} --user-data=${getNextUserDataDir()}`, - ); -} - -async function runInShell(cmd) { - console.log('==============================================================='); - console.log(' Running command in shell:'); - console.log(' > ' + cmd); - console.log('==============================================================='); - let complete = false; - const childProcess = spawn(cmd, { shell: true, stdio: 'inherit', cwd: TEST_E2E_RUN_FOLDER }); - childProcess.on('close', function (code) { - if (code !== 0) { - process.exit(code); - } else { - complete = true; - } - }); - while (!complete) { - await delay(100); - } -} - -async function delay(ms) { - return new Promise(function (resolve) { - setTimeout(function () { - resolve(); - }, ms); - }); -} - -function prepareTrainingDataByCopying() { - fs.copySync(TRAINING_DATA_FOLDER, TRAININGDATA_DEST); - console.log(`Copied ${TRAINING_DATA_FOLDER} to ${TRAININGDATA_DEST}`); -} - -main(); diff --git a/js/web/test/training/e2e/simple-http-server.js b/js/web/test/training/e2e/simple-http-server.js deleted file mode 100644 index ef9cced681cc8..0000000000000 --- a/js/web/test/training/e2e/simple-http-server.js +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -'use strict'; - -// this is a simple HTTP server that enables CORS. -// following code is based on https://developer.mozilla.org/en-US/docs/Learn/Server-side/Node_server_without_framework - -const http = require('http'); -const fs = require('fs'); -const path = require('path'); - -const getRequestData = (url, dir) => { - const pathname = new URL(url, 'http://localhost').pathname; - - let filepath; - let mimeType; - if (pathname.startsWith('/test-wasm-path-override/') || pathname.startsWith('/dist/')) { - filepath = path.resolve(dir, pathname.substring(1)); - } else { - return null; - } - - if (filepath.endsWith('.wasm')) { - mimeType = 'application/wasm'; - } else if (filepath.endsWith('.js') || filepath.endsWith('.mjs')) { - mimeType = 'text/javascript'; - } else { - return null; - } - - return [filepath, mimeType]; -}; - -module.exports = function (dir, port) { - const server = http - .createServer(function (request, response) { - const url = request.url.replace(/\n|\r/g, ''); - console.log(`request ${url}`); - - const requestData = getRequestData(url, dir); - if (!request || !requestData) { - response.writeHead(404); - response.end('404'); - } else { - const [filePath, contentType] = requestData; - fs.readFile(path.resolve(dir, filePath), function (error, content) { - if (error) { - if (error.code == 'ENOENT') { - response.writeHead(404); - response.end('404'); - } else { - response.writeHead(500); - response.end('500'); - } - } else { - response.setHeader('access-control-allow-origin', '*'); - response.writeHead(200, { 'Content-Type': contentType }); - response.end(content, 'utf-8'); - } - }); - } - }) - .listen(port); - console.log(`Server running at http://localhost:${port}/`); - return server; -}; diff --git a/js/web/types.d.ts b/js/web/types.d.ts index 735b6a89a2a86..b82248c0c83b8 100644 --- a/js/web/types.d.ts +++ b/js/web/types.d.ts @@ -20,7 +20,3 @@ declare module 'onnxruntime-web/webgl' { declare module 'onnxruntime-web/webgpu' { export * from 'onnxruntime-web'; } - -declare module 'onnxruntime-web/training' { - export * from 'onnxruntime-web'; -}