Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Feb 23, 2024
2 parents 0c3ea39 + efbe2b8 commit f50bb0c
Show file tree
Hide file tree
Showing 93 changed files with 1,770 additions and 829 deletions.
3 changes: 1 addition & 2 deletions CITATION.cff
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ title: ONNX Runtime
message: "Please use this information to cite ONNX Runtime in
research or other publications."
authors:
- affiliation: Microsoft Corporation
given-names: ONNX Runtime developers
- name: ONNX Runtime developers
date-released: 2018-11-29
url: "https://onnxruntime.ai"
repository-code: "https://github.com/microsoft/onnxruntime"
Expand Down
9 changes: 2 additions & 7 deletions cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,8 @@ if (onnxruntime_MINIMAL_BUILD)
endif()
endif()

# Enable stream for all the non-minimal build, except for DML. There's currently a bug
# in the allocation planner when reusing buffers and more than one streams are used that
# make it possible (although rarely) to reach a reference count of 0 for a buffer that is
# still being used. Since DML doesn't benefit from multiple streams, disabling it is the
# safest option for now.
# https://github.com/microsoft/onnxruntime/issues/19480
if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML)
# Enable stream for all the non-minimal build
if (NOT onnxruntime_MINIMAL_BUILD)
add_compile_definitions(ORT_ENABLE_STREAM)
endif()

Expand Down
4 changes: 0 additions & 4 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,6 @@ set(contrib_ops_excluded_files
"bert/fastertransformer_decoder_attention/*"
"bert/multihead_attention.cc"
"bert/multihead_attention.h"
"bert/fast_gelu_impl.cu"
"bert/fast_gelu_impl.h"
"bert/fast_gelu.cc"
"bert/fast_gelu.h"
"bert/relative_attn_bias.cc"
"bert/relative_attn_bias.h"
"bert/relative_attn_bias_impl.cu"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ private void TestCUDAProviderOptions()
private void CanRunInferenceOnAModelWithTensorRT()
{
string modelPath = Path.Combine(Directory.GetCurrentDirectory(), "squeezenet.onnx");

int deviceId = 0;
string deviceIdStr = System.Environment.GetEnvironmentVariable("ONNXRUNTIME_TEST_GPU_DEVICE_ID");
if (!string.IsNullOrEmpty(deviceIdStr) && int.TryParse(deviceIdStr, out int parsedValue) && parsedValue >= 0)
Expand Down
3 changes: 3 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ Do not modify directly.*
|GatherND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|||12|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|||11|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **indices** = tensor(int64)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|20+|**T** = tensor(float)|
|Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[9, 10]|**T** = tensor(double), tensor(float)|
Expand Down Expand Up @@ -606,6 +607,7 @@ Do not modify directly.*
|GatherND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
|||12|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
|||11|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int64)<br/> **indices** = tensor(int64)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|20+|**T** = tensor(double), tensor(float), tensor(float16)|
|Gemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|||[9, 10]|**T** = tensor(double), tensor(float), tensor(float16)|
Expand All @@ -617,6 +619,7 @@ Do not modify directly.*
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|GreaterOrEqual|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
|||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
|GridSample|*in* X:**T1**<br> *in* grid:**T2**<br> *out* Y:**T1**|16+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|HardSigmoid|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|Identity|*in* input:**T**<br> *out* output:**T**<br><br>or<br><br>*in* input:**V**<br> *out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
10 changes: 8 additions & 2 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class Node;
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"

struct OrtRunOptions;

namespace onnxruntime {

/**
Expand All @@ -51,6 +53,8 @@ struct NodeComputeInfo {
DestroyFunctionStateFunc release_state_func;
};

using RunOptions = OrtRunOptions;

enum class DataLayout {
NCHW,
NHWC,
Expand Down Expand Up @@ -184,15 +188,17 @@ class IExecutionProvider {
Run may not be finished on device This function should be regarded as the
point after which a new Run would start to submit commands from CPU
*/
virtual common::Status OnRunStart() { return Status::OK(); }
virtual common::Status OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { return Status::OK(); }

/**
Called when InferenceSession::Run ended
NOTE that due to async execution in provider, the actual work of this Run
may not be finished on device This function should be regarded as the point
that all commands of current Run has been submmited by CPU
*/
virtual common::Status OnRunEnd(bool /*sync_stream*/) { return Status::OK(); }
virtual common::Status OnRunEnd(bool /*sync_stream*/, const onnxruntime::RunOptions& /*run_options*/) {
return Status::OK();
}

/**
Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ enum CudaResource : int {
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
use_tf32_t,
};
};
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,15 @@ static const char* const kOrtRunOptionsConfigEnableMemoryArenaShrinkage = "memor
// Per default it will be set to '0'
// Taking CUDA EP as an example, it omit triggering cudaStreamSynchronize on the compute stream.
static const char* const kOrtRunOptionsConfigDisableSynchronizeExecutionProviders = "disable_synchronize_execution_providers";

// Set HTP performance mode for QNN HTP backend before session run.
// options for HTP performance mode: "burst", "balanced", "default", "high_performance",
// "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver",
// "sustained_high_performance". Default to "default".
static const char* const kOrtRunOptionsConfigQnnPerfMode = "qnn.htp_perf_mode";

// Set HTP performance mode for QNN HTP backend post session run.
static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_mode_post_run";

// Set RPC control latency for QNN HTP backend
static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";
4 changes: 2 additions & 2 deletions js/node/lib/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class OnnxruntimeSessionHandler implements InferenceSessionHandler {
async run(feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, options: InferenceSession.RunOptions):
Promise<SessionHandler.ReturnType> {
return new Promise((resolve, reject) => {
process.nextTick(() => {
setImmediate(() => {
try {
resolve(this.#inferenceSession.run(feeds, fetches, options));
} catch (e) {
Expand All @@ -56,7 +56,7 @@ class OnnxruntimeBackend implements Backend {
async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions):
Promise<InferenceSessionHandler> {
return new Promise((resolve, reject) => {
process.nextTick(() => {
setImmediate(() => {
try {
resolve(new OnnxruntimeSessionHandler(pathOrBuffer, options || {}));
} catch (e) {
Expand Down
6 changes: 3 additions & 3 deletions js/react_native/e2e/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3351,9 +3351,9 @@ invariant@^2.2.4:
loose-envify "^1.0.0"

ip@^1.1.5:
version "1.1.8"
resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48"
integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg==
version "1.1.9"
resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396"
integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ==

is-accessor-descriptor@^0.1.6:
version "0.1.6"
Expand Down
6 changes: 3 additions & 3 deletions js/react_native/yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3701,9 +3701,9 @@ invariant@^2.2.4:
loose-envify "^1.0.0"

ip@^1.1.5:
version "1.1.8"
resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.8.tgz#ae05948f6b075435ed3307acce04629da8cdbf48"
integrity sha512-PuExPYUiu6qMBQb4l06ecm6T6ujzhmh+MeJcW9wa89PoAz5pvd4zPgN5WJV104mb6S2T1AwNIAaB70JNrLQWhg==
version "1.1.9"
resolved "https://registry.yarnpkg.com/ip/-/ip-1.1.9.tgz#8dfbcc99a754d07f425310b86a99546b1151e396"
integrity sha512-cyRxvOEpNHNtchU3Ln9KC/auJgup87llfQpQ+t5ghoC/UhL16SWzbueiCsdTnWmqAWl7LadfuwhlqmtOaqMHdQ==

is-absolute@^1.0.0:
version "1.0.0"
Expand Down
4 changes: 2 additions & 2 deletions js/web/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The [Open Neural Network Exchange](http://onnx.ai/) (ONNX) is an open standard f

With ONNX Runtime Web, web developers can score models directly on browsers with various benefits including reducing server-client communication and protecting user privacy, as well as offering install-free and cross-platform in-browser ML experience.

ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web complies the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend.
ONNX Runtime Web can run on both CPU and GPU. On CPU side, [WebAssembly](https://developer.mozilla.org/en-US/docs/WebAssembly) is adopted to execute the model at near-native speed. ONNX Runtime Web compiles the native ONNX Runtime CPU engine into WebAssembly backend by using Emscripten, so it supports most functionalities native ONNX Runtime offers, including full ONNX operator coverage, multi-threading, [ONNX Runtime Quantization](https://www.onnxruntime.ai/docs/how-to/quantization.html) as well as [ONNX Runtime Mobile](https://onnxruntime.ai/docs/tutorials/mobile/). For performance acceleration with GPUs, ONNX Runtime Web leverages WebGL, a popular standard for accessing GPU capabilities. We are keeping improving op coverage and optimizing performance in WebGL backend.

See [Compatibility](#Compatibility) and [Operators Supported](#Operators) for a list of platforms and operators ONNX Runtime Web currently supports.

Expand All @@ -22,7 +22,7 @@ Refer to [ONNX Runtime JavaScript examples](https://github.com/microsoft/onnxrun

## Documents

### Developement
### Development

Refer to the following links for development information:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ import {DataType} from '../../../../wasm-common';
import {LOG_DEBUG} from '../../../log';
import {TensorView} from '../../../tensor-view';
import {ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../../types';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from '../common';
import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType, UniformsArrayType} from '../common';
import {ConvTransposeAttributes} from '../conv-transpose';
import {appendActivationUniforms, appendActivationUniformsData, getActivationSnippet} from '../fuse-utils';

import {biasSnippet, typeSnippet} from './activation_util';
import {biasSnippet} from './activation_util';
import {utilFunctions} from './conv_util';
import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu';

const conv2dTransposeCommonSnippet =
(isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, innerElementSize = 4): string => {
const type = typeSnippet(innerElementSize, 'f32');
(isChannelsLast: boolean, addBias = false, attributes: ConvTransposeAttributes, type: string,
innerElementSize = 4): string => {
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
Expand All @@ -47,7 +47,7 @@ const conv2dTransposeCommonSnippet =
let v1 = w[getIndexFromCoords4D(coord1, vec4<i32>(uniforms.w_shape))];
let v2 = w[getIndexFromCoords4D(coord2, vec4<i32>(uniforms.w_shape))];
let v3 = w[getIndexFromCoords4D(coord3, vec4<i32>(uniforms.w_shape))];
return vec4<f32>(v0, v1, v2, v3);
return ${type}(v0, v1, v2, v3);
`;
default:
throw new Error(`innerElementSize ${innerElementSize} is not supported.`);
Expand Down Expand Up @@ -224,7 +224,7 @@ export const createConv2DTransposeMatMulProgramInfo =
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, components);
inputVariables.push(bias);
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${isVec4 ? 'vec4<f32>' : 'f32'} {
fn getBiasByOutputCoords(coords : vec4<i32>) -> ${bias.type.value} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
Expand All @@ -236,16 +236,20 @@ export const createConv2DTransposeMatMulProgramInfo =
{name: 'pads', type: 'i32', length: pads.length}
];
appendActivationUniforms(attributes, uniforms);
const elemType = tensorTypeToWsglStorageType(inputs[0].dataType, 1);
if (elemType !== 'f16' && elemType !== 'f32') {
throw new Error(`elemType ${elemType} is not supported.`);
}
return `
${utilFunctions('uniforms.result_strides')}
${shaderHelper.registerUniforms(uniforms).declareVariables(...inputVariables, output)};
${declareFunctions}
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, innerElementSize)}
${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, attributes, x.type.value, innerElementSize)}
${
isVec4 ? makeMatMulPackedVec4Source(
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) :
elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner) :
makeMatMulPackedSource(
elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false,
elementsPerThread, workGroupSize, elemType, undefined, !isChannelsLast, tileInner, false,
undefined, sequentialAccessByThreads)}`;
};

Expand Down
3 changes: 2 additions & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/where.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ const createWhereOpProgramShader =
const expressionA = `a_data[index_a${x}][component_a${x}]`;
const expressionB = `b_data[index_b${x}][component_b${x}]`;
// eslint-disable-next-line no-bitwise
const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
const expressionC = `bool(c_data[index_c${x}] & (0xffu << (component_c${x} * 8)))`;
return `
let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
Expand All @@ -38,6 +38,7 @@ const createWhereOpProgramShader =
let index_c${x} = offset_c${x} / 4u;
let component_a${x} = offset_a${x} % 4u;
let component_b${x} = offset_b${x} % 4u;
let component_c${x} = offset_c${x} % 4u;
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
`;
};
Expand Down
34 changes: 34 additions & 0 deletions js/web/test/data/ops/where.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -168,5 +168,39 @@
]
}
]
},
{
"name": "Where with no attributes",
"operator": "Where",
"attributes": [],
"cases": [
{
"name": "T[1 1 2 1] T[1 4] T[1 1 2 4] float32 broadcast 1",
"inputs": [
{
"data": [true, false],
"dims": [1, 1, 2, 1],
"type": "bool"
},
{
"data": [1, 2, 3, 4],
"dims": [1, 4],
"type": "float32"
},
{
"data": [5, 6, 7, 8, 9, 10, 11, 12],
"dims": [1, 1, 2, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [1, 2, 3, 4, 9, 10, 11, 12],
"dims": [1, 1, 2, 4],
"type": "float32"
}
]
}
]
}
]
4 changes: 2 additions & 2 deletions js/web/test/test-runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -627,8 +627,8 @@ export async function runModelTestSet(
try {
const feeds: Record<string, ort.Tensor> = {};
const outputsMetaInfo: Record<string, ort.Tensor> = {};
testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor);
testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor);
testCase.inputs!.forEach((tensor) => feeds[tensor.name] = tensor);
testCase.outputs!.forEach((tensor) => outputsMetaInfo[tensor.name] = tensor);
const [start, end, outputs] =
await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding});
if (context.perfData.count === 0) {
Expand Down
10 changes: 1 addition & 9 deletions onnxruntime/contrib_ops/cpu/activations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the MIT License.

#include "core/providers/cpu/activation/activations.h"
#include "activations.h"
#include "contrib_ops/cpu/activations.h"

namespace onnxruntime {
namespace contrib {
Expand All @@ -26,14 +26,6 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
ThresholdedRelu<float>);

ONNX_OPERATOR_KERNEL_EX(
Gelu,
kMSDomain,
1,
kCpuExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
Gelu<float>);

ONNX_OPERATOR_KERNEL_EX(
QuickGelu,
kMSDomain,
Expand Down
Loading

0 comments on commit f50bb0c

Please sign in to comment.