From 1c2dca95d813e3bf7a2b59a70fcedae9c84bed7d Mon Sep 17 00:00:00 2001
From: Ye Wang <52801275+wangyems@users.noreply.github.com>
Date: Wed, 3 Jan 2024 04:38:33 +0000
Subject: [PATCH 001/677] pass rotary embedding to attention op (#18846)
### Description
### Motivation and Context
---
docs/ContribOperators.md | 2 ++
.../contrib_ops/cpu/bert/attention_base.cc | 1 +
.../contrib_ops/cpu/bert/attention_base.h | 2 ++
.../contrib_ops/cpu/bert/attention_common.h | 1 +
.../cuda/bert/add_bias_transpose.cu | 19 +++++-----
.../cuda/bert/add_bias_transpose.h | 2 +-
.../cuda/bert/attention_prepare_qkv.cu | 3 +-
.../core/graph/contrib_ops/bert_defs.cc | 4 +++
.../test_parity_neox_attention.py | 36 +++++++++++--------
9 files changed, 45 insertions(+), 25 deletions(-)
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 131db5d8d9b37..38fceef67de25 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -155,6 +155,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)
qkv_hidden_sizes : list of ints
Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size
+rotary_embedding_dim : int
+Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
unidirectional : int
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
index 5d224bdc2235f..515a967aa2386 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
@@ -253,6 +253,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
output_parameters->is_unidirectional = is_unidirectional_;
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
output_parameters->do_rotary = do_rotary_;
+ output_parameters->rotary_embedding = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_;
output_parameters->mask_filter_value = mask_filter_value_;
output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
index 5ee40c4b98664..a6782daa58f1a 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
@@ -38,6 +38,7 @@ class AttentionBase {
is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1;
+ rotary_embedding_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0));
mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
scale_ = info.GetAttrOrDefault("scale", 0.0f);
@@ -72,6 +73,7 @@ class AttentionBase {
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
bool past_present_share_buffer_; // whether or not the past (if used) and present tensor share the same buffer
bool do_rotary_; // whether or not to use rotary embeddings
+ int rotary_embedding_; // rotary embedding dimension
float mask_filter_value_; // the value to be used for filtered out positions
float scale_; // the scale to be used for softmax
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index a7f83469a768d..c9ed23895b60c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -56,6 +56,7 @@ struct AttentionParameters {
int v_head_size; // hidden size per head of V
int num_heads;
int num_splits;
+ int rotary_embedding;
bool is_unidirectional;
bool past_present_share_buffer;
bool do_rotary;
diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
index 626e4c0b87a3c..1ea2540db486f 100644
--- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
@@ -640,7 +640,7 @@ void InvokeAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count,
- bool do_rotary = false, int past_sequence_length = 0) {
+ bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0) {
assert(num_heads <= max_threads_per_block);
if (do_rotary) {
@@ -650,20 +650,20 @@ void InvokeAddBiasTranspose(
if (format != 1 && format != 2 && format != 3) {
ORT_THROW("format must be 1, 2 or 3 for rotary attention");
}
- if (qk_head_size != 64 && qk_head_size != 128) {
- ORT_THROW("qk_head_size must be 64 or 128 for rotary attention");
+ if (rotary_embedding != 32 && rotary_embedding != 64 && rotary_embedding != 128) {
+ ORT_THROW("rotary_embedding must be 32, 64 or 128 for rotary attention");
}
if (v_head_size != -1 && qk_head_size != v_head_size) {
ORT_THROW("qk_head_size must be equal to v_head_size for rotary attention");
}
const int step = past_sequence_length == 0 ? sequence_length : past_sequence_length;
- size_t smem_size = 2 * qk_head_size * sizeof(T);
+ size_t smem_size = 2 * rotary_embedding * sizeof(T);
const dim3 grid(sequence_length, num_heads, batch_size);
const dim3 block((qk_head_size / 2 + 31) / 32 * 32, 1, 1);
AddBiasTransposeQKV<<>>(total_matrix_count, input, biases, output,
- qkv_add_bias, qk_head_size, qk_head_size,
+ qkv_add_bias, rotary_embedding, qk_head_size,
step, format);
#else
ORT_THROW("Rotary Attention is supported on sm >= 530. Current sm is", __CUDA_ARCH__);
@@ -727,7 +727,7 @@ void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const half* input, const half* biases, half* output, bool enable_half4, const int v_head_size,
- half* qkv_add_bias, int total_matrix_count, bool do_rotary, int past_sequence_length) {
+ half* qkv_add_bias, int total_matrix_count, bool do_rotary, int rotary_embedding, int past_sequence_length) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) {
const int H = qk_head_size / 4;
@@ -753,7 +753,7 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output,
- qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length);
+ qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding, past_sequence_length);
}
}
@@ -763,7 +763,7 @@ void LaunchAddBiasTranspose(
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const float* input, const float* biases, float* output, bool /*enable_half4*/,
const int v_head_size, float* qkv_add_bias, int total_matrix_count, bool do_rotary,
- int past_sequence_length) {
+ int rotary_embedding, int past_sequence_length) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) {
const int H = qk_head_size / 4;
@@ -789,7 +789,8 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output,
- qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length);
+ qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding,
+ past_sequence_length);
}
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
index d903267c99a01..efc31db43bcdb 100644
--- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
+++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
@@ -33,7 +33,7 @@ void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr,
- int total_matrix_count = -1, bool do_rotary = false, int past_sequence_length = 0);
+ int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0);
// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format.
// For self attention:
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
index 5c65a30918ece..a513d9e8d2211 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
@@ -65,7 +65,8 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
- 3, parameters.do_rotary, parameters.past_sequence_length);
+ 3, parameters.do_rotary, parameters.rotary_embedding,
+ parameters.past_sequence_length);
}
return Status::OK();
}
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index ea67218b5c927..f8f63650615fd 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -333,6 +333,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Whether to use rotary position embedding. Default value is 0.",
AttributeProto::INT,
OPTIONAL_VALUE)
+ .Attr("rotary_embedding_dim",
+ "Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size",
+ AttributeProto::INT,
+ OPTIONAL_VALUE)
.Attr("mask_filter_value",
"The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT,
diff --git a/onnxruntime/test/python/transformers/test_parity_neox_attention.py b/onnxruntime/test/python/transformers/test_parity_neox_attention.py
index 8c8e871a854b0..a98bb623beaea 100644
--- a/onnxruntime/test/python/transformers/test_parity_neox_attention.py
+++ b/onnxruntime/test/python/transformers/test_parity_neox_attention.py
@@ -29,6 +29,7 @@ def create_neox_attention_graph(
qkv_weight,
qkv_bias,
num_heads,
+ rotary_embedding,
):
nodes = [
helper.make_node(
@@ -43,6 +44,7 @@ def create_neox_attention_graph(
num_heads=num_heads,
unidirectional=1,
do_rotary=1,
+ rotary_embedding=rotary_embedding,
domain="com.microsoft",
),
]
@@ -174,13 +176,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
class GPTNeoXAttention(nn.Module):
- def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
+ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0, rotary_ndims=64):
super().__init__()
self.do_rotary = True
self.num_attention_heads = num_head
self.hidden_size = hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
- self.rotary_ndims = int(self.head_size)
+ self.rotary_ndims = rotary_ndims
max_positions = 2048
self.register_buffer(
"bias",
@@ -197,6 +199,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
# self.query_key_value.bias.data.copy_(torch.tensor(np.zeros((3 * hidden_size))))
if past_seq_len > 0:
+ assert self.rotary_ndims == self.head_size
self.onnx_graph = create_neox_decoder_masked_self_attention_graph(
batch_size,
seq_len,
@@ -220,6 +223,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
.transpose(0, 1),
self.query_key_value.bias.reshape(self.num_attention_heads, 3, -1).transpose(0, 1).reshape(-1),
self.num_attention_heads,
+ self.rotary_ndims,
)
@classmethod
@@ -422,17 +426,21 @@ def test_gpt_neox_attention(self):
for batch_size in [1, 2, 4, 8]:
for seq_len in [32, 128, 512, 1024, 2048]:
for num_head in [12]:
- for hidden_size in [768]:
- attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size)
-
- hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
- torch.float32
- )
-
- torch_output = attn.torch_forward(hidden_states)
- ort_output = attn.onnx_forward(hidden_states)
- if ort_output is not None:
- assert torch.allclose(torch_output, ort_output, atol=1e-4)
+ for rotary_ndims in [32, 64]:
+ for hidden_size in [768, 960]:
+ attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size, 0, rotary_ndims)
+
+ hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
+ torch.float32
+ )
+
+ torch_output = attn.torch_forward(hidden_states)
+ ort_output = attn.onnx_forward(hidden_states)
+ if ort_output is not None:
+ assert torch.allclose(torch_output, ort_output, atol=1e-3)
+ print(
+ f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}, {rotary_ndims}"
+ )
def test_gpt_neox_decoder_masked_self_attention(self):
for batch_size in [1, 2, 4, 8]:
@@ -466,7 +474,7 @@ def test_gpt_neox_decoder_masked_self_attention(self):
hidden_states, attention_mask=attention_mask, layer_past=layer_past
)
if ort_output is not None:
- assert torch.allclose(torch_output, ort_output, atol=1e-4)
+ assert torch.allclose(torch_output, ort_output, atol=1e-3)
if __name__ == "__main__":
From c97e3f48216d66dfbc6aa951ddcb7f32e313d314 Mon Sep 17 00:00:00 2001
From: Yi Zhang
Date: Wed, 3 Jan 2024 14:53:31 +0800
Subject: [PATCH 002/677] [Fix] exception in Fuzz Test pipeline (#18984)
### Description
### Motivation and Context
The file path is not correct.
---
.../github/azure-pipelines/win-ci-fuzz-testing.yml | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml
index 98f1bf7ea1a16..b8f9566274acc 100644
--- a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml
+++ b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml
@@ -20,7 +20,11 @@ jobs:
workspace:
clean: all
steps:
- - template: win-ci-prebuild-steps.yml
+ - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
+ displayName: 'Clean Agent Directories'
+ condition: always()
+
+ - template: templates/jobs/win-ci-prebuild-steps.yml
parameters:
EnvSetupScript: $(EnvSetupScript)
DownloadCUDA: false
@@ -69,7 +73,3 @@ jobs:
script: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)\onnxruntime_security_fuzz.exe /t /f "$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)\testdata\mnist.onnx" 1 m'
workingDirectory: $(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)
failOnStderr: false # Optional
-
- - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3
- displayName: 'Clean Agent Directories'
- condition: always()
From 7a454acd6197f4ba1ffca13ec9948915ce82d20e Mon Sep 17 00:00:00 2001
From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com>
Date: Wed, 3 Jan 2024 17:25:15 +0800
Subject: [PATCH 003/677] [ROCm] Update CI/Packaging pipeline to ROCm6.0
(#18985)
Update CI/Packaing pipeline to ROCm6.0
---
...mi200.huggingface.bert-large-rocm6.0.json} | 28 +++++++++----------
.../linux-migraphx-ci-pipeline.yml | 2 +-
.../orttraining-pai-ci-pipeline.yml | 2 +-
...orttraining-py-packaging-pipeline-rocm.yml | 24 ++++++++--------
.../docker/Dockerfile.manylinux2_28_rocm | 2 +-
.../migraphx-ci-pipeline-env.Dockerfile | 2 +-
.../docker/scripts/setup_rocm_yum_repo.sh | 6 ++--
.../pai/rocm-ci-pipeline-env.Dockerfile | 4 +--
8 files changed, 35 insertions(+), 35 deletions(-)
rename orttraining/tools/ci_test/results/{ci-mi200.huggingface.bert-large-rocm5.7.json => ci-mi200.huggingface.bert-large-rocm6.0.json} (61%)
diff --git a/orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm5.7.json b/orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm6.0.json
similarity index 61%
rename from orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm5.7.json
rename to orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm6.0.json
index a4ac02b566848..05fcf08cd3232 100644
--- a/orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm5.7.json
+++ b/orttraining/tools/ci_test/results/ci-mi200.huggingface.bert-large-rocm6.0.json
@@ -2,56 +2,56 @@
"steps": [
{
"step": 20,
- "loss": 2.0017
+ "loss": 2.0136
},
{
"step": 40,
- "loss": 1.8337
+ "loss": 1.8466
},
{
"step": 60,
- "loss": 1.7538
+ "loss": 1.7525
},
{
"step": 80,
- "loss": 1.6728
+ "loss": 1.6682
},
{
"step": 100,
- "loss": 1.6656
+ "loss": 1.658
},
{
"step": 120,
- "loss": 1.6752
+ "loss": 1.6749
},
{
"step": 140,
- "loss": 1.6335
+ "loss": 1.6263
},
{
"step": 160,
- "loss": 1.6815
+ "loss": 1.6828
},
{
"step": 180,
- "loss": 1.6155
+ "loss": 1.6145
},
{
"step": 200,
- "loss": 1.6177
+ "loss": 1.6197
},
{
"step": 220,
- "loss": 1.632
+ "loss": 1.6353
},
{
"step": 240,
- "loss": 1.5161
+ "loss": 1.5266
},
{
"step": 260,
- "loss": 1.5433
+ "loss": 1.5441
}
],
- "samples_per_second": 32.335
+ "samples_per_second": 34.561
}
diff --git a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml
index 5dac8fc9cda63..f7571a3b7eab6 100644
--- a/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/linux-migraphx-ci-pipeline.yml
@@ -36,7 +36,7 @@ variables:
- name: render
value: 109
- name: RocmVersion
- value: 5.7
+ value: 6.0
jobs:
- job: Linux_Build
diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml
index 8d02a5e5809a2..a53f91fb317cb 100644
--- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml
+++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml
@@ -25,7 +25,7 @@ variables:
- name: render
value: 109
- name: RocmVersion
- value: 5.7
+ value: 6.0
- name: BuildConfig
value: Release
diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml
index f2ba99369c144..bbdbe0fd8e376 100644
--- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml
+++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml
@@ -9,51 +9,51 @@ resources:
ref: 5eda9aded5462201e6310105728d33016e637ea7
stages:
-- stage: "Python_Packaging_ROCm57_Release"
+- stage: "Python_Packaging_ROCm60_Release"
jobs:
- template: templates/rocm.yml
parameters:
PythonVersion: '3.8'
- RocmVersion: '5.7'
+ RocmVersion: '6.0'
- template: templates/rocm.yml
parameters:
PythonVersion: '3.9'
- RocmVersion: '5.7'
+ RocmVersion: '6.0'
- template: templates/rocm.yml
parameters:
PythonVersion: '3.10'
- RocmVersion: '5.7'
+ RocmVersion: '6.0'
-- stage: "Python_Packaging_ROCm57_Debug"
+- stage: "Python_Packaging_ROCm60_Debug"
jobs:
- template: templates/rocm.yml
parameters:
PythonVersion: '3.8'
- RocmVersion: '5.7'
+ RocmVersion: '6.0'
BuildConfig: 'Debug'
- template: templates/rocm.yml
parameters:
PythonVersion: '3.9'
- RocmVersion: '5.7'
+ RocmVersion: '6.0'
BuildConfig: 'Debug'
- template: templates/rocm.yml
parameters:
PythonVersion: '3.10'
- RocmVersion: '5.7'
+ RocmVersion: '6.0'
BuildConfig: 'Debug'
-- stage: "Python_Packaging_ROCm56_Release"
+- stage: "Python_Packaging_ROCm57_Release"
condition: ne(variables['ORT_DISABLE_PYTHON_PACKAGE_LOCAL_VERSION'], 'true')
jobs:
- template: templates/rocm.yml
parameters:
PythonVersion: '3.8'
- RocmVersion: '5.6'
+ RocmVersion: '5.7'
- template: templates/rocm.yml
parameters:
PythonVersion: '3.9'
- RocmVersion: '5.6'
+ RocmVersion: '5.7'
- template: templates/rocm.yml
parameters:
PythonVersion: '3.10'
- RocmVersion: '5.6'
+ RocmVersion: '5.7'
diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm
index 9e12fe8c75451..b9fd88083f218 100644
--- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm
+++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm
@@ -31,7 +31,7 @@ RUN yum install -y hipify-clang
RUN yum -y install wget
# rocm lib
-RUN yum install -y miopen-hip-devel rocblas-devel rocrand-devel rccl-devel hipsparse-devel hipfft-devel hipcub-devel hipblas-devel rocthrust-devel migraphx-devel
+RUN yum install -y migraphx-devel
ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM}
ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8
diff --git a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile
index d02e7d8b91d11..85d738d2167e1 100644
--- a/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile
+++ b/tools/ci_build/github/linux/docker/migraphx-ci-pipeline-env.Dockerfile
@@ -1,7 +1,7 @@
# Refer to https://github.com/RadeonOpenCompute/ROCm-docker/blob/master/dev/Dockerfile-ubuntu-22.04-complete
FROM ubuntu:22.04
-ARG ROCM_VERSION=5.7
+ARG ROCM_VERSION=6.0
ARG AMDGPU_VERSION=${ROCM_VERSION}
ARG APT_PREF='Package: *\nPin: release o=repo.radeon.com\nPin-Priority: 600'
diff --git a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh
index fcd9086061227..269337bbba042 100755
--- a/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh
+++ b/tools/ci_build/github/linux/docker/scripts/setup_rocm_yum_repo.sh
@@ -2,7 +2,7 @@
set -e -x
# version
-ROCM_VERSION=5.6
+ROCM_VERSION=6.0
while getopts "r:" parameter_Option
do case "${parameter_Option}"
@@ -14,7 +14,7 @@ done
tee /etc/yum.repos.d/amdgpu.repo <
Date: Thu, 4 Jan 2024 02:13:17 +0800
Subject: [PATCH 004/677] [js/webgpu] Introduce trace support (#18928)
This is to leverage console.timeStamp to add a single marker to
browsers' (only Chromium and Firefox support it) performance tool. With
this support, we can dump both CPU and GPU timestamps, and use
post-processing tool to clearly understand the calibrated timeline. A
demo tool can be found at https://github.com/webatintel/ort-test, and
more detailed info can be found at
https://docs.google.com/document/d/1TuVxjE8jnELBXdhI4QGFgMnUqQn6Q53QA9y4a_dH688/edit.
---
js/common/lib/env.ts | 7 +++
js/common/lib/index.ts | 1 +
js/common/lib/inference-session-impl.ts | 5 +++
js/common/lib/trace.ts | 44 +++++++++++++++++++
js/web/lib/backend-wasm.ts | 4 ++
js/web/lib/wasm/jsep/backend-webgpu.ts | 4 +-
.../lib/wasm/jsep/webgpu/program-manager.ts | 6 +++
js/web/lib/wasm/session-handler-inference.ts | 6 ++-
8 files changed, 75 insertions(+), 2 deletions(-)
create mode 100644 js/common/lib/trace.ts
diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts
index 0cded7e5edbcb..b007b5e164bf3 100644
--- a/js/common/lib/env.ts
+++ b/js/common/lib/env.ts
@@ -33,6 +33,13 @@ export declare namespace Env {
*/
simd?: boolean;
+ /**
+ * set or get a boolean value indicating whether to enable trace.
+ *
+ * @defaultValue `false`
+ */
+ trace?: boolean;
+
/**
* Set or get a number specifying the timeout for initialization of WebAssembly backend, in milliseconds. A zero
* value indicates no timeout is set.
diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts
index 9cbfcc4e8bcdc..d7c98380f3fa4 100644
--- a/js/common/lib/index.ts
+++ b/js/common/lib/index.ts
@@ -21,5 +21,6 @@ export * from './backend.js';
export * from './env.js';
export * from './inference-session.js';
export * from './tensor.js';
+export * from './trace.js';
export * from './onnx-value.js';
export * from './training-session.js';
diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts
index 9bc2088f2088a..55f40c8907a89 100644
--- a/js/common/lib/inference-session-impl.ts
+++ b/js/common/lib/inference-session-impl.ts
@@ -6,6 +6,7 @@ import {InferenceSessionHandler} from './backend.js';
import {InferenceSession as InferenceSessionInterface} from './inference-session.js';
import {OnnxValue} from './onnx-value.js';
import {Tensor} from './tensor.js';
+import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from './trace.js';
type SessionOptions = InferenceSessionInterface.SessionOptions;
type RunOptions = InferenceSessionInterface.RunOptions;
@@ -20,6 +21,7 @@ export class InferenceSession implements InferenceSessionInterface {
run(feeds: FeedsType, options?: RunOptions): Promise;
run(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise;
async run(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise {
+ TRACE_FUNC_BEGIN();
const fetches: {[name: string]: OnnxValue|null} = {};
let options: RunOptions = {};
// check inputs
@@ -117,6 +119,7 @@ export class InferenceSession implements InferenceSessionInterface {
}
}
}
+ TRACE_FUNC_END();
return returnValue;
}
@@ -132,6 +135,7 @@ export class InferenceSession implements InferenceSessionInterface {
static async create(
arg0: string|ArrayBufferLike|Uint8Array, arg1?: SessionOptions|number, arg2?: number,
arg3?: SessionOptions): Promise {
+ TRACE_FUNC_BEGIN();
// either load from a file or buffer
let filePathOrUint8Array: string|Uint8Array;
let options: SessionOptions = {};
@@ -196,6 +200,7 @@ export class InferenceSession implements InferenceSessionInterface {
const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
const backend = await resolveBackend(backendHints);
const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options);
+ TRACE_FUNC_END();
return new InferenceSession(handler);
}
diff --git a/js/common/lib/trace.ts b/js/common/lib/trace.ts
new file mode 100644
index 0000000000000..404f7ef8089af
--- /dev/null
+++ b/js/common/lib/trace.ts
@@ -0,0 +1,44 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+import {env} from './env-impl.js';
+
+export const TRACE = (deviceType: string, label: string) => {
+ if (!env.wasm.trace) {
+ return;
+ }
+ // eslint-disable-next-line no-console
+ console.timeStamp(`${deviceType}::ORT::${label}`);
+};
+
+const TRACE_FUNC = (msg: string, extraMsg?: string) => {
+ const stack = new Error().stack?.split(/\r\n|\r|\n/g) || [];
+ let hasTraceFunc = false;
+ for (let i = 0; i < stack.length; i++) {
+ if (hasTraceFunc && !stack[i].includes('TRACE_FUNC')) {
+ let label = `FUNC_${msg}::${stack[i].trim().split(' ')[1]}`;
+ if (extraMsg) {
+ label += `::${extraMsg}`;
+ }
+ TRACE('CPU', label);
+ return;
+ }
+ if (stack[i].includes('TRACE_FUNC')) {
+ hasTraceFunc = true;
+ }
+ }
+};
+
+export const TRACE_FUNC_BEGIN = (extraMsg?: string) => {
+ if (!env.wasm.trace) {
+ return;
+ }
+ TRACE_FUNC('BEGIN', extraMsg);
+};
+
+export const TRACE_FUNC_END = (extraMsg?: string) => {
+ if (!env.wasm.trace) {
+ return;
+ }
+ TRACE_FUNC('END', extraMsg);
+};
diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts
index 2d123cdb71290..d9f63fec9c492 100644
--- a/js/web/lib/backend-wasm.ts
+++ b/js/web/lib/backend-wasm.ts
@@ -26,6 +26,10 @@ export const initializeFlags = (): void => {
env.wasm.proxy = false;
}
+ if (typeof env.wasm.trace !== 'boolean') {
+ env.wasm.trace = false;
+ }
+
if (typeof env.wasm.numThreads !== 'number' || !Number.isInteger(env.wasm.numThreads) || env.wasm.numThreads <= 0) {
const numCpuLogicalCores = typeof navigator === 'undefined' ? cpus().length : navigator.hardwareConcurrency;
env.wasm.numThreads = Math.min(4, Math.ceil((numCpuLogicalCores || 1) / 2));
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index 6c3d22352772e..0148f32cdd91b 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {Env, Tensor} from 'onnxruntime-common';
+import {Env, Tensor, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';
import {configureLogger, LOG_DEBUG} from './log';
import {createView, TensorView} from './tensor-view';
@@ -263,6 +263,7 @@ export class WebGpuBackend {
run(program: ProgramInfo, inputTensorViews: readonly TensorView[], outputIndices: readonly number[],
createKernelOutput: (index: number, dataType: number, dims: readonly number[]) => TensorView,
createIntermediateOutput: (dataType: number, dims: readonly number[]) => TensorView): TensorView[] {
+ TRACE_FUNC_BEGIN(program.name);
// create info for inputs
const inputDatas: GpuData[] = [];
for (let i = 0; i < inputTensorViews.length; ++i) {
@@ -387,6 +388,7 @@ export class WebGpuBackend {
artifact, inputTensorViews, outputTensorViews, inputDatas, outputDatas, normalizedDispatchGroup,
uniformBufferBinding);
+ TRACE_FUNC_END(program.name);
return outputTensorViews;
}
diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts
index ae5bf68483b46..0d699326366b3 100644
--- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts
+++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+import {TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';
+
import {tensorDataTypeEnumToString} from '../../wasm-common';
import {WebGpuBackend} from '../backend-webgpu';
import {LOG_DEBUG} from '../log';
@@ -35,6 +37,7 @@ export class ProgramManager {
run(buildArtifact: Artifact, inputTensorViews: readonly TensorView[], outputTensorViews: readonly TensorView[],
inputs: GpuData[], outputs: GpuData[], dispatchGroup: [number, number, number],
uniformBufferBinding: GPUBindingResource|undefined): void {
+ TRACE_FUNC_BEGIN(buildArtifact.programInfo.name);
const device = this.backend.device;
const computePassEncoder = this.backend.getComputePassEncoder();
@@ -128,11 +131,13 @@ export class ProgramManager {
if (this.backend.pendingDispatchNumber >= 16) {
this.backend.flush();
}
+ TRACE_FUNC_END(buildArtifact.programInfo.name);
}
dispose(): void {
// this.repo.forEach(a => this.glContext.deleteProgram(a.program));
}
build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
+ TRACE_FUNC_BEGIN(programInfo.name);
const device = this.backend.device;
const extensions: string[] = [];
if (device.features.has('shader-f16')) {
@@ -147,6 +152,7 @@ export class ProgramManager {
const computePipeline = device.createComputePipeline(
{compute: {module: shaderModule, entryPoint: 'main'}, layout: 'auto', label: programInfo.name});
+ TRACE_FUNC_END(programInfo.name);
return {programInfo, computePipeline};
}
diff --git a/js/web/lib/wasm/session-handler-inference.ts b/js/web/lib/wasm/session-handler-inference.ts
index b62287483208a..e17ec37e3e612 100644
--- a/js/web/lib/wasm/session-handler-inference.ts
+++ b/js/web/lib/wasm/session-handler-inference.ts
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
import {readFile} from 'node:fs/promises';
-import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';
+import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor, TRACE_FUNC_BEGIN, TRACE_FUNC_END} from 'onnxruntime-common';
import {SerializableInternalBuffer, TensorMetadata} from './proxy-messages';
import {copyFromExternalBuffer, createSession, endProfiling, releaseSession, run} from './proxy-wrapper';
@@ -54,6 +54,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
}
async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise {
+ TRACE_FUNC_BEGIN();
let model: Parameters[0];
if (typeof pathOrBuffer === 'string') {
@@ -70,6 +71,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
}
[this.sessionId, this.inputNames, this.outputNames] = await createSession(model, options);
+ TRACE_FUNC_END();
}
async dispose(): Promise {
@@ -78,6 +80,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions):
Promise {
+ TRACE_FUNC_BEGIN();
const inputArray: Tensor[] = [];
const inputIndices: number[] = [];
Object.entries(feeds).forEach(kvp => {
@@ -115,6 +118,7 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
for (let i = 0; i < results.length; i++) {
resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]);
}
+ TRACE_FUNC_END();
return resultMap;
}
From 3b8b9147fa4f8f6348e171a257bbc325744301df Mon Sep 17 00:00:00 2001
From: Jiajie Hu
Date: Thu, 4 Jan 2024 06:15:26 +0800
Subject: [PATCH 005/677] [js/webgpu] Mitigate floating point accuracy issue in
Resize (#18956)
### Description
The patch fixes a floating point accuracy issue in Resize by preferring
integer indices and integer arithmetic where possible.
### Motivation and Context
Model test `test_resize_upsample_sizes_nearest_floor_align_corners` was
observed to be failing on certain platforms. The root cause is the
inaccurate floating point evaluation of 21 / 7 (2.999... vs 3), which
results in the wrong input element to be indexed (floor(2.999...) vs
floor(3)).
---
js/web/lib/wasm/jsep/webgpu/ops/resize.ts | 83 ++++++++++++-----------
1 file changed, 45 insertions(+), 38 deletions(-)
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
index bea3e8625b41b..d359580904a7b 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
@@ -110,41 +110,48 @@ const validateInputs =
const getOriginalCoordinateFromResizedCoordinate =
(coordinateTransferMode: CoordinateTransformMode, dType: string): string =>
- `fn getOriginalCoordinateFromResizedCoordinate(xResized: ${dType}, xScale: ${dType}, lengthResized: ${dType},
- lengthOriginal: ${dType}, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` +
+ `fn getOriginalCoordinateFromResizedCoordinate(xResized: u32, xScale: ${dType}, lengthResized: u32,
+ lengthOriginal: u32, roiStart: ${dType}, roiEnd: ${dType}) -> ${dType} { ` +
(() => {
switch (coordinateTransferMode) {
case 'asymmetric':
- return 'return xResized / xScale;';
+ return `return ${dType}(xResized) / xScale;`;
case 'pytorch_half_pixel':
- return 'if (lengthResized > 1) { \
- return (xResized + 0.5) / xScale - 0.5; \
- } else { \
- return 0.0; \
- }';
+ return `if (lengthResized > 1) {
+ return (${dType}(xResized) + 0.5) / xScale - 0.5;
+ } else {
+ return 0.0;
+ }`;
case 'tf_half_pixel_for_nn':
- return 'return (xResized + 0.5) / xScale;';
+ return `return (${dType}(xResized) + 0.5) / xScale;`;
case 'align_corners':
- return 'if (lengthResized == 1) { \
- return 0.0; \
- } else { \
- return xResized * (lengthOriginal - 1) / (lengthResized - 1); \
- }';
+ return `if (lengthResized == 1) {
+ return 0.0;
+ } else {
+ // The whole part and the fractional part are calculated separately due to inaccuracy of floating
+ // point division. As an example, f32(21) / f32(7) may evaluate to 2.99... instead of 3, causing an
+ // offset-by-one error later in floor().
+ let whole = ${dType}(xResized * (lengthOriginal - 1) / (lengthResized - 1));
+ let fract =
+ ${dType}(xResized * (lengthOriginal - 1) % (lengthResized - 1)) / ${dType}(lengthResized - 1);
+ return whole + fract;
+ }`;
case 'tf_crop_and_resize':
- return `if (lengthResized > 1) { \
- return roiStart * (lengthOriginal - 1) + \
- (xResized * (roiEnd - roiStart) * (lengthOriginal - 1)) / (lengthResized - 1); \
- } else { \
- return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1); \
+ return `if (lengthResized > 1) {
+ return roiStart * ${dType}(lengthOriginal - 1) +
+ (${dType}(xResized) * (roiEnd - roiStart) * ${dType}(lengthOriginal - 1)) /
+ ${dType}(lengthResized - 1);
+ } else {
+ return 0.5 * (roiStart + roiEnd) * ${dType}(lengthOriginal - 1);
}`;
case 'half_pixel_symmetric':
- return [
- 'const outputWidth = xScale * lengthResized;', 'const adjustment = lengthResized / outputWidth;',
- 'const center = lengthOriginal / 2;', 'const offset = center * (1 - adjustment);',
- 'return offset + ((xResized + 0.5) / xScale) - 0.5;'
- ].join('\n');
+ return `const outputWidth = xScale * ${dType}(lengthResized);
+ const adjustment = ${dType}(lengthResized) / outputWidth;
+ const center = ${dType}(lengthOriginal) / 2;
+ const offset = center * (1 - adjustment);
+ return offset + ((${dType}(xResized) + 0.5) / xScale) - 0.5;`;
case 'half_pixel':
- return 'return ((xResized + 0.5) / xScale) - 0.5;';
+ return `return ((${dType}(xResized) + 0.5) / xScale) - 0.5;`;
default:
throw new Error(`Coordinate transform mode ${coordinateTransferMode} is not supported`);
}
@@ -254,15 +261,15 @@ const calculateOriginalIndicesFromOutputIndices =
output.type.value}, ${outputShape.length}> {
var original_indices: array<${output.type.value}, ${outputShape.length}>;
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
- var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')});
+ var output_index = ${output.indicesGet('output_indices', 'i')};
var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)};
var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)};
var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)};
if (scale == 1.0) {
- original_indices[i] = output_index;
+ original_indices[i] = ${output.type.value}(output_index);
} else {
- var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)});
- var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)});
+ var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
+ var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)};
original_indices[i] = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
input_shape_i, roi_low, roi_hi);
}
@@ -276,23 +283,23 @@ const calculateInputIndicesFromOutputIndices =
fn calculateInputIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} {
var input_indices: ${input.type.indices};
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
- var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')});
+ var output_index = ${output.indicesGet('output_indices', 'i')};
var input_index: u32;
var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)};
if (scale == 1.0) {
- input_index = u32(output_index);
+ input_index = output_index;
} else {
var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)};
var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)};
- var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)});
- var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)});
+ var input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
+ var output_shape_i = ${getElementAt('uniforms.output_shape', 'i', outputShape.length)};
var original_idx = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
input_shape_i, roi_low, roi_hi);
- if (!${useExtrapolation} || (original_idx >= 0 && original_idx < input_shape_i)) {
+ if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${output.type.value}(input_shape_i))) {
if (original_idx < 0) {
input_index = 0;
- } else if (original_idx > (input_shape_i - 1)) {
- input_index = u32(input_shape_i) - 1;
+ } else if (original_idx > ${output.type.value}(input_shape_i - 1)) {
+ input_index = input_shape_i - 1;
} else {
input_index = u32(getNearestPixelFromOriginal(original_idx, scale < 1));
}
@@ -391,8 +398,8 @@ const bicubicInterpolation =
fn ${direction}CubicInterpolation(input_indices: ${input.type.indices}, output_indices: ${
output.type.indices}) -> ${dType} {
var output_index = ${output.indicesGet('output_indices', idx)};
- var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(output_index), ${scales[idx]},
- ${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
+ var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(output_index, ${scales[idx]},
+ ${outputShape[idx]}, ${inputShape[idx]}, ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx);
var coefs = getCubicInterpolationCoefs(fractOriginalIdx);
From c3d96a7b35c975a4eb4ad5a4c94349797defbc78 Mon Sep 17 00:00:00 2001
From: Jeff Bloomfield <38966965+jeffbloo@users.noreply.github.com>
Date: Tue, 2 Jan 2024 18:06:26 -0800
Subject: [PATCH 006/677] Update DML version to 1.13.0 (#18978)
Update DML nuget version to 1.13.0
---
.pipelines/nuget_config/x64/packages.config | 2 +-
.pipelines/nuget_config/x86/packages.config | 2 +-
cmake/external/dml.cmake | 2 +-
packages.config | 2 +-
tools/nuget/generate_nuspec_for_native_nuget.py | 2 +-
5 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/.pipelines/nuget_config/x64/packages.config b/.pipelines/nuget_config/x64/packages.config
index 2ac650b0e6dc9..2583e0d1b2ead 100644
--- a/.pipelines/nuget_config/x64/packages.config
+++ b/.pipelines/nuget_config/x64/packages.config
@@ -1,6 +1,6 @@
-
+
diff --git a/.pipelines/nuget_config/x86/packages.config b/.pipelines/nuget_config/x86/packages.config
index f80f96194a230..5ca659941c159 100644
--- a/.pipelines/nuget_config/x86/packages.config
+++ b/.pipelines/nuget_config/x86/packages.config
@@ -1,6 +1,6 @@
-
+
diff --git a/cmake/external/dml.cmake b/cmake/external/dml.cmake
index 5d25b9529e030..d777306722cd6 100644
--- a/cmake/external/dml.cmake
+++ b/cmake/external/dml.cmake
@@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
- set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.12.1)
+ set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.0)
# Restore nuget packages, which will pull down the DirectML redist package.
add_custom_command(
diff --git a/packages.config b/packages.config
index da61a10adfa74..b67219d6d6913 100644
--- a/packages.config
+++ b/packages.config
@@ -1,6 +1,6 @@
-
+
diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py
index 66248565a3e3a..56e50750ac153 100644
--- a/tools/nuget/generate_nuspec_for_native_nuget.py
+++ b/tools/nuget/generate_nuspec_for_native_nuget.py
@@ -219,7 +219,7 @@ def add_common_dependencies(xml_text, package_name, version):
def generate_dependencies(xml_text, package_name, version):
- dml_dependency = ''
+ dml_dependency = ''
if package_name == "Microsoft.AI.MachineLearning":
xml_text.append("")
From 9bbe425d7f805d4dfaad2e69c3edca40063fe673 Mon Sep 17 00:00:00 2001
From: Xiang Zhang
Date: Thu, 27 Jul 2023 19:13:15 -0700
Subject: [PATCH 007/677] Register LPpool18 and AvgPool 19 (#16880)
---
.../src/External/DirectMLHelpers/ApiTraits.h | 30 ++++++++++++++
.../External/DirectMLHelpers/DirectMLSchema.h | 40 +++++++++++++++++++
.../DirectMLHelpers/GeneratedSchemaHelpers.h | 39 ++++++++++++++++++
.../src/Operators/DmlOperatorPooling.cpp | 40 +++++++++++++++++--
.../src/Operators/OperatorRegistration.cpp | 2 +
.../OperatorAuthorHelper/OperatorVersions.h | 6 +++
.../test/providers/cpu/nn/pool_op_test.cc | 25 ------------
7 files changed, 153 insertions(+), 29 deletions(-)
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
index c75b662af788d..94f2220fcc168 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
@@ -459,12 +459,24 @@ struct OperatorDescTraits
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING;
};
+template <>
+struct OperatorDescTraits
+{
+ static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_AVERAGE_POOLING1;
+};
+
template <>
struct OperatorDescTraits
{
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING;
};
+template <>
+struct OperatorDescTraits
+{
+ static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_LP_POOLING1;
+};
+
template <>
struct OperatorDescTraits
{
@@ -1448,12 +1460,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING>
using DescType = DML_AVERAGE_POOLING_OPERATOR_DESC;
};
+template <>
+struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_AVERAGE_POOLING1>
+{
+ using DescType = DML_AVERAGE_POOLING1_OPERATOR_DESC;
+};
+
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING>
{
using DescType = DML_LP_POOLING_OPERATOR_DESC;
};
+template <>
+struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_LP_POOLING1>
+{
+ using DescType = DML_LP_POOLING1_OPERATOR_DESC;
+};
+
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MAX_POOLING>
{
@@ -2259,8 +2283,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
return std::invoke(std::forward(visitor), DML_ARGMAX_OPERATOR_DESC{}, std::forward(args)...);
case DML_OPERATOR_AVERAGE_POOLING:
return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...);
+ case DML_OPERATOR_AVERAGE_POOLING1:
+ return std::invoke(std::forward(visitor), DML_AVERAGE_POOLING1_OPERATOR_DESC{}, std::forward(args)...);
case DML_OPERATOR_LP_POOLING:
return std::invoke(std::forward(visitor), DML_LP_POOLING_OPERATOR_DESC{}, std::forward(args)...);
+ case DML_OPERATOR_LP_POOLING1:
+ return std::invoke(std::forward(visitor), DML_LP_POOLING1_OPERATOR_DESC{}, std::forward(args)...);
case DML_OPERATOR_MAX_POOLING:
return std::invoke(std::forward(visitor), DML_MAX_POOLING_OPERATOR_DESC{}, std::forward(args)...);
case DML_OPERATOR_MAX_POOLING1:
@@ -2554,7 +2582,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN";
case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX";
case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING";
+ case DML_OPERATOR_AVERAGE_POOLING1: return "DML_OPERATOR_AVERAGE_POOLING1";
case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING";
+ case DML_OPERATOR_LP_POOLING1: return "DML_OPERATOR_LP_POOLING1";
case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING";
case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1";
case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING";
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
index 1ebd52d4ed427..9eae1c1fe8158 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
@@ -757,6 +757,26 @@ constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING_OPERATOR_SCHEMA {
DML_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
};
+constexpr DML_SCHEMA_FIELD DML_AVERAGE_POOLING1_OPERATOR_SCHEMA_FIELDS[9] {
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
+};
+
+constexpr DML_OPERATOR_SCHEMA DML_AVERAGE_POOLING1_OPERATOR_SCHEMA {
+ "DML_OPERATOR_AVERAGE_POOLING1",
+ DML_OPERATOR_AVERAGE_POOLING1,
+ DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
+ 9,
+ DML_AVERAGE_POOLING1_OPERATOR_SCHEMA_FIELDS,
+};
+
constexpr DML_SCHEMA_FIELD DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS[8] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
@@ -776,6 +796,26 @@ constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING_OPERATOR_SCHEMA {
DML_LP_POOLING_OPERATOR_SCHEMA_FIELDS,
};
+constexpr DML_SCHEMA_FIELD DML_LP_POOLING1_OPERATOR_SCHEMA_FIELDS[9] {
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "P", false },
+};
+
+constexpr DML_OPERATOR_SCHEMA DML_LP_POOLING1_OPERATOR_SCHEMA {
+ "DML_OPERATOR_LP_POOLING1",
+ DML_OPERATOR_LP_POOLING1,
+ DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
+ 9,
+ DML_LP_POOLING1_OPERATOR_SCHEMA_FIELDS,
+};
+
constexpr DML_SCHEMA_FIELD DML_MAX_POOLING_OPERATOR_SCHEMA_FIELDS[7] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
index 833871de0bbd9..ad4cceb85cfd2 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
@@ -425,6 +425,21 @@ inline std::vector GetFields(const DML_AVERAGE_POOLING_OPERATOR_D
OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.IncludePadding))),
};
}
+
+inline std::vector GetFields(const DML_AVERAGE_POOLING1_OPERATOR_DESC& desc)
+{
+ return {
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)),
+ OperatorField(&DML_AVERAGE_POOLING1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.IncludePadding))),
+ };
+}
inline std::vector GetFields(const DML_LP_POOLING_OPERATOR_DESC& desc)
{
return {
@@ -438,6 +453,20 @@ inline std::vector GetFields(const DML_LP_POOLING_OPERATOR_DESC&
OperatorField(&DML_LP_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.P))),
};
}
+inline std::vector GetFields(const DML_LP_POOLING1_OPERATOR_DESC& desc)
+{
+ return {
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.DimensionCount))),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)),
+ OperatorField(&DML_LP_POOLING1_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.P))),
+ };
+}
inline std::vector GetFields(const DML_MAX_POOLING_OPERATOR_DESC& desc)
{
return {
@@ -1684,7 +1713,9 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType)
case DML_OPERATOR_ARGMIN: return DML_ARGMIN_OPERATOR_SCHEMA;
case DML_OPERATOR_ARGMAX: return DML_ARGMAX_OPERATOR_SCHEMA;
case DML_OPERATOR_AVERAGE_POOLING: return DML_AVERAGE_POOLING_OPERATOR_SCHEMA;
+ case DML_OPERATOR_AVERAGE_POOLING1: return DML_AVERAGE_POOLING1_OPERATOR_SCHEMA;
case DML_OPERATOR_LP_POOLING: return DML_LP_POOLING_OPERATOR_SCHEMA;
+ case DML_OPERATOR_LP_POOLING1: return DML_LP_POOLING1_OPERATOR_SCHEMA;
case DML_OPERATOR_MAX_POOLING: return DML_MAX_POOLING_OPERATOR_SCHEMA;
case DML_OPERATOR_MAX_POOLING1: return DML_MAX_POOLING1_OPERATOR_SCHEMA;
case DML_OPERATOR_ROI_POOLING: return DML_ROI_POOLING_OPERATOR_SCHEMA;
@@ -2002,10 +2033,18 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_AVERAGE_POOLING_OPERATOR_SCHEMA,
GetFields(*static_cast(opDesc.Desc)));
+ case DML_OPERATOR_AVERAGE_POOLING1:
+ return AbstractOperatorDesc(
+ &DML_AVERAGE_POOLING1_OPERATOR_SCHEMA,
+ GetFields(*static_cast(opDesc.Desc)));
case DML_OPERATOR_LP_POOLING:
return AbstractOperatorDesc(
&DML_LP_POOLING_OPERATOR_SCHEMA,
GetFields(*static_cast(opDesc.Desc)));
+ case DML_OPERATOR_LP_POOLING1:
+ return AbstractOperatorDesc(
+ &DML_LP_POOLING1_OPERATOR_SCHEMA,
+ GetFields(*static_cast(opDesc.Desc)));
case DML_OPERATOR_MAX_POOLING:
return AbstractOperatorDesc(
&DML_MAX_POOLING_OPERATOR_SCHEMA,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp
index e8d5b2746aa13..10ff1d8be8a29 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPooling.cpp
@@ -34,7 +34,7 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
kernelOutputIndices.emplace_back(1);
}
DmlOperator::Initialize(kernelInfo, std::nullopt, kernelOutputIndices);
-
+
std::vector inputDescs = GetDmlInputDescs();
std::vector outputDescs = GetDmlOutputDescs();
ML_CHECK_VALID_ARGUMENT(inputDescs.size() >= 1, "MaxPool input count must be >=1.");
@@ -98,6 +98,21 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
SetOpDesc(desc);
break;
}
+ case DML_OPERATOR_AVERAGE_POOLING1:
+ {
+ if (hasDilations) {
+ DML_AVERAGE_POOLING1_OPERATOR_DESC desc = {};
+ desc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false);
+ desc.Dilations = m_kernel.dilations;
+ SetOpDesc(desc);
+ }
+ else {
+ DML_AVERAGE_POOLING_OPERATOR_DESC desc = {};
+ desc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false);
+ SetOpDesc(desc);
+ }
+ break;
+ }
case DML_OPERATOR_LP_POOLING:
{
DML_LP_POOLING_OPERATOR_DESC desc = {};
@@ -106,6 +121,23 @@ class DmlOperatorPooling : public DmlOperator, public PoolingHelperBase
SetOpDesc(desc);
break;
}
+ case DML_OPERATOR_LP_POOLING1:
+ {
+ if (hasDilations) {
+ DML_LP_POOLING1_OPERATOR_DESC desc = {};
+ desc.P = kernelInfo.GetOptionalAttribute(AttrName::P, 2);
+ ML_CHECK_VALID_ARGUMENT(desc.P > 0);
+ desc.Dilations = m_kernel.dilations;
+ SetOpDesc(desc);
+ }
+ else {
+ DML_LP_POOLING_OPERATOR_DESC desc = {};
+ desc.P = kernelInfo.GetOptionalAttribute(AttrName::P, 2);
+ ML_CHECK_VALID_ARGUMENT(desc.P > 0);
+ SetOpDesc(desc);
+ }
+ break;
+ }
case DML_OPERATOR_MAX_POOLING:
case DML_OPERATOR_MAX_POOLING1:
case DML_OPERATOR_MAX_POOLING2:
@@ -152,7 +184,7 @@ class DmlOperatorPoolingTemplate : public DmlOperatorPooling
void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool* isSupported)
{
*isSupported = false;
-
+
MLOperatorAttributes attributes(context);
int storageOrder = attributes.GetOptionalAttribute(AttrName::StorageOrder, 0);
@@ -164,11 +196,11 @@ void CALLBACK QueryMaxPool(IMLOperatorSupportQueryContextPrivate* context, bool*
*isSupported = true;
}
-DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate);
+DML_OP_DEFINE_CREATION_FUNCTION(AveragePool, DmlOperatorPoolingTemplate);
DML_OP_DEFINE_CREATION_FUNCTION(GlobalAveragePool, DmlOperatorPoolingTemplate);
DML_OP_DEFINE_CREATION_FUNCTION(MaxPool, DmlOperatorPoolingTemplate);
DML_OP_DEFINE_CREATION_FUNCTION(GlobalMaxPool, DmlOperatorPoolingTemplate);
-DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate);
+DML_OP_DEFINE_CREATION_FUNCTION(LpPool, DmlOperatorPoolingTemplate);
DML_OP_DEFINE_CREATION_FUNCTION(GlobalLpPool, DmlOperatorPoolingTemplate);
} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index 28360f09bcba3..dbe9f5da4f569 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -667,6 +667,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
+ {REG_INFO( 19, AveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalAveragePool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 8, MaxPool, typeNameListMaxPool, supportedTypeListMaxPool, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryMaxPool)},
@@ -677,6 +678,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, GlobalMaxPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 11, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
+ {REG_INFO( 18, LpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, GlobalLpPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 7, MaxRoiPool, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_VER( 10, RoiAlign, typeNameListTwo, supportedTypeListRoiAlign, DmlGraphSupport::Supported)},
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index e18ba31def48a..3eb35faeba82f 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -406,6 +406,12 @@ namespace OperatorHelper
static const int sc_sinceVer_BitwiseNot = 18;
static const int sc_sinceVer_Pad = 18;
static const int sc_sinceVer_Split = 18;
+ static const int sc_sinceVer_LpPool = 18;
+ }
+
+ namespace OnnxOperatorSet19
+ {
+ static const int sc_sinceVer_AveragePool = 19;
}
namespace MsftOperatorSet1
diff --git a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc
index 10476ada2fa69..4b194ec18b31b 100644
--- a/onnxruntime/test/providers/cpu/nn/pool_op_test.cc
+++ b/onnxruntime/test/providers/cpu/nn/pool_op_test.cc
@@ -777,11 +777,6 @@ TEST(PoolTest, GlobalMaxPool3D) {
}
TEST(PoolTest, AveragePool) {
- // TODO: Unskip when fixed #41968513
- if (DefaultDmlExecutionProvider().get() != nullptr) {
- GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect.";
- }
-
OpTester test("AveragePool");
test.AddAttribute("auto_pad", "");
@@ -863,11 +858,6 @@ TEST(PoolTest, AveragePool) {
}
TEST(PoolTest, AveragePool_IncludePadPixel) {
- // TODO: Unskip when fixed #41968513
- if (DefaultDmlExecutionProvider().get() != nullptr) {
- GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect.";
- }
-
OpTester test("AveragePool");
test.AddAttribute("auto_pad", "");
@@ -911,11 +901,6 @@ TEST(PoolTest, AveragePool_DefaultStrides) {
}
TEST(PoolTest, AveragePool_10_ceil1_2d) {
- // TODO: Unskip when fixed #41968513
- if (DefaultDmlExecutionProvider().get() != nullptr) {
- GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect.";
- }
-
OpTester test("AveragePool", 10);
test.AddAttribute("auto_pad", "");
@@ -939,11 +924,6 @@ TEST(PoolTest, AveragePool_10_ceil1_2d) {
}
TEST(PoolTest, AveragePool_19_dilation_2d) {
- // TODO: Unskip when fixed #41968513
- if (DefaultDmlExecutionProvider().get() != nullptr) {
- GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect.";
- }
-
OpTester test("AveragePool", 19);
test.AddAttribute("auto_pad", "");
@@ -1070,11 +1050,6 @@ TEST(PoolTest, GlobalAveragePool_Large_256) {
}
TEST(PoolTest, LpPool) {
- // TODO: Unskip when fixed #41968513
- if (DefaultDmlExecutionProvider().get() != nullptr) {
- GTEST_SKIP() << "Skipping because of the following error: MLOperatorAuthorImpl.cpp(2100): The parameter is incorrect.";
- }
-
OpTester test("LpPool");
test.AddAttribute("auto_pad", "");
From 9ff5e3b7b0f5e9683422fa3609175fe4f82ccb77 Mon Sep 17 00:00:00 2001
From: raoanag <127366241+raoanag@users.noreply.github.com>
Date: Fri, 3 Nov 2023 09:34:35 -0700
Subject: [PATCH 008/677] Add QLinearConcat for DML EP (#16971) (#18268)
### Description
[Cherry Pick Reviewed]
```
[ OK ] QLinearConcatS8.ExpectFail_WrongZeroPointType_1 (372 ms)
[ RUN ] QLinearConcatS8.InputOne_Dynamic
[ OK ] QLinearConcatS8.InputOne_Dynamic (255 ms)
[ RUN ] QLinearConcatS8.InputOne_Const
[ OK ] QLinearConcatS8.InputOne_Const (255 ms)
[----------] 11 tests from QLinearConcatS8 (3385 ms total)
[----------] Global test environment tear-down
[==========] 21 tests from 3 test suites ran. (9355 ms total)
[ PASSED ] 21 tests.
```
[#16971](https://github.com/microsoft/onnxruntime/pull/16971)
### Motivation and Context
Co-authored-by: Xiang Zhang
---
.../Operators/DmlOperatorQLinearConcat.cpp | 236 ++++++++++++++++++
.../src/Operators/OperatorRegistration.cpp | 16 +-
.../src/Operators/OperatorUtility.cpp | 4 +-
.../src/Operators/OperatorUtility.h | 3 +-
.../OperatorAuthorHelper/OperatorHelper.cpp | 18 +-
.../dml/OperatorAuthorHelper/OperatorHelper.h | 25 +-
.../OperatorAuthorHelper/OperatorVersions.h | 1 +
7 files changed, 290 insertions(+), 13 deletions(-)
create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp
new file mode 100644
index 0000000000000..67711fdc28b84
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp
@@ -0,0 +1,236 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+
+namespace Dml
+{
+// QLinearConcat = Dequantize + Join + Quantize
+class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper
+{
+ // This order matches the ONNX schema.
+ enum OnnxInputIndex
+ {
+ YScale,
+ YZeroPoint,
+ Count,
+ };
+
+public:
+ DmlOperatorQLinearConcat(const MLOperatorKernelCreationContext& kernelCreationContext)
+ : DmlOperator(kernelCreationContext),
+ QLinearConcatHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
+ {
+ DmlOperator::Initialize(kernelCreationContext);
+
+ auto outputShape = kernelCreationContext.GetTensorShapeDescription().GetOutputTensorShape(0);
+
+ // inputs: {y_scale, y_zero_point, tuple(x_tensor, x_scale, x_zero_point)}
+ uint32_t inputDefinitionCount = kernelCreationContext.GetInputCount();
+ ML_CHECK_VALID_ARGUMENT(inputDefinitionCount >= 5, "Require at least 5 inputs.");
+ ML_CHECK_VALID_ARGUMENT((inputDefinitionCount - 2) % 3 == 0, "Each input must be (tensor, scale, zero_point) tuple!");
+
+ uint32_t inputCount = (inputDefinitionCount - 2) / 3;
+
+ auto yScaleDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YScale).tensorDataType;
+ auto yZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(OnnxInputIndex::YZeroPoint).tensorDataType;
+
+ // broadcast y_scale and y_zero_point to output shape
+ m_inputTensorDescs[OnnxInputIndex::YScale] = TensorDesc(
+ yScaleDataType,
+ outputShape,
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YScale),
+ TensorAxis::DoNotCoerce,
+ TensorAxis::W,
+ TensorAxis::RightAligned,
+ NchwDimensionCount, // minDimensionCount
+ 0 // guaranteedBaseOffsetAlignment
+ );
+
+ m_inputTensorDescs[OnnxInputIndex::YZeroPoint] = TensorDesc(
+ yZeroPointDataType,
+ outputShape,
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(OnnxInputIndex::YZeroPoint),
+ TensorAxis::DoNotCoerce,
+ TensorAxis::W,
+ TensorAxis::RightAligned,
+ NchwDimensionCount, // minDimensionCount
+ 0 // guaranteedBaseOffsetAlignment
+ );
+
+ // Validate input tensors
+ for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
+ {
+ // Inputs(input tensor, scale, zero_point) are in tuple and starting from index 2
+ auto tupleStartIndex = 2 + inputIndex * 3;
+ auto xScaleDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 1).tensorDataType;
+ auto xZeroPointDataType = kernelCreationContext.GetInputEdgeDescription(tupleStartIndex + 2).tensorDataType;
+ ML_CHECK_VALID_ARGUMENT(xScaleDataType == yScaleDataType, "Wrong input type encountered for scale");
+ ML_CHECK_VALID_ARGUMENT(xZeroPointDataType == yZeroPointDataType, "Wrong input type encountered for zero point");
+
+ // broadcast x_scale and x_zero_point to shape of corresponding x
+ m_inputTensorDescs[tupleStartIndex + 1] = TensorDesc(
+ xScaleDataType,
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 1),
+ TensorAxis::DoNotCoerce,
+ TensorAxis::W,
+ TensorAxis::RightAligned,
+ NchwDimensionCount, // minDimensionCount
+ 0 // guaranteedBaseOffsetAlignment
+ );
+
+ m_inputTensorDescs[tupleStartIndex + 2] = TensorDesc(
+ xZeroPointDataType,
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex + 2),
+ TensorAxis::DoNotCoerce,
+ TensorAxis::W,
+ TensorAxis::RightAligned,
+ NchwDimensionCount, // minDimensionCount
+ 0 // guaranteedBaseOffsetAlignment
+ );
+ }
+
+ uint32_t dmlAxis = GetDmlAdjustedAxis(m_axis, kernelCreationContext, m_inputTensorDescs.front().GetDimensionCount(), 2);
+
+ std::vector inputDescs = GetDmlInputDescs();
+ std::vector outputDescs = GetDmlOutputDescs();
+
+ // 1. output edges between Dequantize and Join node
+ // 2. input edge between Join and Quantize node
+ std::vector intermediateOutputTensorDescs(inputCount);
+ std::vector namedDequantizeOperatorDescs(inputCount);
+ std::vector dequantizeOperatorDescs(inputCount);
+ std::vector dmlOpDesc(inputCount);
+ std::vector opDescs;
+ for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
+ {
+ auto tupleStartIndex = 2 + inputIndex * 3;
+ intermediateOutputTensorDescs[inputIndex] = TensorDesc(
+ MLOperatorTensorDataType::Float,
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
+ kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(tupleStartIndex),
+ TensorAxis::DoNotCoerce,
+ TensorAxis::W,
+ TensorAxis::RightAligned,
+ NchwDimensionCount, // minDimensionCount
+ 0 // guaranteedBaseOffsetAlignment)
+ );
+ namedDequantizeOperatorDescs[inputIndex] = intermediateOutputTensorDescs[inputIndex].GetDmlDesc();
+
+ dequantizeOperatorDescs[inputIndex].InputTensor = &inputDescs[tupleStartIndex];
+ dequantizeOperatorDescs[inputIndex].ScaleTensor = &inputDescs[tupleStartIndex + 1];
+ dequantizeOperatorDescs[inputIndex].ZeroPointTensor = &inputDescs[tupleStartIndex + 2];
+ dequantizeOperatorDescs[inputIndex].OutputTensor = &namedDequantizeOperatorDescs[inputIndex];
+
+ dmlOpDesc[inputIndex] = {DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR, &dequantizeOperatorDescs[inputIndex]};
+ opDescs.push_back(&dmlOpDesc[inputIndex]);
+ }
+
+ TensorDesc joinOutputTensorDesc = TensorDesc(
+ MLOperatorTensorDataType::Float,
+ outputShape,
+ outputShape,
+ TensorAxis::DoNotCoerce,
+ TensorAxis::W,
+ TensorAxis::RightAligned,
+ NchwDimensionCount, // minDimensionCount
+ 0 // guaranteedBaseOffsetAlignment
+ );
+ DML_TENSOR_DESC namedJoinOutputTensorDesc = joinOutputTensorDesc.GetDmlDesc();
+
+ DML_JOIN_OPERATOR_DESC joinDesc = {};
+ joinDesc.InputCount = gsl::narrow_cast(namedDequantizeOperatorDescs.size());
+ joinDesc.InputTensors = namedDequantizeOperatorDescs.data();
+ joinDesc.OutputTensor = &namedJoinOutputTensorDesc;
+ joinDesc.Axis = dmlAxis;
+
+ const DML_OPERATOR_DESC opJoinDesc = {DML_OPERATOR_JOIN, &joinDesc};
+ opDescs.push_back(&opJoinDesc);
+
+ DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_DESC quantizeOperatorDesc = {};
+ quantizeOperatorDesc.InputTensor = joinDesc.OutputTensor;
+ quantizeOperatorDesc.ScaleTensor = &inputDescs[OnnxInputIndex::YScale];
+ quantizeOperatorDesc.ZeroPointTensor = &inputDescs[OnnxInputIndex::YZeroPoint];
+ quantizeOperatorDesc.OutputTensor = &outputDescs[0];
+ const DML_OPERATOR_DESC opQuantizeDesc = {DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR, &quantizeOperatorDesc};
+ opDescs.push_back(&opQuantizeDesc);
+
+ MLOperatorGraphDesc operatorGraphDesc = {};
+ operatorGraphDesc.nodeCount = static_cast(opDescs.size());
+ operatorGraphDesc.nodesAsOpDesc = opDescs.data();
+
+ uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2;
+ uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1;
+
+ std::vector inputEdges;
+ // Input edges to Dequantize nodes
+ for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
+ {
+ auto tupleStartIndex = 2 + inputIndex * 3;
+ for (auto edge_index = 0; edge_index < 3; ++edge_index)
+ {
+ DML_INPUT_GRAPH_EDGE_DESC inputEdge = {};
+ inputEdge.GraphInputIndex = tupleStartIndex + edge_index;
+ inputEdge.ToNodeIndex = inputIndex;
+ inputEdge.ToNodeInputIndex = edge_index;
+ inputEdges.push_back(inputEdge);
+ }
+ }
+
+ // Input edge from y_scale to quantize node
+ DML_INPUT_GRAPH_EDGE_DESC yScaleInputEdge = {};
+ yScaleInputEdge.GraphInputIndex = 0; // Y_scale
+ yScaleInputEdge.ToNodeIndex = quantizeNodeIndex;
+ yScaleInputEdge.ToNodeInputIndex = 1;
+ inputEdges.push_back(yScaleInputEdge);
+
+ // Input edge from y_zero_point to quantize node
+ DML_INPUT_GRAPH_EDGE_DESC yZeroPointInputEdge = {};
+ yZeroPointInputEdge.GraphInputIndex = 1; // Y_zero_point
+ yZeroPointInputEdge.ToNodeIndex = quantizeNodeIndex;
+ yZeroPointInputEdge.ToNodeInputIndex = 2;
+ inputEdges.push_back(yZeroPointInputEdge);
+
+ operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size());
+ operatorGraphDesc.inputEdges = inputEdges.data();
+
+ // set intermediate edges
+ std::vector intermediateEdges;
+ for (uint32_t inputIndex = 0; inputIndex < inputCount; ++inputIndex)
+ {
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC dequantizeToJoinEdge = {};
+ dequantizeToJoinEdge.FromNodeIndex = inputIndex;
+ dequantizeToJoinEdge.FromNodeOutputIndex = 0;
+ dequantizeToJoinEdge.ToNodeIndex = joinNodeIndex; // The second last node Join
+ dequantizeToJoinEdge.ToNodeInputIndex = inputIndex;
+ intermediateEdges.push_back(dequantizeToJoinEdge);
+ }
+
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC joinToQuantizeEdge = {};
+ joinToQuantizeEdge.FromNodeIndex = joinNodeIndex;
+ joinToQuantizeEdge.FromNodeOutputIndex = 0;
+ joinToQuantizeEdge.ToNodeIndex = quantizeNodeIndex; // The second last node Join
+ joinToQuantizeEdge.ToNodeInputIndex = 0;
+ intermediateEdges.push_back(joinToQuantizeEdge);
+
+ operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size());
+ operatorGraphDesc.intermediateEdges = intermediateEdges.data();
+
+ // set the output edges
+ std::vector outputEdges;
+ DML_OUTPUT_GRAPH_EDGE_DESC outputEdge = {};
+ outputEdge.FromNodeIndex = quantizeNodeIndex;
+ outputEdge.FromNodeOutputIndex = 0;
+ outputEdge.GraphOutputIndex = 0;
+ outputEdges.push_back(outputEdge);
+ operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size());
+ operatorGraphDesc.outputEdges = outputEdges.data();
+
+ SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
+ };
+};
+
+DML_OP_DEFINE_CREATION_FUNCTION(QLinearConcat, DmlOperatorQLinearConcat);
+} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index dbe9f5da4f569..fa2750a22425f 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -496,6 +496,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(ScatterND);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearAdd);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearConv);
DML_OP_EXTERN_CREATION_FUNCTION(QLinearMatMul);
+DML_OP_EXTERN_CREATION_FUNCTION(QLinearConcat);
DML_OP_EXTERN_CREATION_FUNCTION(DynamicQuantizeLinear);
DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger);
DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger);
@@ -547,6 +548,7 @@ constexpr static std::array typeNameListEyeLike = { "T1", "T2" }
constexpr static std::array typeNameShape = { "T", "T1" };
constexpr static std::array typeNameSize = { "T", "T1" };
constexpr static std::array typeNameListGroupNorm = {"T", "M"};
+constexpr static std::array typeNameListQLinearConcat= {"TF", "T8", "TV"};
constexpr static std::array supportedTypeListAll = {SupportedTensorDataTypes::All};
constexpr static std::array supportedTypeListFloat32 = {SupportedTensorDataTypes::Float32};
@@ -618,7 +620,18 @@ constexpr static std::array supportedTypeListQLinea
constexpr static std::array supportedTypeListDynamicQuantizeLinear = {
SupportedTensorDataTypes::Float32,
- SupportedTensorDataTypes::UInt8,
+ SupportedTensorDataTypes::Ints8Bit
+};
+
+constexpr static std::array supportedTypeListDynamicQuantizeMatMul= {
+ SupportedTensorDataTypes::Float32,
+ SupportedTensorDataTypes::Ints8Bit,
+};
+
+constexpr static std::array supportedTypeListQLinearConcat= {
+ SupportedTensorDataTypes::Float32,
+ SupportedTensorDataTypes::Ints8Bit,
+ SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32,
};
template
@@ -1012,6 +1025,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},
+ {REG_INFO_MS( 1, QLinearConcat, typeNameListQLinearConcat, supportedTypeListQLinearConcat, DmlGraphSupport::Supported)},
{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp
index d8290bbdaee3e..2965fa32ce131 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.cpp
@@ -419,9 +419,9 @@ namespace Dml
} // namespace FusionHelpers
- uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount)
+ uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex)
{
- const std::vector inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(0);
+ const std::vector inputDimensions = kernelCreationContext.GetTensorShapeDescription().GetInputTensorShape(firstInputIndex);
uint32_t onnxDimCount = gsl::narrow_cast(inputDimensions.size());
onnxAxis = HandleNegativeAxis(onnxAxis, onnxDimCount);
return GetDmlAdjustedAxis(onnxAxis, onnxDimCount, dmlDimCount);
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h
index f0fad6a05ffb0..8b2da6084242d 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorUtility.h
@@ -64,8 +64,7 @@ namespace Dml
} // namespace FusionHelpers
// Given an axis in ONNX axis numbering, return the axis adjusted for DML based on how the sizes have been coerced.
- // Note this function presumes the axis attribute is relative to the first input tensor (which is always the case).
- uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount);
+ uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, const MLOperatorKernelCreationContext& kernelCreationContext, uint32_t dmlDimCount, uint32_t firstInputIndex = 0);
uint32_t GetDmlAdjustedAxis(int32_t onnxAxis, uint32_t onnxDimCount, uint32_t dmlDimCount);
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
index 370f336ff5203..4d59964dcc664 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
@@ -1862,7 +1862,7 @@ namespace OperatorHelper
return { std::move(outputShape) };
}
- void ConcatHelper::Initialize(
+ void ConcatHelperBase::Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span inputDimensions
)
@@ -1872,13 +1872,13 @@ namespace OperatorHelper
ML_CHECK_VALID_ARGUMENT(m_axis < static_cast(inputDimensions.size()));
}
- std::vector ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+ std::vector ConcatHelperBase::GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const
{
- auto outputShape = shapeInfo.GetInputTensorShape(0);
+ auto outputShape = shapeInfo.GetInputTensorShape(firstInputIndex);
uint32_t inputCount = shapeInfo.GetInputCount();
- for (uint32_t i = 1; i < inputCount; ++i)
+ for (uint32_t i = firstInputIndex + step; i < inputCount; i += step)
{
auto inputShape = shapeInfo.GetInputTensorShape(i);
for (size_t j = 0; j < outputShape.size(); ++j)
@@ -1893,6 +1893,16 @@ namespace OperatorHelper
return { EdgeShapes(outputShape) };
}
+ std::vector ConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+ {
+ return ConcatHelperBase::GetOutputShapes(shapeInfo, 0, 1);
+ }
+
+ std::vector QLinearConcatHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+ {
+ return ConcatHelperBase::GetOutputShapes(shapeInfo, 2, 3);
+ }
+
void CropHelper::Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span inputDimensions
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index f7e545d9d99a9..55a01c59ee4b5 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -864,7 +864,7 @@ class RecurrentHelper
int m_hiddenSize = 0;
};
-class ConcatHelper
+class ConcatHelperBase
{
public:
void Initialize(
@@ -875,17 +875,33 @@ class ConcatHelper
// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template
- ConcatHelper(const Info_t& info, const Shape_t& shape)
+ ConcatHelperBase(const Info_t& info, const Shape_t& shape, uint32_t firstInputIndex)
{
- Initialize(info, shape.GetInputTensorShape(0));
+ Initialize(info, shape.GetInputTensorShape(firstInputIndex));
}
- std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo, uint32_t firstInputIndex, uint32_t step) const;
protected:
int m_axis;
};
+class ConcatHelper: public ConcatHelperBase
+{
+public:
+ template
+ ConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 0) {}
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+};
+
+class QLinearConcatHelper: public ConcatHelperBase
+{
+public:
+ template
+ QLinearConcatHelper(const Info_t& info, const Shape_t& shape) : ConcatHelperBase(info, shape, 2) {}
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+};
+
class CropHelper
{
public:
@@ -1512,6 +1528,7 @@ using ShapeInferenceHelper_Split13 = VersionedOpsetHelper;
using ShapeInferenceHelper_Split18 = VersionedOpsetHelper;
using ShapeInferenceHelper_Transpose = TransposeHelper;
using ShapeInferenceHelper_Concat = ConcatHelper;
+using ShapeInferenceHelper_QLinearConcat = QLinearConcatHelper;
using ShapeInferenceHelper_Slice7 = VersionedOpsetHelper;
using ShapeInferenceHelper_Slice10 = VersionedOpsetHelper;
using ShapeInferenceHelper_Slice11 = VersionedOpsetHelper; // Note 11 and 10 are identical - no functional change.
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 3eb35faeba82f..996ea1ddcb52c 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -443,6 +443,7 @@ namespace OperatorHelper
static const int sc_sinceVer_BiasAdd = 1;
static const int sc_sinceVer_QuickGelu = 1;
static const int sc_sinceVer_GroupNorm = 1;
+ static const int sc_sinceVer_QLinearConcat = 1;
static const int sc_sinceVer_RotaryEmbedding = 1;
} // namespace MsftOperatorSet1
From cb7f28a16ab78601363c2e694679d2dada149dd1 Mon Sep 17 00:00:00 2001
From: raoanag <127366241+raoanag@users.noreply.github.com>
Date: Fri, 3 Nov 2023 09:43:49 -0700
Subject: [PATCH 009/677] Register Resize for INT8 and UINT8 (#18252)
### Description
### Motivation and Context
Co-authored-by: Adrian Tsai
---
.../DmlExecutionProvider/src/Operators/OperatorRegistration.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index fa2750a22425f..d7910a6c6849f 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -589,7 +589,7 @@ constexpr static std::array supportedTypeListLogica
constexpr static std::array supportedTypeListLogicalComparison9 = /* A&B,C */ { SupportedTensorDataTypes::Float16to32|SupportedTensorDataTypes::Ints8to64, SupportedTensorDataTypes::Bool };
constexpr static std::array supportedTypeListSigned = { SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int64 | SupportedTensorDataTypes::Int32 | SupportedTensorDataTypes::Int16 | SupportedTensorDataTypes::Int8 };
constexpr static std::array supportedTypeListRange = {SupportedTensorDataTypes::Int16|SupportedTensorDataTypes::Int32|SupportedTensorDataTypes::Int64|SupportedTensorDataTypes::Float32};
-constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */};
+constexpr static std::array supportedTypeListResize11 = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Int8 | SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Float16to32 /* ROI read by CPU */};
constexpr static std::array supportedTypeListResize13 = supportedTypeListResize11;
constexpr static std::array supportedTypeListInteger = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8, SupportedTensorDataTypes::Int32 };
constexpr static std::array supportedTypeListInteger8 = {SupportedTensorDataTypes::Int8|SupportedTensorDataTypes::UInt8 };
From dcfff10f57fbbdf81ea9ebcae008be712be42700 Mon Sep 17 00:00:00 2001
From: raoanag <127366241+raoanag@users.noreply.github.com>
Date: Mon, 6 Nov 2023 09:09:11 -0800
Subject: [PATCH 010/677] Enable QLinearAveragePooling DML EP (#17384) (#18240)
[Cherry Pick Reviewed]
DML EP Implementation for
[QLinearAveragePool](https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAveragePool)
```
Note: Google Test filter = *QLinear*Pool*
[==========] Running 72 tests from 2 test suites.
[----------] Global test environment set-up.
[----------] 36 tests from QLinearGlobalAveragePool
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32
[ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32 (410 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1
[ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1 (641 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8 (156 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256
[ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256 (134 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7
[ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7 (160 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255
[ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255 (145 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8 (148 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255
[ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255 (129 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7
[ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7 (134 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256
[ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256 (131 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8
[ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8 (159 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256
[ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256 (168 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7
[ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7 (139 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255
[ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255 (170 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8
[ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8 (155 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255
[ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255 (156 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7
[ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7 (133 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256
[ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256 (149 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x1x32x32_S8 (131 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8
[ OK ] QLinearGlobalAveragePool.Nchw_1x32x32x1_S8 (127 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x256x8x8_S8 (153 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8
[ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x256_S8 (129 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x255x7x7_S8 (133 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8
[ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x255_S8 (135 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x255x8x8_S8 (129 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8
[ OK ] QLinearGlobalAveragePool.Nchw_1x8x8x255_S8 (152 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_1x256x7x7_S8 (140 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8
[ OK ] QLinearGlobalAveragePool.Nchw_1x7x7x256_S8 (133 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_3x256x8x8_S8 (135 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8
[ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x256_S8 (147 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_3x255x7x7_S8 (156 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8
[ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x255_S8 (155 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_3x255x8x8_S8 (138 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8
[ OK ] QLinearGlobalAveragePool.Nchw_3x8x8x255_S8 (155 ms)
[ RUN ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8
[ OK ] QLinearGlobalAveragePool.Nhwc_3x256x7x7_S8 (144 ms)
[ RUN ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8
[ OK ] QLinearGlobalAveragePool.Nchw_3x7x7x256_S8 (139 ms)
[----------] 36 tests from QLinearGlobalAveragePool (5968 ms total)
[----------] 36 tests from QLinearPoolTest
[ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel
[ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel (480 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel
[ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel (481 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel
[ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel (512 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel
[ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel (455 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel
[ OK ] QLinearPoolTest.AveragePool2D_MultiChannel (463 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel
[ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel (448 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel
[ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel (458 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc
[ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc (171 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc
[ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc (169 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc
[ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc (152 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc
[ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc (660 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc
[ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc (150 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc
[ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc (145 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc
[ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc (146 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_BigImage
[ OK ] QLinearPoolTest.AveragePool2D_BigImage (505 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc
[ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc (161 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_Global
[ OK ] QLinearPoolTest.AveragePool2D_Global (481 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc
[ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc (152 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8
[ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_S8 (461 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8
[ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_S8 (448 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8
[ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_S8 (471 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8
[ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_S8 (473 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_S8
[ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_S8 (1507 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8
[ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_S8 (477 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8
[ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_S8 (493 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool1D_ExcludePadPixel_nhwc_S8 (158 ms)
[ RUN ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool1D_IncludePadPixel_nhwc_S8 (146 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool2D_ExcludePadPixel_nhwc_S8 (146 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool2D_IncludePadPixel_nhwc_S8 (158 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool2D_MultiChannel_nhwc_S8 (157 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool3D_ExcludePadPixel_nhwc_S8 (145 ms)
[ RUN ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool3D_IncludePadPixel_nhwc_S8 (147 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_BigImage_S8
[ OK ] QLinearPoolTest.AveragePool2D_BigImage_S8 (537 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool2D_BigImage_nhwc_S8 (173 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_Global_S8
[ OK ] QLinearPoolTest.AveragePool2D_Global_S8 (457 ms)
[ RUN ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8
[ OK ] QLinearPoolTest.AveragePool2D_Global_nhwc_S8 (150 ms)
[----------] 36 tests from QLinearPoolTest (12914 ms total)
[----------] Global test environment tear-down
[==========] 72 tests from 2 test suites ran. (18885 ms total)
[ PASSED ] 72 tests.
memleakdbg:
----- No memory leaks detected -----
```
### Description
### Motivation and Context
---
.../src/External/DirectMLHelpers/ApiTraits.h | 20 ++-
.../External/DirectMLHelpers/DirectMLSchema.h | 25 +++
.../DirectMLHelpers/GeneratedSchemaHelpers.h | 26 +++
.../DmlOperatorQLinearAveragePooling.cpp | 150 ++++++++++++++++++
.../src/Operators/OperatorRegistration.cpp | 8 +
.../DmlExecutionProvider/src/TensorDesc.cpp | 36 +++++
.../dml/DmlExecutionProvider/src/TensorDesc.h | 3 +
.../dml/OperatorAuthorHelper/Attributes.h | 2 +-
.../OperatorAuthorHelper/OperatorHelper.cpp | 42 ++++-
.../dml/OperatorAuthorHelper/OperatorHelper.h | 28 +++-
.../OperatorAuthorHelper/OperatorVersions.h | 2 +
.../qlinear_global_average_pool_test.cc | 3 +
12 files changed, 339 insertions(+), 6 deletions(-)
create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
index 94f2220fcc168..a5415ba85f3d3 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h
@@ -24,7 +24,7 @@ struct EnumTraits
template <>
struct EnumTraits
{
- static constexpr auto ValueCount = 160;
+ static constexpr auto ValueCount = 161;
static constexpr size_t ActivationFunctionCount = 24;
};
@@ -495,6 +495,12 @@ struct OperatorDescTraits
static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING;
};
+template <>
+struct OperatorDescTraits
+{
+ static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING;
+};
+
template <>
struct OperatorDescTraits
{
@@ -1496,6 +1502,12 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING>
using DescType = DML_ROI_POOLING_OPERATOR_DESC;
};
+template <>
+struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING>
+{
+ using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC;
+};
+
template <>
struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE>
{
@@ -2522,6 +2534,12 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args
case DML_OPERATOR_ACTIVATION_GELU:
return std::invoke(std::forward(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward(args)...);
+#pragma warning(push)
+#pragma warning(disable: 4063)
+ case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
+ return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...);
+#pragma warning(pop)
+
default:
ORT_THROW_HR(E_INVALIDARG);
return std::invoke(std::forward(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward(args)...);
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
index 9eae1c1fe8158..2a82c12872a72 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h
@@ -869,6 +869,31 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA {
DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS,
};
+
+constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] {
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false },
+ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false },
+};
+
+constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA {
+ "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING",
+ static_cast(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING),
+ DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE,
+ 13,
+ DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS,
+};
+
constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] {
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false },
DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false },
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
index ad4cceb85cfd2..99218c135f058 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h
@@ -502,6 +502,24 @@ inline std::vector GetFields(const DML_ROI_POOLING_OPERATOR_DESC&
OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PooledSize))),
};
}
+inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc)
+{
+ return {
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)),
+ OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))),
+ };
+}
inline std::vector GetFields(const DML_SLICE_OPERATOR_DESC& desc)
{
return {
@@ -2509,6 +2527,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc)
return AbstractOperatorDesc(
&DML_ACTIVATION_GELU_OPERATOR_SCHEMA,
GetFields(*static_cast(opDesc.Desc)));
+#pragma warning(push)
+#pragma warning(disable: 4063)
+ case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING:
+ return AbstractOperatorDesc(
+ &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA,
+ GetFields(*static_cast(opDesc.Desc)));
+#pragma warning(pop)
+
default:
ORT_THROW_HR(E_INVALIDARG);
return AbstractOperatorDesc(
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp
new file mode 100644
index 0000000000000..0fccedfe311c1
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearAveragePooling.cpp
@@ -0,0 +1,150 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+
+namespace Dml
+{
+
+class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelperBase
+{
+ // For QLinear Avg Pool ORT and DML have same indexing order
+ enum OrtInputTensors : uint32_t
+ {
+ ortInput,
+ ortInputScale,
+ ortInputZeroPoint,
+ ortOutputScale,
+ ortOutputZeroPoint,
+ ortInputCount
+ };
+
+public:
+ using Self = DmlOperatorQLinearAveragePooling;
+
+ DmlOperatorQLinearAveragePooling(
+ const MLOperatorKernelCreationContext& kernelInfo,
+ bool useGlobalPooling
+ )
+ : DmlOperator(kernelInfo),
+ PoolingHelperBase(kernelInfo, kernelInfo.GetTensorShapeDescription(), useGlobalPooling)
+ {
+ DmlOperator::Initialize(kernelInfo);
+
+ bool isNhwc = m_kernel.channelsLast;
+ std::vector inputShape = kernelInfo.GetTensorShapeDescription().GetInputTensorShape(OrtInputTensors::ortInput);
+ std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
+
+ uint32_t dmlDimSize = m_inputTensorDescs[OrtInputTensors::ortInput].GetDimensionCount();
+ ML_CHECK_VALID_ARGUMENT(dmlDimSize >= 2);
+
+ // DML requires that DimensionCount be equal to Input.dmlDimSize - 2 for Pooling
+ uint32_t expectedSpatialDimCount = m_inputTensorDescs[0].GetDimensionCount() - 2;
+ if (m_kernel.spatialDimensionCount < expectedSpatialDimCount)
+ {
+ size_t shift = expectedSpatialDimCount - m_kernel.spatialDimensionCount;
+
+ for (int i = gsl::narrow_cast(m_kernel.spatialDimensionCount) - 1; i >= 0; i--)
+ {
+ m_kernel.windowSize[i + shift] = m_kernel.windowSize[i];
+ m_kernel.windowSize[i] = 1;
+
+ m_kernel.strides[i + shift] = m_kernel.strides[i];
+ m_kernel.strides[i] = 1;
+
+ m_kernel.startPadding[i + shift] = m_kernel.startPadding[i];
+ m_kernel.startPadding[i] = 0;
+
+ m_kernel.endPadding[i + shift] = m_kernel.endPadding[i];
+ m_kernel.endPadding[i] = 0;
+
+ m_kernel.dilations[i + shift] = m_kernel.dilations[i];
+ m_kernel.dilations[i] = 1;
+ }
+
+ m_kernel.spatialDimensionCount = expectedSpatialDimCount;
+ }
+
+ // Initialize dimensionMapping for NCHW or NHWC layout
+ std::vector dimensionMapping = {0u, dmlDimSize - 1u};
+ dimensionMapping.resize(dmlDimSize);
+ if (isNhwc)
+ {
+ // Form a remapping for dimensions so C is moved before the spatial dimensions.
+ // e.g. NWC -> {0,2,1} -> NCW
+ // NHWC -> {0,3,1,2} -> NCHW
+ // NDHWC -> {0,4,1,2,3} -> NCDHW
+ std::iota(dimensionMapping.begin() + 2, dimensionMapping.end(), 1u);
+ }
+ else
+ {
+ // Use NCHW {0,1,2,3} format with increasing order of indexs
+ std::iota(dimensionMapping.begin() + 1, dimensionMapping.end(), 1u);
+ }
+ m_inputTensorDescs[OrtInputTensors::ortInput].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
+
+ // Reshape the Input Scale to be the same dimension as the input tensor.
+ // The 1D tensor needs to be moved to the H channel.
+ m_inputTensorDescs[OrtInputTensors::ortInputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
+
+ // Reshape the Input ZeroPoint to be the same dimension as the input tensor.
+ // The 1D tensor needs to be moved to the H channel.
+ if (kernelInfo.IsInputValid(OrtInputTensors::ortInputZeroPoint))
+ {
+ m_inputTensorDescs[OrtInputTensors::ortInputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
+ }
+
+ // Reshape the Output Scale to be the same dimension as the input tensor.
+ // The 1D tensor needs to be moved to the H channel.
+ m_inputTensorDescs[OrtInputTensors::ortOutputScale].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
+
+ // Reshape the Input ZeroPoint to be the same dimension as the input tensor.
+ // The 1D tensor needs to be moved to the H channel.
+ if (kernelInfo.IsInputValid(OrtInputTensors::ortOutputZeroPoint))
+ {
+ m_inputTensorDescs[OrtInputTensors::ortOutputZeroPoint].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
+ }
+
+ // Initialize the output description while overriding the shape
+ m_outputTensorDescs[0].PermuteDimensions(dimensionMapping, TensorAxis::LeftAligned);
+
+ assert(m_kernel.spatialDimensionCount <= ARRAYSIZE(m_kernel.windowSize));
+
+ std::vector inputDescs = GetDmlInputDescs();
+ std::vector outputDescs = GetDmlOutputDescs();
+
+ DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC qLinearAvgPooldesc = {};
+
+ qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput];
+ qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale];
+ qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint];
+ qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];;
+ qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];;
+ qLinearAvgPooldesc.OutputTensor = &outputDescs[0];
+ qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount;
+ qLinearAvgPooldesc.WindowSize = m_kernel.windowSize;
+ qLinearAvgPooldesc.Strides = m_kernel.strides;
+ qLinearAvgPooldesc.StartPadding = m_kernel.startPadding;
+ qLinearAvgPooldesc.EndPadding = m_kernel.endPadding;
+ qLinearAvgPooldesc.Dilations = m_kernel.dilations;
+ qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute(AttrName::CountIncludePad, false);
+
+ DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc };
+ SetDmlOperatorDesc(opDesc, kernelInfo);
+ }
+};
+
+template
+class DmlOperatorQuantizedPoolingTemplate : public DmlOperatorQLinearAveragePooling
+{
+public:
+ DmlOperatorQuantizedPoolingTemplate(const MLOperatorKernelCreationContext& kernelInfo)
+ : DmlOperatorQLinearAveragePooling(kernelInfo, UseGlobalPooling)
+ {
+ }
+};
+
+DML_OP_DEFINE_CREATION_FUNCTION(QLinearAveragePool, DmlOperatorQuantizedPoolingTemplate);
+DML_OP_DEFINE_CREATION_FUNCTION(QLinearGlobalAveragePool, DmlOperatorQuantizedPoolingTemplate);
+
+} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index d7910a6c6849f..0234bb6b7ec1e 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -320,6 +320,8 @@ DML_OP_EXTERN_CREATION_FUNCTION(GlobalMaxPool);
DML_OP_EXTERN_CREATION_FUNCTION(LpPool);
DML_OP_EXTERN_CREATION_FUNCTION(GlobalLpPool);
DML_OP_EXTERN_CREATION_FUNCTION(MaxRoiPool);
+DML_OP_EXTERN_CREATION_FUNCTION(QLinearAveragePool);
+DML_OP_EXTERN_CREATION_FUNCTION(QLinearGlobalAveragePool);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign10);
DML_OP_EXTERN_CREATION_FUNCTION(RoiAlign16);
DML_OP_EXTERN_CREATION_FUNCTION(InstanceNormalization);
@@ -634,6 +636,10 @@ constexpr static std::array supportedTypeListQLinea
SupportedTensorDataTypes::Ints8Bit|SupportedTensorDataTypes::Float32,
};
+constexpr static std::array supportedTypeListQLinearAveragePool = {
+ SupportedTensorDataTypes::Ints8Bit
+};
+
template
constexpr auto requiredConstantCpuInputs(Args... args)
{
@@ -1040,6 +1046,8 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 9, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))},
{REG_INFO( 11, MaxUnpool, typeNameListTwo, supportedTypeListMaxUnpool, DmlGraphSupport::Supported, requiredConstantCpuInputs(2))}, // 11 is identical to 9.
+ {REG_INFO_MS( 1, QLinearAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)},
+ {REG_INFO_MS( 1, QLinearGlobalAveragePool, typeNameListDefault, supportedTypeListQLinearAveragePool, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QLinearAdd, typeNameListDefault, supportedTypeListInteger8, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearConv, typeNameListFour, supportedTypeListQLinearConv, DmlGraphSupport::Supported)},
{REG_INFO( 10, QLinearMatMul, typeNameListThree, supportedTypeListQLinearMatMul, DmlGraphSupport::Supported)},
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp
index 067a320dd8000..a2183aab52eed 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.cpp
@@ -315,3 +315,39 @@ void TensorDesc::SetDimensionCount(uint32_t newDimensionCount, TensorAxis alignm
}
m_bufferTensorDesc.DimensionCount = newDimensionCount;
}
+
+// Uses dimensionMapping to reorder m_sizes and m_strides to match specific Tensor layout
+void TensorDesc::PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment)
+{
+ EnsureStridesExist();
+ SetDimensionCount(static_cast(dimensionMapping.size()), alignment);
+
+ // Shuffle m_sizes and m_strides according to the indexes pointed by dimensionMapping
+ std::vector tempSizes{m_sizes, m_sizes + MaximumDimensionCount};
+ std::vector tempStrides{m_strides, m_strides + MaximumDimensionCount};
+
+ for (size_t i = 0; i < dimensionMapping.size(); i++)
+ {
+ m_sizes[i] = tempSizes[dimensionMapping[i]];
+ m_strides[i] = tempStrides[dimensionMapping[i]];
+ }
+
+ m_bufferTensorDesc.Sizes = m_sizes;
+ m_bufferTensorDesc.Strides = m_strides;
+}
+
+void TensorDesc::EnsureStridesExist()
+{
+ if (m_bufferTensorDesc.Strides != nullptr)
+ {
+ // Strides are populated
+ return;
+ }
+
+ uint32_t stride = 1;
+ for (uint32_t i = m_bufferTensorDesc.DimensionCount; i-- > 0;)
+ {
+ m_strides[i] = stride;
+ stride *= m_sizes[i];
+ }
+}
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h
index ff70dec5b8871..909e2084d0163 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/TensorDesc.h
@@ -44,6 +44,7 @@ namespace Dml
gsl::span GetSizes() const { return { m_sizes, m_sizes + m_bufferTensorDesc.DimensionCount }; }
gsl::span GetStrides() const;
void SetStrides(gsl::span strides);
+ void PermuteDimensions(gsl::span dimensionMapping, const TensorAxis alignment);
inline uint64_t GetBufferSizeInBytes() const
{
@@ -90,6 +91,8 @@ namespace Dml
uint32_t m_sizes[MaximumDimensionCount] = {};
uint32_t m_strides[MaximumDimensionCount] = {};
DML_BUFFER_TENSOR_DESC m_bufferTensorDesc = {};
+
+ void EnsureStridesExist();
};
class TensorDescBuilder
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
index e9591cfce6870..85333aa77b686 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
@@ -23,8 +23,8 @@ namespace AttrName
static constexpr const char* BlockSize = "blocksize";
static constexpr const char* Border = "border";
static constexpr const char* Broadcast = "broadcast";
- static constexpr const char* ChannelsLast = "channels_last";
static constexpr const char* CeilMode = "ceil_mode";
+ static constexpr const char* ChannelsLast = "channels_last";
static constexpr const char* Clip = "clip";
static constexpr const char* CoordinateTransformationMode = "coordinate_transformation_mode";
static constexpr const char* CountIncludePad = "count_include_pad";
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
index 4d59964dcc664..1fcd3b04300f4 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
@@ -365,13 +365,20 @@ namespace OperatorHelper
}
// Creates a kernel that spans the entire spatial dimensions of the input.
- KernelArgs InitializeGlobalKernel(gsl::span inputDimensions)
+ KernelArgs InitializeGlobalKernel(
+ const MLOperatorAttributes& kernelInfo,
+ gsl::span inputDimensions)
{
ML_CHECK_VALID_ARGUMENT(inputDimensions.size() > NonspatialDimensionCount); // Must be at least 1D convolution (in 3D tensor)
uint32_t spatialDimensionCount = gsl::narrow_cast(inputDimensions.size()) - NonspatialDimensionCount;
ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount); // Support up to 3D convolution (in 5D tensor).
KernelArgs args(spatialDimensionCount);
+ args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0);
+ args.channelsLast = kernelInfo.GetOptionalAttribute(AttrName::ChannelsLast, 0);
+ // For Global Pooling, kernel size equal to the spatial dimension of input tensor
+ // NHWC layout need to offset by one dim to acount for channel placed at the end
+ int dimOffset = args.channelsLast ? 1 : 0;
for (size_t dim = 0; dim < spatialDimensionCount; ++dim)
{
@@ -379,7 +386,7 @@ namespace OperatorHelper
args.dilations[dim] = 1;
args.startPadding[dim] = 0;
args.endPadding[dim] = 0;
- args.windowSize[dim] = gsl::narrow_cast(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim]);
+ args.windowSize[dim] = gsl::narrow_cast(inputDimensions[inputDimensions.size() - spatialDimensionCount + dim - dimOffset]);
}
return args;
@@ -495,6 +502,7 @@ namespace OperatorHelper
}
args.useCeilingOutputShape = kernelInfo.GetOptionalAttribute(AttrName::CeilMode, 0);
+ args.channelsLast = kernelInfo.GetOptionalAttribute(AttrName::ChannelsLast, 0);
return args;
}
@@ -2012,7 +2020,37 @@ namespace OperatorHelper
}
return outputShapes;
}
+
+ std::vector QLinearAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+ {
+ auto inputShape = shapeInfo.GetInputTensorShape(0);
+ std::vector outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast);
+
+ const uint32_t outputCount = shapeInfo.GetOutputCount();
+
+ std::vector outputShapes;
+ for (uint32_t i = 0; i < outputCount; ++i)
+ {
+ outputShapes.push_back(outputDimensions);
+ }
+ return outputShapes;
+ }
+
+ std::vector QLinearGlobalAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+ {
+ auto inputShape = shapeInfo.GetInputTensorShape(0);
+ std::vector outputDimensions = InitializeKernelOutputDimensions(inputShape, m_kernel, m_kernel.channelsLast);
+ const uint32_t outputCount = shapeInfo.GetOutputCount();
+
+ std::vector outputShapes;
+ for (uint32_t i = 0; i < outputCount; ++i)
+ {
+ outputShapes.push_back(outputDimensions);
+ }
+ return outputShapes;
+ }
+
std::vector RoiPoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS);
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 55a01c59ee4b5..d8d09efd8d6e8 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -160,6 +160,7 @@ struct KernelArgs
bool autoPad = false;
bool autoPadSameUpper = false;
bool useCeilingOutputShape = false;
+ bool channelsLast = false;
uint32_t spatialDimensionCount = 0;
KernelArgs(uint32_t spatialDimensionCount) : spatialDimensionCount(spatialDimensionCount)
@@ -188,6 +189,7 @@ struct KernelArgs
KernelArgs(KernelArgs const& kernelArgs, uint32_t minimumDimensionCount)
: autoPad(kernelArgs.autoPad),
autoPadSameUpper(kernelArgs.autoPadSameUpper),
+ channelsLast(kernelArgs.channelsLast),
spatialDimensionCount(std::max(kernelArgs.spatialDimensionCount, minimumDimensionCount))
{
ML_CHECK_VALID_ARGUMENT(spatialDimensionCount <= NcdhwSpatialDimensionCount);
@@ -211,7 +213,9 @@ std::vector InitializeKernelOutputDimsTranspose(
gsl::span inputDimensions,
const KernelArgs& args);
-KernelArgs InitializeGlobalKernel(gsl::span inputDimensions);
+KernelArgs InitializeGlobalKernel(
+ const MLOperatorAttributes& kernelInfo,
+ gsl::span inputDimensions);
KernelArgs InitializeKernel(
const MLOperatorAttributes& kernelInfo,
@@ -1059,7 +1063,7 @@ class PoolingHelperBase
bool useGlobalPooling
)
: m_kernel(useGlobalPooling
- ? InitializeGlobalKernel(shape.GetInputTensorShape(0))
+ ? InitializeGlobalKernel(info, shape.GetInputTensorShape(0))
: InitializeKernel(info, static_cast(shape.GetInputTensorShape(0).size()), gsl::span()))
{
if (!useGlobalPooling)
@@ -1161,6 +1165,24 @@ class RoiAlignHelper : public RoiPoolingHelperBase
std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
};
+class QLinearAveragePoolingHelper : public PoolingHelperBase
+{
+public:
+ template
+ QLinearAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, false) {}
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+
+};
+
+class QLinearGlobalAveragePoolingHelper : public PoolingHelperBase
+{
+public:
+ template
+ QLinearGlobalAveragePoolingHelper(const Info_t& info, const Shape_t& shape) : PoolingHelperBase(info, shape, true) {}
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+
+};
+
class SqueezeHelper
{
public:
@@ -1490,6 +1512,8 @@ using ShapeInferenceHelper_MaxUnpool = UnpoolingHelper;
using ShapeInferenceHelper_LpPool = PoolingHelper;
using ShapeInferenceHelper_GlobalLpPool = GlobalPoolingHelper;
using ShapeInferenceHelper_MaxRoiPool = RoiPoolingHelper;
+using ShapeInferenceHelper_QLinearAveragePool = QLinearAveragePoolingHelper;
+using ShapeInferenceHelper_QLinearGlobalAveragePool = QLinearGlobalAveragePoolingHelper;
using ShapeInferenceHelper_RoiAlign10 = VersionedOpsetHelper;
using ShapeInferenceHelper_RoiAlign16 = VersionedOpsetHelper;
using ShapeInferenceHelper_InstanceNormalization = GetOutputShapeAsInputShapeHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 996ea1ddcb52c..e9d88adf3e221 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -445,6 +445,8 @@ namespace OperatorHelper
static const int sc_sinceVer_GroupNorm = 1;
static const int sc_sinceVer_QLinearConcat = 1;
static const int sc_sinceVer_RotaryEmbedding = 1;
+ static const int sc_sinceVer_QLinearAveragePool = 1;
+ static const int sc_sinceVer_QLinearGlobalAveragePool = 1;
} // namespace MsftOperatorSet1
} // namespace OperatorHelper
diff --git a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc
index 8fb245819fd26..71b6f27b5391f 100644
--- a/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc
+++ b/onnxruntime/test/contrib_ops/qlinear_global_average_pool_test.cc
@@ -66,6 +66,9 @@ void RunQLinearGlobalAveragePool(
test.AddInput("y_scale", {}, {y_scale});
test.AddInput("y_zero_point", {}, {y_zero_point});
test.AddOutput("Y", y_dims, y_data);
+ if (channels_last) {
+ test.AddAttribute("channels_last", (int64_t)1LL);
+ }
auto q8checker = [&](const std::vector& fetches, const std::string& provider_type) {
const OrtValue& ort_value = fetches[0];
From d5f3aae3fd30a79d9bbeed26a70907618ac84835 Mon Sep 17 00:00:00 2001
From: raoanag <127366241+raoanag@users.noreply.github.com>
Date: Fri, 17 Nov 2023 16:43:09 -0800
Subject: [PATCH 011/677] Utilize DML constant input graph node (#18267)
### Description
This PR also includes,
8b0a55e7cc DML constant pow operator
7520974970 Enable custom heaps based on query-
### Motivation and Context
---------
Co-authored-by: Jeff Bloomfield
---
.../src/DmlGraphFusionHelper.cpp | 27 ++++-
.../src/ExecutionProvider.cpp | 31 ++++++
.../src/ExecutionProvider.h | 2 +
.../src/GraphDescBuilder.cpp | 104 ++++++++++++++----
.../src/GraphDescBuilder.h | 5 +-
.../src/IExecutionProvider.h | 1 +
.../src/MLOperatorAuthorImpl.cpp | 29 ++++-
.../src/MLOperatorAuthorImpl.h | 7 ++
.../src/Operators/DmlOperatorElementWise.cpp | 38 +++++--
.../MLOperatorAuthorHelper.h | 13 +++
.../MLOperatorAuthorPrivate.h | 10 ++
11 files changed, 229 insertions(+), 38 deletions(-)
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp
index 4f7ec188140b5..18cdc5d1bf86e 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp
@@ -226,8 +226,7 @@ namespace DmlGraphFusionHelper
{
ComPtr initializeInputBuffer;
- // D3D_FEATURE_LEVEL_1_0_CORE doesn't support Custom heaps
- if (providerImpl->IsMcdmDevice())
+ if (!providerImpl->CustomHeapsSupported())
{
initializeInputBuffer = CreateResource(providerImpl, tensorPtr, tensorByteSize);
}
@@ -294,6 +293,7 @@ namespace DmlGraphFusionHelper
const uint32_t inputCount,
const uint32_t outputCount,
_Inout_ std::vector& dmlOperatorGraphNodes,
+ _Inout_ std::vector& dmlConstantGraphNodes,
_Inout_ std::vector& dmlGraphNodes,
_Inout_ std::vector& dmlInputEdges,
_Inout_ std::vector& dmlOutputEdges,
@@ -302,8 +302,24 @@ namespace DmlGraphFusionHelper
for (size_t i = 0; i < graphDesc.nodes.size(); ++i)
{
auto& nodeInfo = graphDesc.nodes[i];
- dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{nodeInfo.op.Get(), nodeInfo.name.data()};
- dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
+
+ if (std::holds_alternative>(nodeInfo.nodeDef))
+ {
+ dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()};
+ dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
+ }
+ else
+ {
+ auto& nodeDefinitionData = std::get>(nodeInfo.nodeDef);
+ dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC{
+ nodeDefinitionData.data(),
+ nodeDefinitionData.size(),
+ nodeInfo.name.data()
+ };
+
+ // TODO: Change as new header is ingested
+ dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{static_cast(2), &dmlConstantGraphNodes[i]};
+ }
}
for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i)
@@ -392,6 +408,8 @@ namespace DmlGraphFusionHelper
// convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator
DML_GRAPH_DESC dmlGraphDesc = {};
std::vector dmlOperatorGraphNodes(graphDesc.nodes.size());
+ std::vector dmlConstantGraphNodes(graphDesc.nodes.size());
+
std::vector dmlGraphNodes(graphDesc.nodes.size());
std::vector dmlInputEdges(graphDesc.inputEdges.size());
std::vector dmlOutputEdges(graphDesc.outputEdges.size());
@@ -402,6 +420,7 @@ namespace DmlGraphFusionHelper
fusedNodeInputCount,
fusedNodeOutputCount,
dmlOperatorGraphNodes,
+ dmlConstantGraphNodes,
dmlGraphNodes,
dmlInputEdges,
dmlOutputEdges,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
index 8644b8d56a426..49a64c4810252 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp
@@ -182,6 +182,32 @@ namespace Dml
}
m_isMcdmDevice = (featureLevels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE);
+ m_areCustomHeapsSupported = !m_isMcdmDevice;
+
+ if (m_isMcdmDevice)
+ {
+
+ // TODO: Ingest updated header file
+ typedef struct D3D12_FEATURE_DATA_D3D12_OPTIONS19
+ {
+ BOOL MismatchingOutputDimensionsSupported;
+ UINT SupportedSampleCountsWithNoOutputs;
+ BOOL PointSamplingAddressesNeverRoundUp;
+ BOOL RasterizerDesc2Supported;
+ BOOL NarrowQuadrilateralLinesSupported;
+ BOOL AnisoFilterWithPointMipSupported;
+ UINT MaxSamplerDescriptorHeapSize;
+ UINT MaxSamplerDescriptorHeapSizeWithStaticSamplers;
+ UINT MaxViewDescriptorHeapSize;
+ _Out_ BOOL ComputeOnlyCustomHeapSupported;
+ } D3D12_FEATURE_DATA_D3D12_OPTIONS19;
+
+ D3D12_FEATURE_DATA_D3D12_OPTIONS19 options19 = {};
+
+ // The call may fail in which case the default value is false
+ d3d12Device->CheckFeatureSupport(static_cast(48) /*D3D12_FEATURE_D3D12_OPTIONS19*/, &options19, sizeof(options19));
+ m_areCustomHeapsSupported = options19.ComputeOnlyCustomHeapSupported;
+ }
m_context = std::make_shared(m_d3d12Device.Get(), m_dmlDevice.Get(), queue);
@@ -1089,6 +1115,11 @@ namespace Dml
return m_isMcdmDevice;
}
+ bool __stdcall ExecutionProviderImpl::CustomHeapsSupported() const noexcept
+ {
+ return m_areCustomHeapsSupported;
+ }
+
bool __stdcall ExecutionProviderImpl::MetacommandsEnabled() const noexcept
{
return m_areMetacommandsEnabled;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
index 3aaa11cdee479..ab932fb8a4367 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h
@@ -150,6 +150,7 @@ namespace Dml
}
STDMETHOD_(bool, IsMcdmDevice)() const noexcept final;
+ STDMETHOD_(bool, CustomHeapsSupported)() const noexcept final;
STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final;
bool DynamicGraphFusionEnabled() const noexcept;
@@ -186,6 +187,7 @@ namespace Dml
ComPtr m_d3d12Device;
ComPtr m_dmlDevice;
bool m_isMcdmDevice = false;
+ bool m_areCustomHeapsSupported = false;
bool m_areMetacommandsEnabled = true;
bool m_dynamicGraphFusionEnabled = false;
bool m_native16BitShaderOpsSupported = false;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
index 3fc8f415e5a58..ba022533a1e94 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
@@ -149,7 +149,7 @@ namespace Dml::GraphDescBuilder
const std::unordered_map>& isInitializerTransferable,
const std::unordered_map& graphNodePropertyMap,
IDMLDevice* device,
- const void* executionHandle,
+ const ExecutionProviderImpl* executionHandle,
const onnxruntime::Path& modelPath,
gsl::span subgraphNodes,
gsl::span