diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index bd886abc98a89..5fb1e54b38c2b 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -554,8 +554,12 @@ Do not modify directly.*
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)|
-|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
-|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
+|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||12|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
+|ArgMin|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|||12|**T** = tensor(double), tensor(float), tensor(float16)|
+|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
|AveragePool|*in* X:**T**
*out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
|||10|**T** = tensor(double), tensor(float), tensor(float16)|
|||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)|
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 497d0014795ec..8396e2629d2bf 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -963,6 +963,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin);
+class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin);
+
// OpSet 13
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add);
@@ -1199,6 +1206,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin);
+
// OpSet 14
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu);
@@ -1640,6 +1654,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -1822,9 +1839,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
19, IsInf)>,
// opset 11
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
- BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -1916,6 +1930,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+
// OpSet 13
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2150,6 +2171,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+
// OpSet 14
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2566,6 +2594,32 @@ static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) {
return false;
}
+static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) {
+ // Opset 12 introduced the attribute "select_last_index"
+ if (node.SinceVersion() >= 12) {
+ const auto& node_attributes = node.GetAttributes();
+
+ for (auto& attr : node_attributes) {
+ auto& attr_name = attr.first;
+ auto& attr_value = attr.second;
+
+ // CuDNN doesn't support picking the last index in case of encountering
+ // duplicate max values.
+ // CuDNN's API doc doesn't mention what happens in case duplicates are encountered,
+ // but based on testing, the results seem to indicate a "stable" implementation
+ // (i.e.) relative ordering is preserved which is the expected behavior when the
+ // attribute takes on the default value (most common use-case for this operator).
+ if ("select_last_index" == attr_name) {
+ if (attr_value.i() != 0) {
+ return true;
+ }
+ }
+ }
+ }
+
+ return false;
+}
+
std::unique_ptr CUDAExecutionProvider::GetDataTransfer() const {
return std::make_unique();
}
@@ -2615,6 +2669,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
} else if ("ConvTranspose" == node.OpType()) {
not_supported = ConvTransposeNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred());
force_inside = !not_supported;
+ } else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) {
+ not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node);
+ force_inside = !not_supported;
} else if ("Cast" == node.OpType()) {
not_supported = CastNeedFallbackToCPU(node);
// cast is not compute heavy, and may be placed outside
diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
index 860bea67dc719..4f8e6605ce151 100644
--- a/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
+++ b/onnxruntime/core/providers/cuda/reduction/reduction_ops.cc
@@ -16,17 +16,17 @@ using namespace onnxruntime::common;
namespace onnxruntime {
namespace cuda {
-#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
+#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
- 1, end, \
+ begin, end, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
name);
-#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
+#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
@@ -37,8 +37,13 @@ namespace cuda {
name);
#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
- REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
- REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
+ REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \
+ REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur)
+
+#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \
+ REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \
+ REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \
+ REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13)
// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
template
@@ -829,14 +834,13 @@ template std::unique_ptr ReduceCompute
class ArgMax final : public ReduceKernel {
public:
- ArgMax(const OpKernelInfo& info) : ReduceKernel(info) {}
+ ArgMax(const OpKernelInfo& info) : ReduceKernel(info) {
+ // The following is just a safety check.
+ // The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMax
+ // nodes with select_last_index == 1 to the CUDA EP.
+ int64_t select_last_index = 0;
+ if (info.GetAttr("select_last_index", &select_last_index).IsOK()) {
+ ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
+ }
+ }
Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl(ctx, CUDNN_REDUCE_TENSOR_MAX);
@@ -98,7 +106,15 @@ class ArgMax final : public ReduceKernel {
template
class ArgMin final : public ReduceKernel {
public:
- ArgMin(const OpKernelInfo& info) : ReduceKernel(info) {}
+ ArgMin(const OpKernelInfo& info) : ReduceKernel(info) {
+ // The following is just a safety check.
+ // The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMin
+ // nodes with select_last_index == 1 to the CUDA EP.
+ int64_t select_last_index = 0;
+ if (info.GetAttr("select_last_index", &select_last_index).IsOK()) {
+ ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
+ }
+ }
Status ComputeInternal(OpKernelContext* ctx) const override {
return ComputeImpl(ctx, CUDNN_REDUCE_TENSOR_MIN);
diff --git a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc
index 1340c49c38ded..d8b7e26d17b65 100644
--- a/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc
+++ b/onnxruntime/core/providers/rocm/reduction/reduction_ops.cc
@@ -16,17 +16,17 @@ using namespace onnxruntime::common;
namespace onnxruntime {
namespace rocm {
-#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
+#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
- 1, end, \
+ begin, end, \
T, \
kRocmExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \
name);
-#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
+#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
name, \
kOnnxDomain, \
@@ -37,8 +37,13 @@ namespace rocm {
name);
#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
- REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
- REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
+ REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \
+ REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur)
+
+#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \
+ REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \
+ REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \
+ REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13)
// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
template
@@ -830,14 +835,13 @@ template std::unique_ptr ReduceCompute,
// BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ // BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+
BuildKernelCreateInfo,
// BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -1785,9 +1802,6 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
19, IsInf)>,
// opset 11
- BuildKernelCreateInfo,
- // BuildKernelCreateInfo,
- BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -1879,6 +1893,13 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ // BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ // BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+
// OpSet 13
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2112,6 +2133,12 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ // BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ // BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// OpSet 14
BuildKernelCreateInfo,
@@ -2387,6 +2414,26 @@ static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) {
return false;
}
+static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) {
+ // Opset 12 introduced the attribute "select_last_index"
+ if (node.SinceVersion() >= 12) {
+ const auto& node_attributes = node.GetAttributes();
+
+ for (auto& attr : node_attributes) {
+ auto& attr_name = attr.first;
+ auto& attr_value = attr.second;
+
+ // It is not supported to pick the last index in case of encountering duplicate max values.
+ if ("select_last_index" == attr_name) {
+ if (attr_value.i() != 0) {
+ return true;
+ }
+ }
+ }
+ }
+
+ return false;
+}
std::unique_ptr ROCMExecutionProvider::GetDataTransfer() const {
return std::make_unique();
}
@@ -2425,6 +2472,9 @@ ROCMExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
"GRU" == node.OpType()) {
not_supported = true;
force_inside = !not_supported;
+ } else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) {
+ not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node);
+ force_inside = !not_supported;
} else if ("Cast" == node.OpType()) {
not_supported = CastNeedFallbackToCPU(node);
// cast is not compute heavy, and may be placed outside
diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
index bb6d732fccb8f..c1c049ae5f967 100644
--- a/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
+++ b/onnxruntime/test/providers/cpu/reduction/reduction_ops_test.cc
@@ -3,6 +3,7 @@
#include
#include
+#include
#include
#include "gtest/gtest.h"
#include "test/common/dnnl_op_test_utils.h"
@@ -3337,6 +3338,41 @@ TEST(ReductionOpTest, ArgMax_int32_last_index_dups) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider});
}
+TEST(ReductionOpTest, ArgMax_float_first_index_random) {
+ OpTester test("ArgMax", 12);
+ test.AddAttribute("axis", static_cast(0));
+ test.AddAttribute("keepdims", static_cast(1));
+
+ // Since select_last_index is 0 by default, this test should run on both CPU and CUDA
+ test.AddAttribute("select_last_index", static_cast(0));
+
+ constexpr size_t vector_size = 64 * 1024;
+ constexpr float max_value = std::numeric_limits::infinity();
+
+ std::random_device rd;
+ std::mt19937 generator(rd());
+ std::uniform_int_distribution distribution(0, static_cast(vector_size) - 1);
+
+ std::vector data_vec(vector_size, 0.0f);
+
+ int min_index = -1;
+
+ // Try replace 8 elements with max_value. It is fine that some elements hit same index.
+ for (int i = 0; i < 8; ++i) {
+ int index = distribution(generator);
+ data_vec[index] = max_value;
+ if (i == 0 || index < min_index) {
+ min_index = index;
+ }
+ }
+
+ test.AddInput("data", {vector_size}, data_vec);
+ test.AddOutput("reduced", {1}, {min_index});
+
+ // Exclude OpenVINO since it failed to handle this case.
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
TEST(ReductionOpTest, ArgMax_int32_neg_axis) {
OpTester test("ArgMax");
test.AddAttribute("axis", (int64_t)(-2));
@@ -3655,6 +3691,41 @@ TEST(ReductionOpTest, ArgMin_int32_neg_axis) {
test.Run();
}
+TEST(ReductionOpTest, ArgMin_float_first_index_random) {
+ OpTester test("ArgMin", 13);
+ test.AddAttribute("axis", static_cast(0));
+ test.AddAttribute("keepdims", static_cast(1));
+
+ // Since select_last_index is 0 by default, this test should run on both CPU and CUDA
+ test.AddAttribute("select_last_index", static_cast(0));
+
+ constexpr size_t vector_size = 64 * 1024;
+ constexpr float min_value = -std::numeric_limits::infinity();
+
+ std::random_device rd;
+ std::mt19937 generator(rd());
+ std::uniform_int_distribution distribution(0, static_cast(vector_size) - 1);
+
+ std::vector data_vec(vector_size, 0.0f);
+
+ int min_index = -1;
+
+ // Try replace 8 elements with min_value. It is fine that some elements hit same index.
+ for (int i = 0; i < 8; ++i) {
+ int index = distribution(generator);
+ data_vec[index] = min_value;
+ if (i == 0 || index < min_index) {
+ min_index = index;
+ }
+ }
+
+ test.AddInput("data", {vector_size}, data_vec);
+ test.AddOutput("reduced", {1}, {min_index});
+
+ // Exclude OpenVINO since it failed to handle this case.
+ test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
+}
+
TEST(ReductionOpTest, OptimizeShapeForFastReduce_ReduceDimWithZero1) {
FastReduceKind fast_kind;
TensorShapeVector fast_shape, fast_output_shape, fast_axes;