Skip to content

Commit

Permalink
redefine getElementType add stride tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Louly committed May 24, 2024
1 parent 8d85429 commit 29ddaed
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 30 deletions.
28 changes: 14 additions & 14 deletions onnxruntime/core/providers/cuda/tensor/gather_elements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,20 @@ void CoalesceDimensions(TensorShapeVector& input_shape, TensorShapeVector& indic

// GatherElementsGrad needs atomic_add which supports float types only, so use half, float and double for 16, 32, and 64
// bits data respectively.
ONNX_NAMESPACE::TensorProto_DataType GetElementType(size_t element_size) {
switch (element_size) {
case sizeof(int8_t):
return ONNX_NAMESPACE::TensorProto_DataType_INT8;
case sizeof(MLFloat16):
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
case sizeof(float):
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
case sizeof(double):
return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
ONNX_NAMESPACE::TensorProto_DataType GetElementType(const DataTypeImpl* dtype) {
if (dtype == DataTypeImpl::GetType<int8_t>()) {
return ONNX_NAMESPACE::TensorProto_DataType_INT8;
} else if (dtype == DataTypeImpl::GetType<MLFloat16>()) {
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT16;
} else if (dtype == DataTypeImpl::GetType<float>()) {
return ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
} else if (dtype == DataTypeImpl::GetType<double>()) {
return ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
} else if (dtype == DataTypeImpl::GetType<BFloat16>()) {
return ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16;
} else {
// should not reach here as we validate if the all relevant types are supported in the Compute method
default:
return ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
return ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED;
}
}

Expand Down Expand Up @@ -183,8 +184,7 @@ Status GatherElements::ComputeInternal(OpKernelContext* context) const {
#endif
CoalesceDimensions(input_shape_vec, indices_shape_vec, p_indices_strides_vec, axis, args);

// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.
int dtype = GetElementType(input_tensor->DataType()->Size());
int dtype = GetElementType(input_tensor->DataType());
if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
ORT_THROW("Unsupported element size by the GatherElements CUDA kernel");
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/tensor/gather_elements.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ struct GatherScatterElementsArgs;
// dim-1 and dim-2 is contiguous (20==4*5), but dim-0 and dim-1 is not contiguous (0!=3*20).
void CoalesceDimensions(TensorShapeVector& input_shape, TensorShapeVector& indices_shape,
TensorShapeVector* p_indices_strides, int64_t axis, GatherScatterElementsArgs& args);
ONNX_NAMESPACE::TensorProto_DataType GetElementType(size_t element_size);
ONNX_NAMESPACE::TensorProto_DataType GetElementType(const DataTypeImpl* dtype);

class GatherElements final : public CudaKernel {
public:
Expand Down
11 changes: 0 additions & 11 deletions onnxruntime/core/providers/cuda/tensor/gather_elements_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -321,17 +321,6 @@ struct FuncAtomicAdd {
const size_t numel_;
};

template <>
struct FuncAtomicAdd<BFloat16> {
const size_t numel_;

FuncAtomicAdd(const size_t numel) : numel_(numel) {}

__device__ __inline__ void operator()(BFloat16* start_addr, size_t index, BFloat16 value) const {
atomic_add(start_addr + index, value);
}
};

template <typename T, typename TIndex>
Status GatherElementsGradNonDeterministicImpl(cudaStream_t stream, const TIndex* indices_data, const T* updates_data,
T* output_data,
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/providers/cuda/tensor/scatter_elements.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
ORT_THROW("Unsupported reduction type for ScatterElements.");
}

// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.
int dtype = GetElementType(input_tensor->DataType()->Size());
int dtype = GetElementType(input_tensor->DataType());
if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) {
ORT_THROW("Unsupported element size by the ScatterElements CUDA kernel");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,18 @@ TEST(GatherElementsGrad, Strided_float) { RunKernelComputeTestWrapper<float>();
TEST(GatherElementsGrad, Strided_double) { RunKernelComputeTestWrapper<double>(); }

TEST(GatherElementsGrad, Strided_MLFloat16) { RunKernelComputeTestWrapper<MLFloat16>(); }

TEST(GatherElementsGrad, Strided_BFloat16) {
#ifdef USE_CUDA
int min_cuda_architecture = 530;
if (!onnxruntime::test::HasCudaEnvironment(min_cuda_architecture)) {
LOGS_DEFAULT(WARNING) << "Hardware does not support BFP16";
return;
}
#endif
RunKernelComputeTestWrapper<BFloat16>();
}

#endif

} // namespace test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,7 @@ Status GatherElementsGrad::ComputeInternal(OpKernelContext* context) const {
#endif
CoalesceDimensions(data_shape_vec, indices_shape_vec, p_indices_strides_vec, axis, args);

// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.
int dtype = GetElementType(dY->DataType()->Size());
int dtype = GetElementType(dY->DataType());
// GatherElementsGrad supports half, bfloat16, float and double only for now, it's element size will not but INT8.
if (dtype == ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED || dtype == ONNX_NAMESPACE::TensorProto_DataType_INT8) {
ORT_THROW("Unsupported element size by the GatherElementsGrad CUDA kernel");
Expand Down

0 comments on commit 29ddaed

Please sign in to comment.