Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ std::vector<int64_t> ChannelLastToFirstPerm(size_t rank) {
}

std::vector<int64_t> p(rank);
p[0] = 0;
p[0] = 0; // This is usually the batch dimension (hence preserve this position)
p[1] = rank - 1;
for (size_t i = 2; i < rank; ++i) {
p[i] = i - 1;
Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 19, float, GridSample);

// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
Expand Down Expand Up @@ -1510,6 +1510,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, 21, float, GridSample);

// Opset 21.
// TODO(fajin): support other quantized types
Expand Down Expand Up @@ -1583,6 +1584,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSwish);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSwish);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, GridSample);

// Opset 23.
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention);
Expand Down Expand Up @@ -2485,7 +2487,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 19, float, GridSample)>,

// Opset 17
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization)>,
Expand Down Expand Up @@ -2582,6 +2584,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, MLFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsInf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, IsNaN)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 20, 21, float, GridSample)>,

// Opset 21
// TODO(fajin): support other quantized types
Expand Down Expand Up @@ -2654,6 +2657,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, double, HardSwish)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, MLFloat16, HardSwish)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, BFloat16, HardSwish)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 22, float, GridSample)>,

// Opset 23
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 23, float, Attention)>,
Expand Down
9 changes: 7 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_nhwc_kernels.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -164,12 +164,17 @@ Status RegisterCudaNhwcKernels(KernelRegistry& kernel_registry) {
#ifndef DISABLE_CONTRIB_OPS
namespace onnxruntime::contrib::cuda {

class CUDA_NHWC_OP_TYPED_CLASS_NAME(16, float, GridSample);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(16, 19, float, GridSample);
class CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(20, 21, float, GridSample);
class CUDA_NHWC_OP_TYPED_CLASS_NAME(22, float, GridSample);

onnxruntime::common::Status RegisterCudaNhwcContribKernels(onnxruntime::KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn nhwc_function_table[] = {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(16, float, GridSample)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(16, 19, float, GridSample)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_VERSIONED_TYPED_CLASS_NAME(20, 21, float, GridSample)>,
BuildKernelCreateInfo<CUDA_NHWC_OP_TYPED_CLASS_NAME(22, float, GridSample)>,

};

for (auto& function_table_entry : nhwc_function_table) {
Expand Down
159 changes: 128 additions & 31 deletions onnxruntime/core/providers/cuda/tensor/grid_sample.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,66 @@
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()), \
onnxruntime::contrib::cuda::GridSample<T, LAYOUT>);

#define REGISTER_KERNEL_VERSIONED_TYPED(T, FROM_VERSION, TO_VERSION, LAYOUT, DOMAIN) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
GridSample, \
DOMAIN, \
FROM_VERSION, \
TO_VERSION, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<T>()), \
onnxruntime::contrib::cuda::GridSample<T, LAYOUT>);

REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain)

#ifdef ENABLE_CUDA_NHWC_OPS
REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain)
// Op was introduced in opset 16
REGISTER_KERNEL_VERSIONED_TYPED(float, 16, 19, LAYOUT_NHWC, kMSInternalNHWCDomain)

// Op was modified to support multiple spatial dimensions in opset 20
REGISTER_KERNEL_VERSIONED_TYPED(float, 20, 21, LAYOUT_NHWC, kMSInternalNHWCDomain)

// Op spec introduced BFloat16 support in opset 22
REGISTER_KERNEL_TYPED(float, 22, LAYOUT_NHWC, kMSInternalNHWCDomain)
#endif

template <typename T, bool IsNHWC>
GridSample<T, IsNHWC>::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
opset_start_version_ = info.node().SinceVersion();

std::string mode_str = info.GetAttrOrDefault<std::string>("mode", "bilinear");
std::string padding_mode_str = info.GetAttrOrDefault<std::string>("padding_mode", "zeros");
align_corners_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("align_corners", 0));
ORT_ENFORCE(mode_str == "bilinear" || mode_str == "nearest" || mode_str == "bicubic",
"mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic");
ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection",
"padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection");
if (mode_str == "bicubic") {
mode_i_ = 2;
} else if (mode_str == "nearest") {
mode_i_ = 1;

if (opset_start_version_ >= 20) {
std::string mode_str = info.GetAttrOrDefault<std::string>("mode", "linear");
if (mode_str == "cubic") {
Comment on lines 32 to +60
Copy link

Copilot AI Jan 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mode_str is read with default "bilinear" and then immediately shadowed by a new mode_str inside both the opset>=20 and opset<20 branches, leaving the outer variable unused (can trigger -Wunused-variable). Consider removing the outer mode_str and only reading the attribute within the version-specific branch (or reuse the existing variable instead of shadowing).

Copilot uses AI. Check for mistakes.
mode_i_ = 2;
} else if (mode_str == "nearest") {
mode_i_ = 1;
} else if (mode_str == "linear") {
mode_i_ = 0;
} else {
ORT_THROW("mode \"", mode_str, "\" not supported, expect linear, nearest or cubic");
}
} else {
mode_i_ = 0;
std::string mode_str = info.GetAttrOrDefault<std::string>("mode", "bilinear");

Check warning on line 70 in onnxruntime/core/providers/cuda/tensor/grid_sample.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/tensor/grid_sample.cc:70: Add #include <string> for string [build/include_what_you_use] [4]
if (mode_str == "bicubic") {
mode_i_ = 2;
} else if (mode_str == "nearest") {
mode_i_ = 1;
} else if (mode_str == "bilinear") {
mode_i_ = 0;
} else {
ORT_THROW("mode \"", mode_str, "\" not supported, expect bilinear, nearest or bicubic");
}
}

ORT_ENFORCE(padding_mode_str == "zeros" || padding_mode_str == "border" || padding_mode_str == "reflection",
"padding_mode \"", padding_mode_str, "\" not supported, expect zeros, border or reflection");
if (padding_mode_str == "reflection") {
padding_mode_i_ = 2;
} else if (padding_mode_str == "border") {
Expand All @@ -59,44 +97,103 @@
const Tensor* Grid = context->Input<Tensor>(1);
const auto& dims_grid = Grid->Shape().GetDims();

if (dims_input.size() != 4 || dims_grid.size() != 4) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Only 4-D tensor is supported");
if (dims_input.size() != dims_grid.size()) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Input and grid must have the same number of dimensions");
}

if (opset_start_version_ < 20 && dims_input.size() != 4) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Opset 16-19 versions of this op only supports 4-D input tensors");
}

if (dims_input[0] != dims_grid[0]) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Grid batch size does not match input batch size ");
}

if ((dims_input.size() == 4 && dims_grid[3] != 2) || (dims_input.size() == 5 && dims_grid[4] != 3)) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Last dimension of grid input must match the number of "
"spatial dimensions in the input (2 for 2D, 3 for 3D).");
}

if (dims_input.size() != 4 && dims_input.size() != 5) {
return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Only 4-D and 5-D input tensors are supported");
}


if (dims_input.size() == 5 && mode_i_ == 2) {
// This is common for CPU and CUDA to not support Cubic mode for 5D input
// So it won't break CUDA users who were previously dropping down to CPU version of the op.
return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED, "Cubic mode is only supported in 4-D cases.");
}
ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]);
ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2");

using Ch = Channels<IsNHWC>;

TensorShapeVector dims_output(4);
dims_output[Ch::N] = dims_input[Ch::N];
dims_output[Ch::C] = dims_input[Ch::C];
dims_output[Ch::H] = dims_grid[1 /* Grid::H */];
dims_output[Ch::W] = dims_grid[2 /* Grid::W */];
TensorShapeVector dims_output(dims_input.size());
if (dims_input.size() == 4) {
dims_output[Ch::N] = dims_input[Ch::N];
dims_output[Ch::C] = dims_input[Ch::C];
dims_output[Ch::H] = dims_grid[1 /* Grid::H */];
dims_output[Ch::W] = dims_grid[2 /* Grid::W */];
} else {
// 5D input - deal with both NCHW and NHWC layouts
dims_output[0] = dims_input[0];
dims_output[1] = !IsNHWC ? dims_input[1] : dims_grid[1];
dims_output[2] = !IsNHWC ? dims_grid[1] : dims_grid[2];
dims_output[3] = !IsNHWC ? dims_grid[2] : dims_grid[3];
dims_output[4] = !IsNHWC ? dims_grid[3] : dims_input[4];
}
Tensor* Y = context->Output(0, dims_output);

// Return early if the output tensor is going to be of size 0
if (Y->Shape().Size() == 0) {
return Status::OK();
}

typedef typename ToCudaType<T>::MappedType CudaT;
CudaT* Y_data = reinterpret_cast<CudaT*>(Y->MutableData<T>());
GridSampleImpl<CudaT, IsNHWC>(
Stream(context),
reinterpret_cast<const CudaT*>(X->Data<T>()),
reinterpret_cast<const CudaT*>(Grid->Data<T>()),
mode_i_,
padding_mode_i_,
align_corners_,
dims_input.data(),
dims_grid[1],
dims_grid[2],
Y_data);

if (dims_input.size() == 4) {
// sample 2d
GridSampleImpl<CudaT, IsNHWC>(
Stream(context),
reinterpret_cast<const CudaT*>(X->Data<T>()),
reinterpret_cast<const CudaT*>(Grid->Data<T>()),
mode_i_,
padding_mode_i_,
align_corners_,
dims_input.data(),
dims_grid[1],
dims_grid[2],
Y_data);
} else {
// sample 3d
GridSampleImpl3D<CudaT, IsNHWC>(
Stream(context),
reinterpret_cast<const CudaT*>(X->Data<T>()),
reinterpret_cast<const CudaT*>(Grid->Data<T>()),
mode_i_,
padding_mode_i_,
align_corners_,
dims_input.data(),
dims_grid[1],
dims_grid[2],
dims_grid[3],
Y_data);
}


return Status::OK();
}
} // namespace cuda
} // namespace contrib

namespace cuda {
REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain)
// Op was introduced in opset 16
REGISTER_KERNEL_VERSIONED_TYPED(float, 16, 19, LAYOUT_NCHW, kOnnxDomain)

// Op was modified to support multiple spatial dimensions in opset 20
REGISTER_KERNEL_VERSIONED_TYPED(float, 20, 21, LAYOUT_NCHW, kOnnxDomain)

// Op spec introduced BFloat16 support in opset 22
REGISTER_KERNEL_TYPED(float, 22, LAYOUT_NCHW, kOnnxDomain)
} // namespace cuda
} // namespace onnxruntime
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/tensor/grid_sample.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class GridSample final : public CudaKernel {
int64_t mode_i_; // 0: bilinear (default), 1: nearest 2: bicubic
int64_t padding_mode_i_; // 0:'zeros', 1: 'border', 2:'reflection'
int64_t align_corners_;
int opset_start_version_;
};

} // namespace cuda
Expand Down
Loading
Loading