@@ -11,10 +11,13 @@ namespace contrib {
11
11
namespace rocm {
12
12
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , float , GridSample);
13
13
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , float , FastGelu);
14
+ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , double , FastGelu);
14
15
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , MLFloat16, FastGelu);
16
+ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , BFloat16, FastGelu);
15
17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , float , Gelu);
16
18
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , double , Gelu);
17
19
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , MLFloat16, Gelu);
20
+ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , BFloat16, Gelu);
18
21
class ONNX_OPERATOR_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , BiasGelu);
19
22
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , MLFloat16, BiasSplitGelu);
20
23
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , float , BiasSplitGelu);
@@ -126,7 +129,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1
126
129
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , MLFloat16_int8_t, QAttention);
127
130
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , float , FusedConv);
128
131
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , MLFloat16, FusedConv);
129
- class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , BFloat16, FastGelu);
130
132
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , BFloat16, TransposeMatMul); // backward compatibility
131
133
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , BFloat16, FusedMatMul);
132
134
// class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, QOrderedMatMul);
@@ -173,10 +175,13 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
173
175
BuildKernelCreateInfo<void >, // default entry to avoid the list become empty after ops-reducing
174
176
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , float , GridSample)>,
175
177
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , float , FastGelu)>,
178
+ BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , double , FastGelu)>,
176
179
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , MLFloat16, FastGelu)>,
180
+ BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , BFloat16, FastGelu)>,
177
181
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , float , Gelu)>,
178
182
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , double , Gelu)>,
179
183
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , MLFloat16, Gelu)>,
184
+ BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , BFloat16, Gelu)>,
180
185
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , BiasGelu)>,
181
186
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , MLFloat16, BiasSplitGelu)>,
182
187
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , float , BiasSplitGelu)>,
@@ -287,7 +292,6 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) {
287
292
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
288
293
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>,
289
294
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , Trilu)>,
290
- BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , BFloat16, FastGelu)>,
291
295
// TransposedMatMul is still here for backward compatibility
292
296
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , BFloat16, TransposeMatMul)>, // backward compatibility
293
297
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME (kRocmExecutionProvider , kMSDomain , 1 , BFloat16, FusedMatMul)>,
0 commit comments