diff --git a/cmake/onnxruntime_providers_coreml.cmake b/cmake/onnxruntime_providers_coreml.cmake
index b8ebc4ca53239..0aa25a221bf27 100644
--- a/cmake/onnxruntime_providers_coreml.cmake
+++ b/cmake/onnxruntime_providers_coreml.cmake
@@ -126,10 +126,12 @@ endif()
if (APPLE)
file(GLOB
onnxruntime_providers_coreml_objcc_srcs CONFIGURE_DEPENDS
- "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h"
- "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.mm"
"${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.h"
"${ONNXRUNTIME_ROOT}/core/providers/coreml/model/host_utils.mm"
+ "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.h"
+ "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/model.mm"
+ "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/objc_str_utils.h"
+ "${ONNXRUNTIME_ROOT}/core/providers/coreml/model/objc_str_utils.mm"
)
else()
# add the Model implementation that uses the protobuf types but excludes any actual CoreML dependencies
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 8fa67ee172733..635f65696eae2 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -774,7 +774,9 @@ Do not modify directly.*
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
-|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|ScatterND|*in* data:**T**
*in* indices:**tensor(int64)**
*in* updates:**T**
*out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Selu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|SequenceAt|*in* input_sequence:**S**
*in* position:**I**
*out* tensor:**T**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts
index e8dc702d6b3b7..db9bb73e394c7 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts
@@ -264,7 +264,7 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
let local_offset = local_idx * uniforms.elements_per_thread;
let offset = workgroup_id.x * uniforms.d_comp + local_offset;
- var thread_max_vector = ${inputHelper.type.value}(-3.402823e+38f);
+ var thread_max_vector = ${f32Type}(-3.402823e+38f);
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
thread_max_vector = max(${f32Type}(x[offset + i]), thread_max_vector);
}
@@ -282,12 +282,12 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
})()};
workgroupBarrier();
- var max_value: f32 = -3.402823e+38f;
+ var max_value = -3.402823e+38f;
for (var i = 0u; i < ${WG}; i++) {
max_value = max(thread_max[i], max_value);
}
- var sum_vector = ${inputHelper.type.value}(${0});
+ var sum_vector = ${f32Type}(${0});
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
sum_vector += exp(${f32Type}(x[offset + i]) - max_value);
}
@@ -333,9 +333,9 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
const createAttentionProbsProgramInfo =
(_context: ComputeContext, q: TensorView, key: TensorView, relativePositionBias: TensorView|undefined,
- parameters: AttentionParameters, attributes: AttentionAttrs) => {
- const probsShape =
- [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, parameters.totalSequenceLength];
+ parameters: AttentionParameters, attributes: AttentionAttrs, pastSequenceLength: number) => {
+ const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
+ const probsShape = [parameters.batchSize, parameters.numHeads, parameters.sequenceLength, totalSequenceLength];
// TODO: handle mask
@@ -344,14 +344,13 @@ const createAttentionProbsProgramInfo =
const vectorizedHeadSize = parameters.headSize / components;
const TILE_SIZE = 12;
const dispatch = {
- x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE),
+ x: Math.ceil(totalSequenceLength / TILE_SIZE),
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
z: parameters.batchSize * parameters.numHeads
};
const programUniforms: ProgramUniform[] = [
{type: DataType.uint32, data: parameters.sequenceLength}, {type: DataType.uint32, data: vectorizedHeadSize},
- {type: DataType.uint32, data: parameters.totalSequenceLength},
- {type: DataType.uint32, data: parameters.numHeads}, {type: DataType.uint32, data: parameters.kvSequenceLength},
+ {type: DataType.uint32, data: totalSequenceLength}, {type: DataType.uint32, data: parameters.numHeads},
{type: q.dataType, data: alpha}
];
@@ -376,8 +375,7 @@ const createAttentionProbsProgramInfo =
const uniforms: UniformsArrayType = [
{name: 'M', type: 'u32'}, {name: 'K', type: 'u32'}, {name: 'N', type: 'u32'},
- {name: 'num_heads', type: 'u32'}, {name: 'kv_sequence_length', type: 'u32'},
- {name: 'alpha', type: dataType as UniformDataElementType}
+ {name: 'num_heads', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType}
];
return `
const beta: ${dataType} = 1.0;
@@ -394,7 +392,7 @@ const createAttentionProbsProgramInfo =
let m = workgroup_id.y * TILE_SIZE;
let n = workgroup_id.x * TILE_SIZE;
let qOffset = uniforms.M * uniforms.K * headIdx + m * uniforms.K;
- let kOffset = uniforms.kv_sequence_length * uniforms.K * headIdx + n * uniforms.K;
+ let kOffset = uniforms.N * uniforms.K * headIdx + n * uniforms.K;
var value = ${qInput.type.value}(0);
for (var w: u32 = 0u; w < uniforms.K; w += TILE_SIZE) {
@@ -456,7 +454,9 @@ const createAttentionProbsProgramInfo =
const createVxAttentionScoreProgramInfo =
- (_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => {
+ (_context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters,
+ pastSequenceLength: number) => {
+ const totalSequenceLength = pastSequenceLength + params.kvSequenceLength;
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize];
const TILE_SIZE = 12;
const dispatch = {
@@ -465,7 +465,7 @@ const createVxAttentionScoreProgramInfo =
z: params.batchSize * params.numHeads
};
const programUniforms: ProgramUniform[] = [
- {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: params.totalSequenceLength},
+ {type: DataType.uint32, data: params.sequenceLength}, {type: DataType.uint32, data: totalSequenceLength},
{type: DataType.uint32, data: params.vHeadSize}, {type: DataType.uint32, data: params.numHeads},
{type: DataType.uint32, data: params.vHiddenSize}
];
@@ -537,24 +537,25 @@ export const applyAttention =
(context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined,
_past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined,
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
+ const outputPresentKey = context.outputCount > 1;
+ const outputPresentValue = context.outputCount > 2;
+ const pastSequenceLength = (outputPresentKey && outputPresentValue) ? parameters.pastSequenceLength : 0;
+ const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
// Concatinate pastKey and K to produce presentKey.
- const presentKeyShape =
- [parameters.batchSize, parameters.numHeads, parameters.totalSequenceLength, parameters.headSize];
+ const presentKeyShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
const concatKeyInputs = pastKey ? [pastKey, k] : [k];
- const key = (context.outputCount > 1 || pastKey) ?
- context.compute(
- createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
- {inputs: concatKeyInputs, outputs: [context.outputCount > 1 ? 1 : -1]})[0] :
- k;
+ const key = outputPresentKey ? context.compute(
+ createConcatProgramInfo(concatKeyInputs, 2, presentKeyShape, k.dataType),
+ {inputs: concatKeyInputs, outputs: [1]})[0] :
+ k;
// Concatinate pastValue and V to produce presentValue.
- const presentValueShape =
- [parameters.batchSize, parameters.numHeads, parameters.totalSequenceLength, parameters.headSize];
+ const presentValueShape = [parameters.batchSize, parameters.numHeads, totalSequenceLength, parameters.headSize];
const concatValueInputs = pastValue ? [pastValue, v] : [v];
- const value = (context.outputCount > 2 || pastValue) ?
+ const value = outputPresentValue ?
context.compute(
createConcatProgramInfo(concatValueInputs, 2, presentValueShape, v.dataType),
- {inputs: concatValueInputs, outputs: [context.outputCount > 2 ? 2 : -1]})[0] :
+ {inputs: concatValueInputs, outputs: [2]})[0] :
v;
const inputsK = [q, key];
if (relativePositionBias) {
@@ -563,20 +564,22 @@ export const applyAttention =
// Run AttentionProbs
const probs = context.compute(
- createAttentionProbsProgramInfo(context, q, key, relativePositionBias, parameters, attributes),
+ createAttentionProbsProgramInfo(
+ context, q, key, relativePositionBias, parameters, attributes, pastSequenceLength),
{inputs: inputsK, outputs: [-1]})[0];
// Run Softmax
context.compute(
createInPlaceSoftmaxProgramInfo(
context, probs, parameters.batchSize * parameters.numHeads * parameters.sequenceLength,
- parameters.totalSequenceLength),
+ totalSequenceLength),
{inputs: [probs], outputs: []});
// Run AttrionScore
const inputsV = [probs, value];
context.compute(
- createVxAttentionScoreProgramInfo(context, probs, value, parameters), {inputs: inputsV, outputs: [0]});
+ createVxAttentionScoreProgramInfo(context, probs, value, parameters, pastSequenceLength),
+ {inputs: inputsV, outputs: [0]});
};
const prepare = (context: ComputeContext, parameters: AttentionParameters) => {
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
index 5e27e79087730..ec2831a3cca04 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
@@ -313,7 +313,7 @@ export const castToF32 = (dataType: string, components: number, value: string) =
return `f32(${value})`;
}
- return `vec${components}f32(${value})`;
+ return `vec${components}(${value})`;
};
/**
diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts
index 03d637b35bc7c..b67e173a22793 100644
--- a/js/web/script/test-runner-cli.ts
+++ b/js/web/script/test-runner-cli.ts
@@ -558,7 +558,9 @@ async function main() {
if (args.noSandbox) {
karmaArgs.push('--no-sandbox');
}
- if (webgpu || webnn) {
+
+ // When using BrowserStack with Safari, we need NOT to use 'localhost' as the hostname.
+ if (!(browser.startsWith('BS_') && browser.includes('Safari'))) {
karmaArgs.push('--force-localhost');
}
if (webgpu) {
diff --git a/js/web/test/data/ops/multihead-attention.jsonc b/js/web/test/data/ops/multihead-attention.jsonc
index 05687bd482e24..0bed30747bca9 100644
--- a/js/web/test/data/ops/multihead-attention.jsonc
+++ b/js/web/test/data/ops/multihead-attention.jsonc
@@ -190,5 +190,656 @@
]
}
]
+ },
+ {
+ "name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
+ "operator": "MultiHeadAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ // Q
+ {
+ "data": [1],
+ "dims": [1, 1, 1],
+ "type": "float32"
+ },
+ // K
+ {
+ "data": [2],
+ "dims": [1, 1, 1],
+ "type": "float32"
+ },
+ // V
+ {
+ "data": [3],
+ "dims": [1, 1, 1],
+ "type": "float32"
+ },
+ // Bias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // Mask
+ {
+ "data": null,
+ "type": "int32"
+ },
+ // RelativePositionBias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // PastKey
+ {
+ "data": [4],
+ "dims": [1, 1, 1, 1],
+ "type": "float32"
+ },
+ // PastValue
+ {
+ "data": [5],
+ "dims": [1, 1, 1, 1],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [3],
+ "dims": [1, 1, 1],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
+ "operator": "MultiHeadAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ // Q
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // K
+ {
+ "data": [5, 6, 7, 8],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // V
+ {
+ "data": [9, 10, 11, 12],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // Bias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // Mask
+ {
+ "data": null,
+ "type": "int32"
+ },
+ // RelativePositionBias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // PastKey
+ {
+ "data": [13, 14, 15, 16],
+ "dims": [1, 1, 1, 4],
+ "type": "float32"
+ },
+ // PastValue
+ {
+ "data": [17, 18, 19, 20],
+ "dims": [1, 1, 1, 4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [9, 10, 11, 12],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
+ "operator": "MultiHeadAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ // Q
+ {
+ "data": [1],
+ "dims": [1, 1, 1],
+ "type": "float32"
+ },
+ // K
+ {
+ "data": [2],
+ "dims": [1, 1, 1],
+ "type": "float32"
+ },
+ // V
+ {
+ "data": [3],
+ "dims": [1, 1, 1],
+ "type": "float32"
+ },
+ // Bias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // Mask
+ {
+ "data": null,
+ "type": "int32"
+ },
+ // RelativePositionBias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // PastKey
+ {
+ "data": [4],
+ "dims": [1, 1, 1, 1],
+ "type": "float32"
+ },
+ // PastValue
+ {
+ "data": [5],
+ "dims": [1, 1, 1, 1],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [4.761593818664551],
+ "dims": [1, 1, 1],
+ "type": "float32"
+ },
+ {
+ "data": [4, 2],
+ "dims": [1, 1, 2, 1],
+ "type": "float32"
+ },
+ {
+ "data": [5, 3],
+ "dims": [1, 1, 2, 1],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
+ "operator": "MultiHeadAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ // Q
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // K
+ {
+ "data": [5, 6, 7, 8],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // V
+ {
+ "data": [9, 10, 11, 12],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // Bias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // Mask
+ {
+ "data": null,
+ "type": "int32"
+ },
+ // RelativePositionBias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // Past Key
+ {
+ "data": [13, 14, 15, 16],
+ "dims": [1, 1, 1, 4],
+ "type": "float32"
+ },
+ // Past Value
+ {
+ "data": [17, 18, 19, 20],
+ "dims": [1, 1, 1, 4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [17, 18, 19, 20],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // Present key
+ {
+ "data": [13, 14, 15, 16, 5, 6, 7, 8],
+ "dims": [1, 1, 2, 4],
+ "type": "float32"
+ },
+ // Present value
+ {
+ "data": [17, 18, 19, 20, 9, 10, 11, 12],
+ "dims": [1, 1, 2, 4],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "MultiHeadAttention Basic, one head and head-size=1 with pastKey and pastValue",
+ "operator": "MultiHeadAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ // Q
+ {
+ "data": [1],
+ "dims": [1, 1, 1],
+ "type": "float32"
+ },
+ // K
+ {
+ "data": [2],
+ "dims": [1, 1, 1],
+ "type": "float32"
+ },
+ // V
+ {
+ "data": [3],
+ "dims": [1, 1, 1],
+ "type": "float32"
+ },
+ // Bias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // Mask
+ {
+ "data": null,
+ "type": "int32"
+ },
+ // RelativePositionBias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // PastKey
+ {
+ "data": [4],
+ "dims": [1, 1, 1, 1],
+ "type": "float32"
+ },
+ // PastValue
+ {
+ "data": [5],
+ "dims": [1, 1, 1, 1],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [3],
+ "dims": [1, 1, 1],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "MultiHeadAttention Basic, one head and head-size=4 with pastKey and pastValue",
+ "operator": "MultiHeadAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ // Q
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // K
+ {
+ "data": [5, 6, 7, 8],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // V
+ {
+ "data": [9, 10, 11, 12],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // Bias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // Mask
+ {
+ "data": null,
+ "type": "int32"
+ },
+ // RelativePositionBias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // PastKey
+ {
+ "data": [13, 14, 15, 16],
+ "dims": [1, 1, 1, 4],
+ "type": "float32"
+ },
+ // PastValue
+ {
+ "data": [17, 18, 19, 20],
+ "dims": [1, 1, 1, 4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [9, 10, 11, 12],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "MultiHeadAttention Basic, 4 heads and head-size=1 with pastKey and pastValue",
+ "operator": "MultiHeadAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [{ "name": "num_heads", "data": 4, "type": "int" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ // Q
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // K
+ {
+ "data": [5, 6, 7, 8],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // V
+ {
+ "data": [9, 10, 11, 12],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // Bias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // Mask
+ {
+ "data": null,
+ "type": "int32"
+ },
+ // RelativePositionBias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // PastKey
+ {
+ "data": [13, 14, 15, 16],
+ "dims": [1, 4, 1, 1],
+ "type": "float32"
+ },
+ // PastValue
+ {
+ "data": [17, 18, 19, 20],
+ "dims": [1, 4, 1, 1],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [16.997316360473633, 18, 19, 20],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ {
+ "data": [13, 5, 14, 6, 15, 7, 16, 8],
+ "dims": [1, 4, 2, 1],
+ "type": "float32"
+ },
+ {
+ "data": [17, 9, 18, 10, 19, 11, 20, 12],
+ "dims": [1, 4, 2, 1],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "MultiHeadAttention Basic, 4 heads and head-size=4 with pastKey and pastValue",
+ "operator": "MultiHeadAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [{ "name": "num_heads", "data": 4, "type": "int" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ // Q
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
+ "dims": [1, 1, 16],
+ "type": "float32"
+ },
+ // K
+ {
+ "data": [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1],
+ "dims": [1, 1, 16],
+ "type": "float32"
+ },
+ // V
+ {
+ "data": [2, 4, 8, 16, 1, 3, 9, 27, 1, 2, 4, 8, 16, 32, 64, 128],
+ "dims": [1, 1, 16],
+ "type": "float32"
+ },
+ // Bias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // Mask
+ {
+ "data": null,
+ "type": "int32"
+ },
+ // RelativePositionBias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // Past Key
+ {
+ "data": [13, 14, 15, 16, 5, 6, 7, 8, 1, 2, 3, 4, 9, 10, 11, 12],
+ "dims": [1, 4, 1, 4],
+ "type": "float32"
+ },
+ // Past Value
+ {
+ "data": [17, 18, 19, 20, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [1, 4, 1, 4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [
+ 16.899608612060547, 17.906301498413086, 18.926380157470703, 19.973230361938477, 1, 3, 9, 27, 1, 2, 4, 8,
+ 5, 6, 7, 8
+ ],
+ "dims": [1, 1, 16],
+ "type": "float32"
+ },
+ // Present key
+ {
+ "data": [
+ 13, 14, 15, 16, 16, 15, 14, 13, 5, 6, 7, 8, 12, 11, 10, 9, 1, 2, 3, 4, 8, 7, 6, 5, 9, 10, 11, 12, 4, 3, 2,
+ 1
+ ],
+ "dims": [1, 4, 2, 4],
+ "type": "float32"
+ },
+ // Present value
+ {
+ "data": [
+ 17, 18, 19, 20, 2, 4, 8, 16, 9, 10, 11, 12, 1, 3, 9, 27, 1, 2, 3, 4, 1, 2, 4, 8, 5, 6, 7, 8, 16, 32, 64,
+ 128
+ ],
+ "dims": [1, 4, 2, 4],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey and PastValue",
+ "operator": "MultiHeadAttention",
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ // Q
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // K
+ {
+ "data": [5, 6, 7, 8],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // V
+ {
+ "data": [9, 10, 11, 12],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // Bias
+ {
+ "data": null,
+ "type": "float32"
+ },
+ // Mask
+ {
+ "data": null,
+ "type": "int32"
+ },
+ // RelativePositionBias
+ {
+ "data": [10, 20],
+ "dims": [1, 1, 1, 2],
+ "type": "float32"
+ },
+ // Past Key
+ {
+ "data": [13, 14, 15, 16],
+ "dims": [1, 1, 1, 4],
+ "type": "float32"
+ },
+ // Past Value
+ {
+ "data": [17, 18, 19, 20],
+ "dims": [1, 1, 1, 4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [17, 18, 19, 20],
+ "dims": [1, 1, 4],
+ "type": "float32"
+ },
+ // Present key
+ {
+ "data": [13, 14, 15, 16, 5, 6, 7, 8],
+ "dims": [1, 1, 2, 4],
+ "type": "float32"
+ },
+ // Present value
+ {
+ "data": [17, 18, 19, 20, 9, 10, 11, 12],
+ "dims": [1, 1, 2, 4],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
}
]
diff --git a/objectivec/ort_checkpoint.mm b/objectivec/ort_checkpoint.mm
index 12386457fadf1..2c7c9e417b52c 100644
--- a/objectivec/ort_checkpoint.mm
+++ b/objectivec/ort_checkpoint.mm
@@ -8,6 +8,7 @@
#include
#import "cxx_api.h"
+#import "cxx_utils.h"
#import "error_utils.h"
NS_ASSUME_NONNULL_BEGIN
@@ -73,7 +74,7 @@ - (nullable NSString*)getStringPropertyWithName:(NSString*)name error:(NSError**
try {
Ort::Property value = [self CXXAPIOrtCheckpoint].GetProperty(name.UTF8String);
if (std::string* str = std::get_if(&value)) {
- return [NSString stringWithUTF8String:str->c_str()];
+ return utils::toNSString(str->c_str());
}
ORT_CXX_API_THROW("Property is not a string.", ORT_INVALID_ARGUMENT);
}
diff --git a/objectivec/ort_session.mm b/objectivec/ort_session.mm
index 87288bd1e9dc7..3dcc88f1ebd5b 100644
--- a/objectivec/ort_session.mm
+++ b/objectivec/ort_session.mm
@@ -7,6 +7,7 @@
#include
#import "cxx_api.h"
+#import "cxx_utils.h"
#import "error_utils.h"
#import "ort_enums_internal.h"
#import "ort_env_internal.h"
@@ -198,8 +199,7 @@ - (BOOL)runWithInputs:(NSDictionary*)inputs
for (size_t i = 0; i < nameCount; ++i) {
auto name = getName(i, allocator);
- NSString* nameNsstr = [NSString stringWithUTF8String:name.get()];
- NSAssert(nameNsstr != nil, @"nameNsstr must not be nil");
+ NSString* nameNsstr = utils::toNSString(name.get());
[result addObject:nameNsstr];
}
diff --git a/onnxruntime/core/providers/coreml/model/host_utils.mm b/onnxruntime/core/providers/coreml/model/host_utils.mm
index 5487ea35388f5..70052f50ae1c2 100644
--- a/onnxruntime/core/providers/coreml/model/host_utils.mm
+++ b/onnxruntime/core/providers/coreml/model/host_utils.mm
@@ -3,6 +3,7 @@
#include "core/platform/env.h"
#include "core/providers/coreml/model/host_utils.h"
+#include "core/providers/coreml/model/objc_str_utils.h"
#import
@@ -36,7 +37,7 @@ int32_t CoreMLVersion() {
#if !defined(NDEBUG)
std::string path_override = Env::Default().GetEnvironmentVar(kOverrideModelOutputDirectoryEnvVar);
if (!path_override.empty()) {
- NSString* ns_path_override = [NSString stringWithUTF8String:path_override.c_str()];
+ NSString* ns_path_override = Utf8StringToNSString(path_override.c_str());
temporary_directory_url = [NSURL fileURLWithPath:ns_path_override isDirectory:YES];
}
#endif
diff --git a/onnxruntime/core/providers/coreml/model/model.mm b/onnxruntime/core/providers/coreml/model/model.mm
index 1434043e064f4..3edcdb3f95e46 100644
--- a/onnxruntime/core/providers/coreml/model/model.mm
+++ b/onnxruntime/core/providers/coreml/model/model.mm
@@ -23,6 +23,7 @@
#include "core/providers/coreml/builders/helper.h"
#include "core/providers/coreml/coreml_provider_factory.h"
#include "core/providers/coreml/model/host_utils.h"
+#include "core/providers/coreml/model/objc_str_utils.h"
#include "core/providers/coreml/shape_utils.h"
// force the linker to create a dependency on the CoreML framework so that in MAUI usage we don't need
@@ -33,13 +34,6 @@
using namespace onnxruntime::coreml;
namespace {
-// Converts a UTF8 const char* to an NSString. Throws on failure.
-NSString* _Nonnull Utf8StringToNSString(const char* utf8_str) {
- NSString* result = [NSString stringWithUTF8String:utf8_str];
- ORT_ENFORCE(result != nil, "NSString conversion failed.");
- return result;
-}
-
/**
* Computes the static output shape used to allocate the output tensor.
* `inferred_shape` is the inferred shape known at model compile time. It may contain dynamic dimensions (-1).
@@ -165,7 +159,7 @@ Status CreateInputFeatureProvider(const std::unordered_map&)inputs
for (const auto& [output_name, output_tensor_info] : outputs) {
MLFeatureValue* output_value =
- [output_features featureValueForName:Utf8StringToNSString(output_name.c_str())];
+ [output_features featureValueForName:util::Utf8StringToNSString(output_name.c_str())];
if (output_value == nil) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "output_features has no value for ", output_name);
diff --git a/onnxruntime/core/providers/coreml/model/objc_str_utils.h b/onnxruntime/core/providers/coreml/model/objc_str_utils.h
new file mode 100644
index 0000000000000..38006c39bc588
--- /dev/null
+++ b/onnxruntime/core/providers/coreml/model/objc_str_utils.h
@@ -0,0 +1,14 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#import
+
+namespace onnxruntime::coreml::util {
+
+// Converts a UTF8 const char* to an NSString. Throws on failure.
+// Prefer this to directly calling [NSString stringWithUTF8String:] as that may return nil.
+NSString* _Nonnull Utf8StringToNSString(const char* _Nonnull utf8_str);
+
+} // namespace onnxruntime::coreml::util
diff --git a/onnxruntime/core/providers/coreml/model/objc_str_utils.mm b/onnxruntime/core/providers/coreml/model/objc_str_utils.mm
new file mode 100644
index 0000000000000..0d280ce0b7f8f
--- /dev/null
+++ b/onnxruntime/core/providers/coreml/model/objc_str_utils.mm
@@ -0,0 +1,16 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/coreml/model/objc_str_utils.h"
+
+#include "core/common/common.h"
+
+namespace onnxruntime::coreml::util {
+
+NSString* _Nonnull Utf8StringToNSString(const char* _Nonnull utf8_str) {
+ NSString* result = [NSString stringWithUTF8String:utf8_str];
+ ORT_ENFORCE(result != nil, "NSString conversion failed.");
+ return result;
+}
+
+} // namespace onnxruntime::coreml::util
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index ba0d2fe5d4174..c553458f7fa45 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1157,7 +1157,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterND);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad);
@@ -1295,6 +1295,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample);
// Opset 17
@@ -1312,6 +1313,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
@@ -2071,7 +2073,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2202,6 +2204,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
// Opset 17
@@ -2225,6 +2228,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc
index 42a9f50001103..bfe385af49dc4 100755
--- a/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc
+++ b/onnxruntime/core/providers/cuda/tensor/scatter_elements.cc
@@ -133,7 +133,7 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
} else if (reduction_ == "max") {
args.operation = GatherScatterElementsArgs::Operation::MAX;
} else {
- ORT_THROW("Unsupported reduction type");
+ ORT_THROW("Unsupported reduction type for ScatterElements.");
}
// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.
diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc
index 6191715f79188..a270249da2b7f 100644
--- a/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc
+++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd.cc
@@ -3,6 +3,7 @@
#include "core/providers/cuda/tensor/scatter_nd.h"
#include "core/providers/cuda/tensor/scatter_nd_impl.h"
+#include "core/providers/cuda/tensor/scatter_nd_common.h"
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "core/providers/cpu/tensor/utils.h"
@@ -16,18 +17,61 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.MayInplace(0, 0),
- ScatterND);
+ ScatterNDDisjointAndNoReduction);
+
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
+ kOnnxDomain,
+ 13, 15,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
+ .MayInplace(0, 0),
+ ScatterNDWithAtomicReduction);
+
+ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
+ kOnnxDomain,
+ 16, 17,
+ kCudaExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
+ .MayInplace(0, 0),
+ ScatterNDWithAtomicReduction);
ONNX_OPERATOR_KERNEL_EX(ScatterND,
kOnnxDomain,
- 13,
+ 18,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
.MayInplace(0, 0),
- ScatterND);
+ ScatterNDWithAtomicReduction);
-Status ScatterND::ComputeInternal(OpKernelContext* context) const {
+static Status InitiliazeElementCountsAndInputDimsSpanOrGpu(int64_t last_index_dimension, const TensorShape& input_shape,
+ ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
+ CudaKernel::CudaAsyncBuffer& element_counts_and_input_dims_gpu,
+ onnxruntime::OpKernelContext* context) {
+ TensorPitches input_strides(input_shape);
+
+ if (last_index_dimension < 6) {
+ element_counts_and_input_dims.gpu_ptr = nullptr;
+ for (int64_t i = 0; i < last_index_dimension; ++i) {
+ element_counts_and_input_dims.stack_ptr[i] = input_strides[i];
+ element_counts_and_input_dims.stack_ptr[i + last_index_dimension] = input_shape[i];
+ }
+ } else {
+ element_counts_and_input_dims_gpu.AllocCpuPtr(last_index_dimension * 2);
+ memset(element_counts_and_input_dims_gpu.CpuPtr(), 0, sizeof(int64_t) * last_index_dimension * 2);
+ for (int64_t i = 0; i < last_index_dimension; ++i) {
+ element_counts_and_input_dims_gpu.CpuPtr()[i] = input_strides[i];
+ element_counts_and_input_dims_gpu.CpuPtr()[i + last_index_dimension] = input_shape[i];
+ }
+ ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream()));
+ element_counts_and_input_dims.gpu_ptr = element_counts_and_input_dims_gpu.GpuPtr();
+ }
+ return Status::OK();
+}
+
+Status ScatterNDDisjointAndNoReduction::ComputeInternal(OpKernelContext* context) const {
const auto* input_tensor = context->Input(0);
const auto* indices_tensor = context->Input(1);
const auto* updates_tensor = context->Input(2);
@@ -44,8 +88,6 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
const void* input_data = input_tensor->DataRaw();
void* output_data = output_tensor->MutableDataRaw();
- size_t element_size = input_tensor->DataType()->Size();
-
if (input_data != output_data) {
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will be faster than invoking cudaMemcpy ?
CUDA_RETURN_IF_ERROR(
@@ -58,18 +100,17 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
}
auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1];
+ size_t element_size = input_tensor->DataType()->Size();
// We need element counts for each dimension and the input dim value for each dimension
// for the range [0, last_index_dimension).
// To avoid multiple GPU data transfers, we combine this into one array and send it through
- TensorPitches input_strides(input_shape);
- std::vector element_counts_and_input_dims(last_index_dimension * 2, 0LL);
- for (int64_t i = 0; i < last_index_dimension; ++i) {
- element_counts_and_input_dims[i] = input_strides[i];
- element_counts_and_input_dims[i + last_index_dimension] = input_shape[i];
- }
- CudaAsyncBuffer element_counts_and_input_dims_gpu(this, element_counts_and_input_dims);
- ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream()));
+ ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims;
+ CudaAsyncBuffer element_counts_and_input_dims_gpu(this);
+ ORT_RETURN_IF_ERROR(InitiliazeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape,
+ element_counts_and_input_dims,
+ element_counts_and_input_dims_gpu,
+ context));
ORT_RETURN_IF_ERROR(ScatterNDImpl(
Stream(context),
@@ -78,12 +119,89 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
indices_shape.Size() / static_cast(last_index_dimension),
indices_tensor->Data(), // only int64_t is supported for indices as per the onnx spec
last_index_dimension,
- element_counts_and_input_dims_gpu.GpuPtr(),
+ element_counts_and_input_dims,
updates_tensor->DataRaw(),
input_shape.SizeFromDimension(last_index_dimension)));
return Status::OK();
}
+Status ScatterNDWithAtomicReduction::ComputeInternal(OpKernelContext* context) const {
+ const auto* input_tensor = context->Input(0);
+ const auto* indices_tensor = context->Input(1);
+ const auto* updates_tensor = context->Input(2);
+
+ const auto& input_shape = input_tensor->Shape();
+ const auto& indices_shape = indices_tensor->Shape();
+ const auto& updates_shape = updates_tensor->Shape();
+
+ // Validate input shapes
+ ORT_RETURN_IF_ERROR(onnxruntime::ScatterND::ValidateShapes(input_shape, indices_shape, updates_shape));
+
+ auto* output_tensor = context->Output(0, input_shape);
+
+ const void* input_data = input_tensor->DataRaw();
+ void* output_data = output_tensor->MutableDataRaw();
+
+ if (input_data != output_data) {
+ // TODO: Run benchmarks to determine if a dedicated kernel doing data copy will
+ // be faster than invoking cudaMemcpy ?
+ CUDA_RETURN_IF_ERROR(
+ cudaMemcpyAsync(output_data, input_data, input_tensor->SizeInBytes(),
+ cudaMemcpyDeviceToDevice, Stream(context)));
+ }
+
+ // Bail out early
+ if (indices_shape.Size() == 0) {
+ return Status::OK();
+ }
+
+ auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1];
+ ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims;
+ CudaAsyncBuffer element_counts_and_input_dims_gpu(this);
+ ORT_RETURN_IF_ERROR(InitiliazeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape,
+ element_counts_and_input_dims,
+ element_counts_and_input_dims_gpu,
+ context));
+
+ switch (reduction_) {
+ case ScatterNDReduction::None: {
+ size_t element_size = input_tensor->DataType()->Size();
+ ORT_RETURN_IF_ERROR(ScatterNDImpl(
+ Stream(context),
+ output_data,
+ element_size,
+ indices_shape.Size() / static_cast(last_index_dimension),
+ indices_tensor->Data(), // only int64_t is supported for indices as per the onnx spec
+ last_index_dimension,
+ element_counts_and_input_dims,
+ updates_tensor->DataRaw(),
+ input_shape.SizeFromDimension(last_index_dimension)));
+ } break;
+ case ScatterNDReduction::Add:
+ case ScatterNDReduction::Min:
+ case ScatterNDReduction::Max:
+ case ScatterNDReduction::Mul: {
+ auto element_type = input_tensor->DataType()->AsPrimitiveDataType()->GetDataType();
+ ORT_RETURN_IF_ERROR(ScatterNDImplReduction(
+ Stream(context),
+ output_data,
+ element_type,
+ indices_shape.Size() / static_cast(last_index_dimension),
+ indices_tensor->Data(), // only int64_t is supported for indices as per the onnx spec
+ last_index_dimension,
+ element_counts_and_input_dims,
+ updates_tensor->DataRaw(),
+ input_shape.SizeFromDimension(last_index_dimension),
+ reduction_));
+ } break;
+ default:
+ ORT_THROW("ScatterND not supported for other reduction than Add, None.");
+ break;
+ }
+
+ return Status::OK();
+}
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd.h b/onnxruntime/core/providers/cuda/tensor/scatter_nd.h
index 07df5ab552c3c..6d8bbe6f463fd 100644
--- a/onnxruntime/core/providers/cuda/tensor/scatter_nd.h
+++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd.h
@@ -3,18 +3,63 @@
#pragma once
+#include
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/cuda/cuda_kernel.h"
+#include "core/providers/cuda/tensor/scatter_nd_kind.h"
#include "core/providers/cpu/tensor/scatter_nd.h"
namespace onnxruntime {
namespace cuda {
-class ScatterND final : public CudaKernel {
+/**
+ * This implementation assumes there is common indices and
+ * reduction is not needed. The code does not check that condition.
+ * However in that case, the same output element could be accessed
+ * from different threads at the same time and the final value
+ * is unlikely to be correct.
+ */
+class ScatterNDDisjointAndNoReduction final : public CudaKernel {
public:
- explicit ScatterND(const OpKernelInfo& info) : CudaKernel(info) {}
+ explicit ScatterNDDisjointAndNoReduction(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};
+/**
+ * This is an implementation derived from the first one.
+ * It does atomic operation to handle conflicts.
+ * The result is unlikely to be correct if the reduction is none
+ * as there is no guarantee that the final value will be the one
+ * corresponding to the highest visited index.
+ * TODO: change the implementation of avoid conflicts.
+ */
+class ScatterNDWithAtomicReduction final : public CudaKernel {
+ public:
+ explicit ScatterNDWithAtomicReduction(const OpKernelInfo& info) : CudaKernel(info) {
+ std::string reduction;
+
+ if (info.GetAttr("reduction", &reduction).IsOK()) {
+ if (reduction == "add") {
+ reduction_ = ScatterNDReduction::Add;
+ } else if (reduction == "mul") {
+ reduction_ = ScatterNDReduction::Mul;
+ } else if (reduction == "min") {
+ reduction_ = ScatterNDReduction::Min;
+ } else if (reduction == "max") {
+ reduction_ = ScatterNDReduction::Max;
+ } else if (reduction == "none") {
+ LOGS_DEFAULT(WARNING) << "ScatterND with reduction=='none' only guarantees "
+ << "to be correct if indices are not duplicated.";
+ } else {
+ ORT_THROW("Reduction '", reduction, "' is not supported on CUDA and opset >= 13.");
+ }
+ }
+ }
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ ScatterNDReduction reduction_{ScatterNDReduction::None};
+};
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd_common.h b/onnxruntime/core/providers/cuda/tensor/scatter_nd_common.h
new file mode 100644
index 0000000000000..9f1465590c5e4
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd_common.h
@@ -0,0 +1,15 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+namespace onnxruntime {
+namespace cuda {
+
+struct ElementCountsAndInputDimsSpanOrGpu {
+ int64_t stack_ptr[12];
+ int64_t* gpu_ptr;
+};
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu
index e9199b5e1b15b..47e7d103ce27b 100644
--- a/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu
+++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.cu
@@ -14,7 +14,7 @@ __global__ void _ScatterNDKernel(
const size_t num_indices,
const int64_t* indices_data,
const int64_t last_index_dimension,
- const int64_t* element_counts_and_input_dims,
+ ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims,
const T* updates_data,
const size_t num_updates_elements) {
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, num_indices);
@@ -27,8 +27,12 @@ __global__ void _ScatterNDKernel(
for (size_t i = indices_start; i < indices_end; ++i) {
int64_t index = indices_data[i];
- int64_t element_count_dim = element_counts_and_input_dims[i - indices_start];
- int64_t dim_value = element_counts_and_input_dims[i - indices_start + last_index_dimension];
+ int64_t element_count_dim = element_counts_and_input_dims.gpu_ptr == nullptr
+ ? element_counts_and_input_dims.stack_ptr[i - indices_start]
+ : element_counts_and_input_dims.gpu_ptr[i - indices_start];
+ int64_t dim_value = element_counts_and_input_dims.gpu_ptr == nullptr
+ ? element_counts_and_input_dims.stack_ptr[i - indices_start + last_index_dimension]
+ : element_counts_and_input_dims.gpu_ptr[i - indices_start + last_index_dimension];
// Clamp the index if out of range
// This would have been an error in the CPU kernel, but throwing in the CUDA EP
@@ -66,7 +70,7 @@ Status ScatterNDImpl(
const size_t num_indices,
const int64_t* indices_data,
const int64_t last_index_dimension,
- const int64_t* element_counts_and_input_dims,
+ const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
const void* updates_data,
const size_t num_updates_elements) {
if (num_indices == 0)
@@ -128,5 +132,197 @@ Status ScatterNDImpl(
return Status::OK();
}
+template
+struct FuncAdd {
+ __device__ __inline__ void operator()(T* start_addr, T value) const {
+ atomic_add(start_addr, value);
+ }
+};
+
+template
+struct FuncMul {
+ __device__ __inline__ void operator()(T* start_addr, T value) const {
+ atomic_mul(start_addr, value);
+ }
+};
+
+template
+struct FuncMax {
+ __device__ __inline__ void operator()(T* start_addr, T value) const {
+ atomic_max(start_addr, value);
+ }
+};
+
+template
+struct FuncMin {
+ __device__ __inline__ void operator()(T* start_addr, T value) const {
+ atomic_min(start_addr, value);
+ }
+};
+
+template
+__global__ void _ScatterNDKernelReduction(
+ T* output_data,
+ const size_t num_indices,
+ const int64_t* indices_data,
+ const int64_t last_index_dimension,
+ ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims,
+ const T* updates_data,
+ const size_t num_updates_elements,
+ const TFunc func) {
+ CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, num_indices);
+
+ // Compute the base offset into the output data
+ int64_t data_offset = 0;
+
+ size_t indices_start = last_index_dimension * id;
+ size_t indices_end = indices_start + last_index_dimension;
+ for (size_t i = indices_start; i < indices_end; ++i) {
+ int64_t index = indices_data[i];
+
+ int64_t element_count_dim = element_counts_and_input_dims.gpu_ptr == nullptr
+ ? element_counts_and_input_dims.stack_ptr[i - indices_start]
+ : element_counts_and_input_dims.gpu_ptr[i - indices_start];
+ int64_t dim_value = element_counts_and_input_dims.gpu_ptr == nullptr
+ ? element_counts_and_input_dims.stack_ptr[i - indices_start + last_index_dimension]
+ : element_counts_and_input_dims.gpu_ptr[i - indices_start + last_index_dimension];
+
+ // Clamp the index if out of range
+ // This would have been an error in the CPU kernel, but throwing in the CUDA EP
+ // is hard. This is the approach taken by other frameworks for out of bound indices
+ // in their corresponding GPU backends as well.
+ // index >= -dim_value && index < dim_value
+
+ if (index >= 0) {
+ if (index >= dim_value) {
+ index = dim_value - 1;
+ }
+ } else {
+ if (index < -dim_value) {
+ index = 0;
+ } else {
+ index += dim_value;
+ }
+ }
+
+ data_offset += (index * element_count_dim);
+ }
+
+ const T* updates_data_base = updates_data + num_updates_elements * id;
+ T* output_data_base = output_data + data_offset;
+
+ for (size_t i = 0; i < num_updates_elements; ++i) {
+ func(output_data_base + i, updates_data_base[i]);
+ }
+}
+
+template
+Status _ScatterNDType(
+ cudaStream_t stream,
+ T* output_data,
+ const size_t num_indices,
+ const int64_t* indices_data,
+ const int64_t last_index_dimension,
+ const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
+ const T* updates_data,
+ const size_t num_updates_elements,
+ ScatterNDReduction reduction) {
+ // Parallelize on number of indices
+ int blocksPerGrid = static_cast(ceil(static_cast(num_indices) / GridDim::maxThreadsPerBlock));
+
+ switch (reduction) {
+ case ScatterNDReduction::Add:
+ _ScatterNDKernelReduction<<>>(
+ output_data,
+ num_indices,
+ indices_data,
+ last_index_dimension,
+ element_counts_and_input_dims,
+ updates_data,
+ num_updates_elements,
+ FuncAdd());
+ break;
+ case ScatterNDReduction::Mul:
+ _ScatterNDKernelReduction<<>>(
+ output_data,
+ num_indices,
+ indices_data,
+ last_index_dimension,
+ element_counts_and_input_dims,
+ updates_data,
+ num_updates_elements,
+ FuncMul());
+ break;
+ case ScatterNDReduction::Min:
+ _ScatterNDKernelReduction<<>>(
+ output_data,
+ num_indices,
+ indices_data,
+ last_index_dimension,
+ element_counts_and_input_dims,
+ updates_data,
+ num_updates_elements,
+ FuncMin());
+ break;
+ case ScatterNDReduction::Max:
+ _ScatterNDKernelReduction<<>>(
+ output_data,
+ num_indices,
+ indices_data,
+ last_index_dimension,
+ element_counts_and_input_dims,
+ updates_data,
+ num_updates_elements,
+ FuncMax());
+ break;
+ default:
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Reduction ", static_cast(reduction), " not implemented for ScatterND operator.");
+ }
+
+ return Status::OK();
+}
+
+Status ScatterNDImplReduction(
+ cudaStream_t stream,
+ void* output_data,
+ const int32_t element_type,
+ const size_t num_indices,
+ const int64_t* indices_data,
+ const int64_t last_index_dimension,
+ const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
+ const void* updates_data,
+ const size_t num_updates_elements,
+ ScatterNDReduction reduction) {
+ if (num_indices == 0)
+ return Status::OK();
+
+ switch (element_type) {
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
+ return _ScatterNDType(
+ stream,
+ reinterpret_cast(output_data),
+ num_indices,
+ indices_data,
+ last_index_dimension,
+ element_counts_and_input_dims,
+ reinterpret_cast(updates_data),
+ num_updates_elements,
+ reduction);
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
+ return _ScatterNDType(
+ stream,
+ reinterpret_cast(output_data),
+ num_indices,
+ indices_data,
+ last_index_dimension,
+ element_counts_and_input_dims,
+ reinterpret_cast(updates_data),
+ num_updates_elements,
+ reduction);
+ default:
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "element_type ", static_cast(element_type), " not implemented for ScatterND operator.");
+ }
+}
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h
index 874d275f94776..a3c8aab460043 100644
--- a/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h
+++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd_impl.h
@@ -4,6 +4,8 @@
#pragma once
#include "core/providers/cuda/shared_inc/cuda_utils.h"
+#include "core/providers/cuda/tensor/scatter_nd_kind.h"
+#include "core/providers/cuda/tensor/scatter_nd_common.h"
namespace onnxruntime {
namespace cuda {
@@ -15,9 +17,21 @@ Status ScatterNDImpl(
const size_t num_indices,
const int64_t* indices_data,
const int64_t last_index_dimension,
- const int64_t* element_counts_and_input_dims,
+ const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
const void* updates_data,
const size_t num_updates_elements);
+Status ScatterNDImplReduction(
+ cudaStream_t stream,
+ void* output_data,
+ const int32_t element_type,
+ const size_t num_indices,
+ const int64_t* indices_data,
+ const int64_t last_index_dimension,
+ const ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
+ const void* updates_data,
+ const size_t num_updates_elements,
+ ScatterNDReduction reduction);
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/tensor/scatter_nd_kind.h b/onnxruntime/core/providers/cuda/tensor/scatter_nd_kind.h
new file mode 100644
index 0000000000000..d766cdd920955
--- /dev/null
+++ b/onnxruntime/core/providers/cuda/tensor/scatter_nd_kind.h
@@ -0,0 +1,18 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+namespace onnxruntime {
+namespace cuda {
+
+enum class ScatterNDReduction : int {
+ None = 0,
+ Add = 1,
+ Mul = 2,
+ Min = 3,
+ Max = 4,
+};
+
+} // namespace cuda
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc
index 7a665ef82f083..9a0b2fcc6cf06 100644
--- a/onnxruntime/core/providers/dml/dml_provider_factory.cc
+++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc
@@ -542,7 +542,11 @@ static D3D12_COMMAND_LIST_TYPE CalculateCommandListType(ID3D12Device* d3d12_devi
sizeof(feature_levels)
));
- auto use_compute_command_list = (feature_levels.MaxSupportedFeatureLevel <= D3D_FEATURE_LEVEL_1_0_CORE);
+ // Use compute queue whenever possible on supported hardware to avoid TDR and maintain UI QoS
+ // Core and generic devices only have compute queues, DX11 has "immediate" submission, DX12 has both
+ auto use_compute_command_list = (feature_levels.MaxSupportedFeatureLevel <= D3D_FEATURE_LEVEL_1_0_CORE) ||
+ (feature_levels.MaxSupportedFeatureLevel >= D3D_FEATURE_LEVEL_12_0);
+
if (use_compute_command_list)
{
return D3D12_COMMAND_LIST_TYPE_COMPUTE;
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
index 704e0b5ab26c9..48e3ac493c962 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
@@ -1161,7 +1161,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, LRN);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 13, Identity);
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, ScatterND);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 15, ScatterND);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, float, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, double, Pad);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad);
@@ -1295,6 +1295,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
@@ -1308,6 +1309,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, bool, Pad);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterElements);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, ScatterND);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, float, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, double, Resize);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 18, MLFloat16, Resize);
@@ -2115,7 +2117,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2249,6 +2251,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 17
BuildKernelCreateInfo,
@@ -2262,6 +2265,7 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo("data", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f});
+ test1.AddInput("indices", {3, 1}, {0, 1, 0});
+ // The linter complains if the line is split into multiple lines.
+ test1.AddInput("updates", {3, 2, 3}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f});
+ test1.AddOutput("output", {2, 2, 3}, {8194.1f, 16388.1f, 32776.10f, 65552.10f, 131104.1f, 262208.1f, 128.1f, 256.1f, 512.1f, 1024.1f, 2048.1f, 4096.1f});
+ test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterNDOpTest, ScatterND_18_mul) {
+ OpTester test1("ScatterND", 18);
+ test1.AddAttribute("reduction", "mul");
+ test1.AddInput("data", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f});
+ test1.AddInput("indices", {3, 1}, {0, 1, 0});
+ // The linter complains if the line is split into multiple lines.
+ test1.AddInput("updates", {3, 2, 3}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f});
+ test1.AddOutput("output", {2, 2, 3}, {1638.4f, 6553.6f, 26214.4f, 104857.6f, 419430.4f, 1677721.625f, 12.8f, 25.6f, 51.2f, 102.4f, 204.8f, 409.6f});
+ test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterNDOpTest, ScatterND_18_mul_long_shape) {
+ OpTester test1("ScatterND", 18);
+ test1.AddAttribute("reduction", "mul");
+ test1.AddInput("data", {2, 2, 3, 1, 1, 1, 1}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f});
+ test1.AddInput("indices", {3, 1}, {0, 1, 0});
+ // The linter complains if the line is split into multiple lines.
+ test1.AddInput("updates", {3, 2, 3, 1, 1, 1, 1}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f});
+ test1.AddOutput("output", {2, 2, 3, 1, 1, 1, 1}, {1638.4f, 6553.6f, 26214.4f, 104857.6f, 419430.4f, 1677721.625f, 12.8f, 25.6f, 51.2f, 102.4f, 204.8f, 409.6f});
+ test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterNDOpTest, ScatterND_18_min) {
+ OpTester test1("ScatterND", 18);
+ test1.AddAttribute("reduction", "min");
+ test1.AddInput("data", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f});
+ test1.AddInput("indices", {3, 1}, {0, 1, 0});
+ // The linter complains if the line is split into multiple lines.
+ test1.AddInput("updates", {3, 2, 3}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f});
+ test1.AddOutput("output", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f});
+ test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+TEST(ScatterNDOpTest, ScatterND_18_max) {
+ OpTester test1("ScatterND", 18);
+ test1.AddAttribute("reduction", "max");
+ test1.AddInput("data", {2, 2, 3}, {0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f, 0.1f});
+ test1.AddInput("indices", {3, 1}, {0, 1, 0});
+ // The linter complains if the line is split into multiple lines.
+ test1.AddInput("updates", {3, 2, 3}, {2.0f, 4.0f, 8.0f, 16.0f, 32.0f, 64.0f, 128.0f, 256.0f, 512.0f, 1024.0f, 2048.0f, 4096.0f, 8192.0f, 16384.0f, 32768.0f, 65536.0f, 131072.0f, 262144.0f});
+ test1.AddOutput("output", {2, 2, 3}, {8192.0, 16384.0, 32768.0, 65536.0, 131072.0, 262144.0, 128.0, 256.0, 512.0, 1024.0, 2048.0, 4096.0});
+ test1.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
} // namespace test
} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
index 6b587be7d74eb..2a7a7158b5f62 100644
--- a/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/scatter_op_test.cc
@@ -308,13 +308,30 @@ TEST(ScatterElements, AddReduction) {
test.AddAttribute("axis", 0);
test.AddAttribute("reduction", "add");
- test.AddInput("data", {2, 3}, {-9.f, -4.f, -1.f, -7.f, -3.f, -6.f});
+ test.AddInput("data", {3, 3}, {1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
+ test.AddInput("indices", {2, 3}, {1, 0, 2, 0, 2, 1});
+ test.AddInput("updates", {2, 3}, {1.0f, 1.1f, 1.2f, 2.0f, 2.1f, 2.2f});
+ test.AddOutput("y", {3, 3}, {3.0f, 1.1f, 0.0f, 1.0f, 0.0f, 2.2f, 0.0f, 2.1f, 1.2f});
+
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
+#if defined(CUDA_VERSION)
+// Operation on float16 (MLFloat16) is not implemented on CPU.
+TEST(ScatterElements, AddReduction_MLFloat16) {
+ OpTester test("ScatterElements", 18);
+ test.AddAttribute("axis", 0);
+ test.AddAttribute("reduction", "add");
+
+ test.AddInput("data", {2, 3}, ToFloat16(std::vector({-9.f, -4.f, -1.f, -7.f, -3.f, -6.f})));
test.AddInput("indices", {4, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
- test.AddInput("updates", {4, 3}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f});
- test.AddOutput("y", {2, 3}, {-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)});
+ test.AddInput("updates", {4, 3}, ToFloat16(std::vector({1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f})));
+ test.AddOutput("y", {2, 3}, ToFloat16(std::vector({-9.f, -4.f, -1.f, -7.f + (1.f + 2.f + 3.f + 4.f), -3.f + (1.f + 2.f + 3.f + 4.f), -6.f + (1.f + 2.f + 3.f + 4.f)})));
+ // exclude CPU Execution Provider as MLFloat16 is not supported in CPU
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
}
+#endif
TEST(ScatterElements, AddReductionAxis1) {
OpTester test("ScatterElements", 18);
diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py
index 395315b2a2b0c..6eebc996fde9c 100644
--- a/onnxruntime/test/python/onnx_backend_test_series.py
+++ b/onnxruntime/test/python/onnx_backend_test_series.py
@@ -89,14 +89,16 @@ def apply_filters(filters, category):
def load_jsonc(basename: str):
"""Returns a deserialized object from the JSONC file in testdata/."""
- filename = os.path.join(
- os.path.dirname(os.path.realpath(__file__)),
- "testdata",
- basename,
- )
- if not os.path.exists(filename):
- raise FileNotFoundError(f"File not found {filename!r}.")
+ filenames = [
+ os.path.join(os.path.dirname(os.path.realpath(__file__)), "testdata", basename),
+ os.path.realpath(os.path.join(os.path.dirname(__file__), "..", "..", "test", "testdata", basename)),
+ ]
+
+ filtered = [f for f in filenames if os.path.exists(f)]
+ if not filtered:
+ raise FileNotFoundError(f"No file found in {filenames!r}.")
+ filename = filtered[0]
with open(filename, encoding="utf-8") as f: # pylint: disable=invalid-name
lines = f.readlines()
lines = [x.split("//")[0] for x in lines]
diff --git a/onnxruntime/test/python/onnxruntime_test_scatternd.py b/onnxruntime/test/python/onnxruntime_test_scatternd.py
new file mode 100644
index 0000000000000..2a5555bba37de
--- /dev/null
+++ b/onnxruntime/test/python/onnxruntime_test_scatternd.py
@@ -0,0 +1,329 @@
+import itertools
+import json
+import os
+import typing
+import unittest
+import warnings
+
+import numpy as np
+import onnx.helper as oh
+from onnx import TensorProto, load
+from onnx.numpy_helper import from_array
+from onnx.reference import ReferenceEvaluator
+
+import onnxruntime
+
+
+def has_cuda():
+ available_providers = [provider for provider in onnxruntime.get_available_providers()]
+ return "CUDAExecutionProvider" in available_providers
+
+
+def ignore_warnings(warns: typing.List[Warning]) -> typing.Callable:
+ def wrapper(fct):
+ if warns is None:
+ raise AssertionError(f"warns cannot be None for '{fct}'.")
+
+ def call_f(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", warns)
+ return fct(self)
+
+ return call_f
+
+ return wrapper
+
+
+class TestScatterPerProvider(unittest.TestCase):
+ def assert_exists(self, filename: str):
+ assert os.path.exists(filename), f"Unable to find {filename!r}."
+
+ def common_scatter(self, opset, providers, dtype, reduction, expected_names):
+ from onnxruntime import InferenceSession, SessionOptions
+
+ op_type = "ScatterElements" if "ScatterElements" in expected_names else "ScatterND"
+ ndim = 2 if op_type == "ScatterElements" else 3
+
+ assert dtype in (np.float16, np.float32)
+ itype = TensorProto.FLOAT if dtype == np.float32 else TensorProto.FLOAT16
+ model = oh.make_model(
+ oh.make_graph(
+ [
+ oh.make_node("CastLike", ["X", "I"], ["data"]),
+ oh.make_node(
+ op_type,
+ inputs=["data", "indices", "updates"],
+ outputs=["sy"],
+ # axis=0,
+ reduction=reduction,
+ ),
+ oh.make_node("Sub", ["sy", "I"], ["Y"]),
+ ],
+ "name",
+ [
+ oh.make_tensor_value_info("X", TensorProto.FLOAT, [None] * ndim),
+ oh.make_tensor_value_info("indices", TensorProto.INT64, [None, None]),
+ oh.make_tensor_value_info("updates", itype, [None] * ndim),
+ ],
+ [oh.make_tensor_value_info("Y", itype, [None] * ndim)],
+ [from_array(np.array([0], dtype=dtype), name="I")],
+ ),
+ opset_imports=[oh.make_opsetid("", opset)],
+ ir_version=8 if opset <= 18 else 9,
+ )
+
+ if not os.path.exists("temp_dump"):
+ os.mkdir("temp_dump")
+ for name in os.listdir("temp_dump"):
+ os.remove(os.path.join("temp_dump", name))
+
+ filename = f"temp_dump/{op_type}_{providers[0]}_{itype}.onnx"
+ opts = SessionOptions()
+ opts.optimized_model_filepath = filename
+ sess = InferenceSession(model.SerializeToString(), opts, providers=providers)
+ self.assertTrue(sess is not None)
+ self.assert_exists(filename)
+ onx = load(filename)
+ names = [n.op_type for n in onx.graph.node]
+ self.assertEqual(expected_names, names)
+
+ sonx = str(onx).replace(" ", "").replace("\n", "|")
+ sexp = 'op_type:"Cast"|attribute{|name:"to"|type:INT|i:%d|}' % itype
+ sexp2 = 'op_type:"Cast"|attribute{|name:"to"|i:%d|type:INT|}' % itype
+ assert sexp in sonx or sexp2 in sonx, f"Unable to find a substring in {sonx!r}"
+ if providers == ["CPUExecutionProvider"]:
+ return
+
+ if op_type == "ScatterElements":
+ data = np.zeros((3, 3), dtype=np.float32)
+ data[0, 0] = 1
+ indices = np.array([[1, 0, 2], [0, 2, 1]], dtype=np.int64)
+ updates = np.array([[1.0, 1.1, 1.2], [2.0, 2.1, 2.2]], dtype=dtype)
+ else:
+ data = np.array(
+ [
+ [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]],
+ [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]],
+ [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]],
+ [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]],
+ ],
+ dtype=np.float32,
+ )
+ indices = np.array([[0], [2]], dtype=np.int64)
+ updates = np.array(
+ [
+ [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],
+ [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]],
+ ],
+ dtype=dtype,
+ )
+ opts = SessionOptions()
+ opts.enable_profiling = True
+ opts.optimized_model_filepath = filename
+ sess = InferenceSession(model.SerializeToString(), opts, providers=providers)
+ got = sess.run(None, {"X": data, "indices": indices, "updates": updates})[0]
+ self.assertEqual(got.dtype, updates.dtype)
+ prof = sess.end_profiling()
+
+ with open(prof, "r") as f: # noqa: UP015
+ content = f.read()
+ js = json.loads(content)
+
+ exe_providers = []
+ suffixes = ["_kernel_time", "_fence_before", "_fence_after"]
+ rows = []
+ for row in js:
+ if "args" in row and isinstance(row["args"], dict):
+ for k, v in row["args"].items():
+ row[f"args_{k}"] = v
+ del row["args"]
+ name = row["name"]
+ for suf in suffixes:
+ if name.endswith(suf):
+ changed = name[: -len(suf)]
+ row["op_name"] = changed
+ break
+ rows.append(row)
+ exe_providers.append((row.get("args_provider", None), row.get("args_op_name", None)))
+ short_list = [(a, b) for a, b in exe_providers if a is not None and b is not None]
+ self.assertEqual(short_list, [("CUDAExecutionProvider", o) for o in expected_names])
+
+ @unittest.skipIf(not has_cuda(), reason="cuda not available")
+ @ignore_warnings(DeprecationWarning)
+ def test_scatterels_cuda(self):
+ default_value = [
+ "Cast",
+ "ScatterElements",
+ "Sub",
+ ]
+ expected = {
+ (np.float32, "none"): default_value,
+ (np.float16, "none"): default_value,
+ (np.float32, "add"): default_value,
+ (np.float16, "add"): default_value,
+ (np.float32, "mul"): default_value,
+ (np.float16, "mul"): default_value,
+ (np.float32, "min"): default_value,
+ (np.float16, "min"): default_value,
+ (np.float32, "max"): default_value,
+ (np.float16, "max"): default_value,
+ }
+ for opset, dtype, reduction in itertools.product(
+ [16, 18], [np.float32, np.float16], ["none", "add", "mul", "min", "max"]
+ ):
+ with self.subTest(dtype=dtype, reduction=reduction, opset=opset):
+ self.common_scatter(
+ opset,
+ ["CUDAExecutionProvider"],
+ dtype,
+ reduction,
+ expected[dtype, reduction],
+ )
+
+ @unittest.skipIf(not has_cuda(), reason="cuda not available")
+ @ignore_warnings(DeprecationWarning)
+ def test_scatternd_cuda(self):
+ default_value = [
+ "Cast",
+ "ScatterND",
+ "Sub",
+ ]
+ expected = {
+ (np.float32, "none"): default_value,
+ (np.float16, "none"): default_value,
+ (np.float32, "add"): default_value,
+ (np.float16, "add"): default_value,
+ (np.float32, "mul"): default_value,
+ (np.float16, "mul"): default_value,
+ (np.float32, "min"): default_value,
+ (np.float16, "min"): default_value,
+ (np.float32, "max"): default_value,
+ (np.float16, "max"): default_value,
+ }
+ for opset, dtype, reduction in itertools.product(
+ [16, 18], [np.float32, np.float16], ["none", "add", "mul", "min", "max"]
+ ):
+ with self.subTest(dtype=dtype, reduction=reduction, opset=opset):
+ self.common_scatter(
+ opset,
+ ["CUDAExecutionProvider"],
+ dtype,
+ reduction,
+ expected[dtype, reduction],
+ )
+
+ @ignore_warnings(DeprecationWarning)
+ def test_scatterels_cpu(self):
+ default_value = [
+ "Cast",
+ "ScatterElements",
+ "Sub",
+ ]
+ expected = {
+ (np.float32, "none"): default_value,
+ (np.float16, "none"): default_value,
+ (np.float32, "add"): default_value,
+ (np.float16, "add"): default_value,
+ (np.float32, "mul"): default_value,
+ (np.float16, "mul"): default_value,
+ (np.float32, "min"): default_value,
+ (np.float16, "min"): default_value,
+ (np.float32, "max"): default_value,
+ (np.float16, "max"): default_value,
+ }
+ for opset, dtype, reduction in itertools.product([16, 18], [np.float32], ["none", "add", "mul", "min", "max"]):
+ with self.subTest(dtype=dtype, reduction=reduction, opset=opset):
+ self.common_scatter(
+ opset,
+ ["CPUExecutionProvider"],
+ dtype,
+ reduction,
+ expected[dtype, reduction],
+ )
+
+ @ignore_warnings(DeprecationWarning)
+ def test_scatternd_cpu(self):
+ default_value = [
+ "Cast",
+ "ScatterND",
+ "Sub",
+ ]
+ expected = {
+ (np.float32, "none"): default_value,
+ (np.float16, "none"): default_value,
+ (np.float32, "add"): default_value,
+ (np.float16, "add"): default_value,
+ (np.float32, "mul"): default_value,
+ (np.float16, "mul"): default_value,
+ (np.float32, "min"): default_value,
+ (np.float16, "min"): default_value,
+ (np.float32, "max"): default_value,
+ (np.float16, "max"): default_value,
+ }
+ for opset, dtype, reduction in itertools.product([16, 18], [np.float32], ["none", "add", "mul", "min", "max"]):
+ with self.subTest(dtype=dtype, reduction=reduction, opset=opset):
+ self.common_scatter(
+ opset,
+ ["CPUExecutionProvider"],
+ dtype,
+ reduction,
+ expected[dtype, reduction],
+ )
+
+ def _scatternd_standalone_cuda(self, reduction, line):
+ model = oh.make_model(
+ oh.make_graph(
+ [
+ oh.make_node(
+ "ScatterND",
+ inputs=["data", "indices", "updates"],
+ outputs=["y"],
+ reduction=reduction,
+ )
+ ],
+ "nd",
+ [
+ oh.make_tensor_value_info("data", TensorProto.FLOAT, [None, None, None]),
+ oh.make_tensor_value_info("indices", TensorProto.INT64, [None, None]),
+ oh.make_tensor_value_info("updates", TensorProto.FLOAT, [None, None, None]),
+ ],
+ [oh.make_tensor_value_info("y", TensorProto.FLOAT, [None, None, None])],
+ ),
+ opset_imports=[oh.make_opsetid("", 18)],
+ ir_version=9,
+ )
+
+ data = np.full((2, 2, 3), 0.1, dtype=np.float32)
+ indices = np.array([[line], [1 - line], [line]], dtype=np.int64)
+ updates = (2 ** (np.arange(18) + 1).astype(np.float32).reshape((3, 2, 3))).astype(np.float32)
+
+ feeds = dict(data=data, indices=indices, updates=updates)
+ ref = ReferenceEvaluator(model)
+ expected = ref.run(None, feeds)[0]
+
+ providers = (
+ [
+ ["CUDAExecutionProvider"],
+ ["CPUExecutionProvider"],
+ ]
+ if has_cuda()
+ else [["CPUExecutionProvider"]]
+ )
+ for provider in providers:
+ sess = onnxruntime.InferenceSession(model.SerializeToString(), providers=provider)
+ got = sess.run(None, feeds)[0]
+ self.assertEqual(expected.tolist(), got.tolist())
+
+ def test_scatternd_standalone_cuda(self):
+ self._scatternd_standalone_cuda("add", 0)
+ self._scatternd_standalone_cuda("add", 1)
+ self._scatternd_standalone_cuda("mul", 0)
+ self._scatternd_standalone_cuda("mul", 1)
+ self._scatternd_standalone_cuda("min", 0)
+ self._scatternd_standalone_cuda("min", 1)
+ self._scatternd_standalone_cuda("max", 0)
+ self._scatternd_standalone_cuda("max", 1)
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/onnxruntime/test/python/quantization/test_quant_issues.py b/onnxruntime/test/python/quantization/test_quant_issues.py
new file mode 100644
index 0000000000000..66960978748ad
--- /dev/null
+++ b/onnxruntime/test/python/quantization/test_quant_issues.py
@@ -0,0 +1,72 @@
+# -------------------------------------------------------------------------
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License. See License.txt in the project root for
+# license information.
+# --------------------------------------------------------------------------
+import os
+import tempfile
+import unittest
+import warnings
+
+
+def ignore_warnings(warns):
+ """
+ Catches warnings.
+
+ :param warns: warnings to ignore
+ """
+
+ def wrapper(fct):
+ if warns is None:
+ raise AssertionError(f"warns cannot be None for '{fct}'.")
+
+ def call_f(self):
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", warns)
+ return fct(self)
+
+ return call_f
+
+ return wrapper
+
+
+class TestQuantIssues(unittest.TestCase):
+ @ignore_warnings(DeprecationWarning)
+ def test_minimal_model(self):
+ folder = os.path.join(os.path.dirname(__file__), "..", "..", "testdata")
+ onnx_path = os.path.join(folder, "qdq_minimal_model.onnx")
+ if not os.path.exists(onnx_path):
+ # The file does seem to be the same location in every CI job.
+ raise unittest.SkipTest("unable to find {onnx_path!r}")
+
+ import numpy as np
+
+ import onnxruntime.quantization as oq
+
+ class Mock:
+ def __init__(self):
+ self.i = 0
+
+ def get_next(self):
+ if self.i > 10:
+ return None
+ self.i += 1
+ return {"input": np.random.randint(0, 255, size=(1, 3, 32, 32), dtype=np.uint8)}
+
+ with tempfile.TemporaryDirectory() as temp:
+ preprocessed_path = os.path.join(temp, "preprocessed.onnx")
+ quantized_path = os.path.join(temp, "quantized.onnx")
+ oq.quant_pre_process(onnx_path, preprocessed_path, skip_symbolic_shape=True)
+ oq.quantize_static(
+ preprocessed_path,
+ quantized_path,
+ Mock(),
+ calibrate_method=oq.CalibrationMethod.Percentile,
+ op_types_to_quantize=["Conv", "Mul", "Gemm"],
+ )
+ assert os.path.exists(preprocessed_path), f"missing output {preprocessed_path!r}"
+ assert os.path.exists(quantized_path), f"missing output {quantized_path!r}"
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/onnxruntime/test/python/quantization/test_subgraph.py b/onnxruntime/test/python/quantization/test_subgraph.py
index c425bf956f976..fbf95767b3fdf 100644
--- a/onnxruntime/test/python/quantization/test_subgraph.py
+++ b/onnxruntime/test/python/quantization/test_subgraph.py
@@ -19,9 +19,13 @@ def test_dynamic_quantization_subgraph(self):
with tempfile.TemporaryDirectory() as tmpdir:
onnx_path = os.path.join(tmpdir, "decoder_model_merged.onnx")
quantized_onnx_path = os.path.join(tmpdir, "decoder_model_merged_quantized.onnx")
- urllib.request.urlretrieve(
- "https://huggingface.co/fxmarty/t5-tiny-onnx-testing/resolve/main/decoder_model_merged.onnx", onnx_path
- )
+ url = "https://huggingface.co/fxmarty/t5-tiny-onnx-testing/resolve/main/decoder_model_merged.onnx"
+ try:
+ urllib.request.urlretrieve(url, onnx_path)
+ except urllib.request.HTTPError as e:
+ # The unit test should not fail for this kind of issue.
+ # TODO: use another way to retrieve the model.
+ raise unittest.SkipTest(f"Unable to fetch {url!r} due to {e}") # noqa: B904
quantize_dynamic(
model_input=onnx_path,
@@ -62,3 +66,7 @@ def test_dynamic_quantization_subgraph(self):
if attr.type == onnx.AttributeProto.GRAPH:
for initializer in attr.g.initializer:
self.assertTrue("shared.weight" not in initializer.name)
+
+
+if __name__ == "__main__":
+ unittest.main(verbosity=2)
diff --git a/onnxruntime/test/testdata/qdq_minimal_model.onnx b/onnxruntime/test/testdata/qdq_minimal_model.onnx
new file mode 100644
index 0000000000000..04e71789e6356
Binary files /dev/null and b/onnxruntime/test/testdata/qdq_minimal_model.onnx differ
diff --git a/tools/ci_build/github/apple/upload_pod_archive_and_update_podspec.sh b/tools/ci_build/github/apple/upload_pod_archive_and_update_podspec.sh
index 311519985f3ad..30c47d7c26359 100755
--- a/tools/ci_build/github/apple/upload_pod_archive_and_update_podspec.sh
+++ b/tools/ci_build/github/apple/upload_pod_archive_and_update_podspec.sh
@@ -33,7 +33,10 @@ POD_ARCHIVE_BASENAME=$(basename "${POD_ARCHIVE_PATH}")
STORAGE_ACCOUNT_NAME="onnxruntimepackages"
STORAGE_ACCOUNT_CONTAINER_NAME="\$web"
-STORAGE_URL_PREFIX=$(az storage account show --name ${STORAGE_ACCOUNT_NAME} --query "primaryEndpoints.web" --output tsv)
+
+# TODO: See if there's a way to get the new storage account AFD URL using the Azure CLI
+#STORAGE_URL_PREFIX=$(az storage account show --name ${STORAGE_ACCOUNT_NAME} --query "primaryEndpoints.web" --output tsv)
+STORAGE_URL_PREFIX="https://onnxruntimepackages.azureedge.net/"
# upload the pod archive and set the podspec source to the pod archive URL
az storage blob upload \