Skip to content

Commit

Permalink
[CUDA/ROCm] Conditionally support ArgMax and ArgMin for opset 12 and …
Browse files Browse the repository at this point in the history
…above (#22713)

### Description
Based on #9700, and extend
it to ArgMin as well.

This pull request introduces several enhancements and fixes related to
the `ArgMax` and `ArgMin` operators in the CUDA execution provider. The
changes ensure proper handling of these operators across different
versions and improve kernel registration and fallback mechanisms.

Key changes include:

#### Enhancements to `ArgMax` and `ArgMin` Operators:

* Added new kernel class registrations for `ArgMax` and `ArgMin` for
different data types and versions in
`onnxruntime/core/providers/cuda/cuda_execution_provider.cc`.
[[1]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R966-R972)
[[2]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1209-R1215)
[[3]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1657-R1659)
[[4]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285L1825-L1827)
[[5]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1933-R1939)
[[6]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2174-R2180)

* Introduced `ArgMaxOrArgMinNeedFallbackToCPU` function to handle
fallback to CPU when the `select_last_index` attribute is set to 1, as
CUDA does not support this attribute.
[[1]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2597-R2622)
[[2]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2672-R2674)

#### Macro and Kernel Registration Improvements:

* Replaced `REGISTER_KERNEL_UNTIL_VERSIONED_TYPED` with
`REGISTER_KERNEL_VERSIONED_RANGE_TYPED` and
`REGISTER_KERNEL_VERSIONED_SINCE_TYPED` macros for better version
handling.
[[1]](diffhunk://#diff-ee5316fc3898058f70e942d9a84de36be4c7da09f144633a2504236430d5d033L19-R29)
[[2]](diffhunk://#diff-ee5316fc3898058f70e942d9a84de36be4c7da09f144633a2504236430d5d033L40-R46)

* Updated kernel registration for `ArgMax` and `ArgMin` to use the new
macros, ensuring proper version handling and support for different data
types.

#### Safety Checks:

* Added safety checks in the `ArgMax` and `ArgMin` classes to ensure
`select_last_index` is not set to 1, as it is not supported on CUDA.
[[1]](diffhunk://#diff-8ab09fef1f4a12cbf3b3432e509f8f1ef561e83c72778a0e047780060aeef6efL91-R99)
[[2]](diffhunk://#diff-8ab09fef1f4a12cbf3b3432e509f8f1ef561e83c72778a0e047780060aeef6efL101-R117)

#### Testing Enhancements:

* Added new tests for `ArgMax` and `ArgMin` operators to verify behavior
when `select_last_index` is set to 0, ensuring compatibility with both
CPU and CUDA execution providers.
[[1]](diffhunk://#diff-77affe1b70d1a9d38c2485f7c6b16ef2b6b541ed94dd727bc9b286f068f1481aR3340-R3360)
[[2]](diffhunk://#diff-77affe1b70d1a9d38c2485f7c6b16ef2b6b541ed94dd727bc9b286f068f1481aR3679-R3699)

### Motivation and Context
Improve CUDA kernel coverage for stable diffusion model and hence
improve its performance on CUDA
  • Loading branch information
tianleiwu authored Nov 6, 2024
1 parent d993ec3 commit ba22d78
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 34 deletions.
8 changes: 6 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -554,8 +554,12 @@ Do not modify directly.*
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Affine|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|And|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|||10|**T** = tensor(double), tensor(float), tensor(float16)|
|||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)|
Expand Down
63 changes: 60 additions & 3 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum);

class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin);

// OpSet 13
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add);
Expand Down Expand Up @@ -1199,6 +1206,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear);

class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin);

// OpSet 14
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu);
Expand Down Expand Up @@ -1640,6 +1654,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1)>,
Expand Down Expand Up @@ -1822,9 +1839,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
19, IsInf)>,

// opset 11
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
Expand Down Expand Up @@ -1916,6 +1930,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum)>,

BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin)>,

// OpSet 13
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add)>,
Expand Down Expand Up @@ -2150,6 +2171,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear)>,

BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin)>,

// OpSet 14
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu)>,
Expand Down Expand Up @@ -2566,6 +2594,32 @@ static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) {
return false;
}

static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) {
// Opset 12 introduced the attribute "select_last_index"
if (node.SinceVersion() >= 12) {
const auto& node_attributes = node.GetAttributes();

for (auto& attr : node_attributes) {
auto& attr_name = attr.first;
auto& attr_value = attr.second;

// CuDNN doesn't support picking the last index in case of encountering
// duplicate max values.
// CuDNN's API doc doesn't mention what happens in case duplicates are encountered,
// but based on testing, the results seem to indicate a "stable" implementation
// (i.e.) relative ordering is preserved which is the expected behavior when the
// attribute takes on the default value (most common use-case for this operator).
if ("select_last_index" == attr_name) {
if (attr_value.i() != 0) {
return true;
}
}
}
}

return false;
}

std::unique_ptr<onnxruntime::IDataTransfer> CUDAExecutionProvider::GetDataTransfer() const {
return std::make_unique<onnxruntime::GPUDataTransfer>();
}
Expand Down Expand Up @@ -2615,6 +2669,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
} else if ("ConvTranspose" == node.OpType()) {
not_supported = ConvTransposeNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred());
force_inside = !not_supported;
} else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) {
not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node);
force_inside = !not_supported;
} else if ("Cast" == node.OpType()) {
not_supported = CastNeedFallbackToCPU(node);
// cast is not compute heavy, and may be placed outside
Expand Down
28 changes: 16 additions & 12 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ using namespace onnxruntime::common;
namespace onnxruntime {
namespace cuda {

#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
1, end, \
begin, end, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);

#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
Expand All @@ -37,8 +37,13 @@ namespace cuda {
name<T>);

#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur)

#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13)

// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
template <bool allow_multi_axes>
Expand Down Expand Up @@ -829,14 +834,13 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, CUDNN_REDUCE_TENSOR_NO

} // namespace ReductionOps

// CUDA ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, float, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, MLFloat16)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, float)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, double)

REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, float, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18)
Expand Down
20 changes: 18 additions & 2 deletions onnxruntime/core/providers/cuda/reduction/reduction_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,15 @@ class ReduceKernel : public CudaKernel, public ReduceKernelBase<allow_multi_axes
template <typename T>
class ArgMax final : public ReduceKernel<false> {
public:
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {}
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {
// The following is just a safety check.
// The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMax
// nodes with select_last_index == 1 to the CUDA EP.
int64_t select_last_index = 0;
if (info.GetAttr<int64_t>("select_last_index", &select_last_index).IsOK()) {
ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
}
}

Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl<T, CUDNN_REDUCE_TENSOR_FLATTENED_INDICES>(ctx, CUDNN_REDUCE_TENSOR_MAX);
Expand All @@ -98,7 +106,15 @@ class ArgMax final : public ReduceKernel<false> {
template <typename T>
class ArgMin final : public ReduceKernel<false> {
public:
ArgMin(const OpKernelInfo& info) : ReduceKernel<false>(info) {}
ArgMin(const OpKernelInfo& info) : ReduceKernel<false>(info) {
// The following is just a safety check.
// The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMin
// nodes with select_last_index == 1 to the CUDA EP.
int64_t select_last_index = 0;
if (info.GetAttr<int64_t>("select_last_index", &select_last_index).IsOK()) {
ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
}
}

Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl<T, CUDNN_REDUCE_TENSOR_FLATTENED_INDICES>(ctx, CUDNN_REDUCE_TENSOR_MIN);
Expand Down
28 changes: 16 additions & 12 deletions onnxruntime/core/providers/rocm/reduction/reduction_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ using namespace onnxruntime::common;
namespace onnxruntime {
namespace rocm {

#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
1, end, \
begin, end, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
name<T>);

#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
Expand All @@ -37,8 +37,13 @@ namespace rocm {
name<T>);

#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur)

#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13)

// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
template <bool allow_multi_axes>
Expand Down Expand Up @@ -830,14 +835,13 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, MIOPEN_REDUCE_TENSOR_N

} // namespace ReductionOps

// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, float, 11)
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, MLFloat16)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, float)
// REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, double)

REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, MLFloat16, 11)
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, float, 11)
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16)
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float)
// REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double)

REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18)
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18)
Expand Down
Loading

0 comments on commit ba22d78

Please sign in to comment.