Skip to content

Commit

Permalink
add MaskedScatterNdOfShape
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 7, 2024
1 parent f0da179 commit 2605674
Show file tree
Hide file tree
Showing 5 changed files with 361 additions and 25 deletions.
2 changes: 2 additions & 0 deletions operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
#ifdef USE_CUDA
,
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<float>),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
#if ORT_API_VERSION >= 16

CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<ortc::MFloat16>),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>)
#endif
Expand Down
70 changes: 70 additions & 0 deletions operators/cuda/scatter_nd_of_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,4 +70,74 @@ struct ScatterNDOfShape {
ScatterReduction reduction_;
};


template <typename T>
struct MaskedScatterNDOfShape {
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string value;
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
if (status != nullptr)
return status;

if (value == "add")
reduction_ = ScatterReduction::Add;
else if (value == "mul")
reduction_ = ScatterReduction::Mul;
else if (value == "min")
reduction_ = ScatterReduction::Min;
else if (value == "max")
reduction_ = ScatterReduction::Max;
else
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);

status = OrtW::GetOpAttribute(info, "maskedValue", masked_value_);
if (status != nullptr)
return status;

return nullptr;
}

OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<int64_t>& output_shape,
const ortc::Tensor<int64_t>& indices,
const ortc::Tensor<T>& updates,
ortc::Tensor<T>& output) const {
auto& output_shape_shape = output_shape.Shape();
auto& indices_shape = indices.Shape();
auto& updates_shape = updates.Shape();

if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
}
if (indices_shape[indices_shape.size() - 1] != 1) {
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
}

const int64_t* shape_data = output_shape.Data(); // CPU pointer
const int64_t* indices_data = indices.Data(); // GPU pointer
const T* updates_data = updates.Data(); // GPU pointer
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
T* output_data = output.Allocate(voutput_shape); // GPU pointer
LaunchMaskedScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
voutput_shape,
indices_shape,
indices_data,
updates_data,
output_data,
reduction_,
masked_value_);
return nullptr;
}

static OrtMemType GetInputMemoryType(size_t input_index) {
if (input_index == 0) // shape
return OrtMemType::OrtMemTypeCPUInput;
return OrtMemType::OrtMemTypeDefault;
}

private:
ScatterReduction reduction_;
int64_t masked_value_;
};

} // namespace contrib
155 changes: 145 additions & 10 deletions operators/cuda/scatter_nd_of_shape_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,65 @@ addition_inplace_kernel(T* __restrict__ output_data, const int64_t* __restrict__
}

template <typename T>
cudaError_t _ComputeNoAtomic(cudaStream_t stream, T* output_data,
const int64_t* indices_data, const T* updates_data,
int threads_per_block, int blocks_per_grid, size_t indice_size, size_t nrows, size_t stride) {
dim3 threads(threads_per_block);
dim3 blocks(blocks_per_grid);
using TT = typename CudaT<T>::MappedType;
addition_inplace_kernel<TT><<<blocks, threads, 0, stream>>>((TT*)output_data, indices_data,
(TT*)updates_data, indice_size, nrows, stride);
return cudaGetLastError();
__global__ void masked_addition_inplace_kernel(T *__restrict__ output_data,
const int64_t *__restrict__ indices_data,
const T *__restrict__ updates_data,
const CUDA_LONG indice_size,
const CUDA_LONG nrows, const CUDA_LONG stride,
const int64_t masked_value) {
auto id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= stride)
return;

for (size_t i = 0; i < nrows; ++i) {
output_data[i * stride + id] = 0;
}

for (size_t i = 0; i < indice_size; ++i) {
if (indices_data[i] == masked_value)
continue;
_add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]);
}
}

template <typename T, int NTHREAD>
__global__ void masked_addition_inplace_kernelN(T *__restrict__ output_data,
const int64_t *__restrict__ indices_data,
const T *__restrict__ updates_data,
const CUDA_LONG indice_size,
const CUDA_LONG nrows, const CUDA_LONG stride,
const int64_t masked_value) {
__shared__ int64_t shared_indices[NTHREAD];

CUDA_LONG tid = threadIdx.x;
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;

for (size_t i = 0; i < nrows; ++i) {
output_data[i * stride + id] = 0;
}

int begin = 0;
int end = std::min(begin + NTHREAD, indice_size);
while (begin < end && (end == begin + NTHREAD)) {
shared_indices[tid] = indices_data[tid + begin];
__syncthreads();

for (size_t i = begin; i < end; ++i) {
if (shared_indices[tid] == masked_value)
continue;
_add_inplace(output_data[shared_indices[tid] * stride + id],
updates_data[i * stride + id]);
}

begin = end;
end = std::min(begin + NTHREAD, indice_size);
}

for (size_t i = begin; i < indice_size; ++i) {
if (indices_data[i] == masked_value)
continue;
_add_inplace(output_data[indices_data[i] * stride + id], updates_data[i * stride + id]);
}
}

template <class NTYPE>
Expand Down Expand Up @@ -89,7 +139,54 @@ cudaError_t ScatterNDOfShapeKernel(cudaStream_t stream,

int threads_per_block = 256;
int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block;
return _ComputeNoAtomic(stream, output_data, indices_data, updates_data, threads_per_block, blocks_per_grid, indice_size, nrows, stride);

dim3 threads(threads_per_block);
dim3 blocks(blocks_per_grid);
using TT = typename CudaT<T>::MappedType;
addition_inplace_kernel<TT><<<blocks, threads, 0, stream>>>(reinterpret_cast<TT*>(output_data), indices_data,
reinterpret_cast<const TT*>(updates_data),
indice_size, nrows, stride);
return cudaGetLastError();
}

template <typename T>
cudaError_t MaskedScatterNDOfShapeKernel(cudaStream_t stream, const std::vector<int64_t> &input_shape,
const std::vector<int64_t> &indices_shape,
const int64_t *indices_data, const T *updates_data,
T *output_data,
ScatterReduction reduction, int64_t masked_value) {
if (reduction != ScatterReduction::Add)
ORTX_CXX_API_THROW("Only reduction 'add' is implemented.", ORT_RUNTIME_EXCEPTION);
size_t indice_size = static_cast<size_t>(flattened_dimension(indices_shape));
size_t input_size = static_cast<size_t>(flattened_dimension(input_shape));
size_t stride = input_shape[input_shape.size() - 1];
size_t nrows = input_size / stride;

std::vector<size_t> next_batch(indice_size);
std::vector<uint8_t> processed(input_shape[0], 0);
std::vector<uint8_t> processed_once(input_shape[0], 0);

int threads_per_block = 256;
bool split = stride / threads_per_block <= 32;

int blocks_per_grid = (stride + threads_per_block - 1) / threads_per_block;
dim3 threads(threads_per_block);
dim3 blocks(blocks_per_grid);

using TT = typename CudaT<T>::MappedType;

if (split && stride >= 256 && threads_per_block == 256) {
masked_addition_inplace_kernelN<TT, 256><<<blocks, threads, 0, stream>>>(
reinterpret_cast<TT*>(output_data), indices_data,
reinterpret_cast<const TT*>(updates_data),
indice_size, nrows, stride, masked_value);
} else {
masked_addition_inplace_kernel<TT><<<blocks, threads, 0, stream>>>(
reinterpret_cast<TT*>(output_data), indices_data,
reinterpret_cast<const TT*>(updates_data),
indice_size, nrows, stride, masked_value);
}
return cudaGetLastError();
}

template <>
Expand Down Expand Up @@ -126,4 +223,42 @@ cudaError_t LaunchScatterNDOfShapeKernel<ortc::MFloat16>(cudaStream_t stream,
reduction);
}

template <>
cudaError_t LaunchMaskedScatterNDOfShapeKernel<float>(cudaStream_t stream,
const std::vector<int64_t>& output_shape,
const std::vector<int64_t>& indices_shape,
const int64_t* indices,
const float* updates,
float* output,
ScatterReduction reduction,
int64_t masked_value) {
return MaskedScatterNDOfShapeKernel(stream,
output_shape,
indices_shape,
indices,
updates,
output,
reduction,
masked_value);
}

template <>
cudaError_t LaunchMaskedScatterNDOfShapeKernel<ortc::MFloat16>(cudaStream_t stream,
const std::vector<int64_t>& output_shape,
const std::vector<int64_t>& indices_shape,
const int64_t* indices,
const ortc::MFloat16* updates,
ortc::MFloat16* output,
ScatterReduction reduction,
int64_t masked_value) {
return MaskedScatterNDOfShapeKernel(stream,
output_shape,
indices_shape,
indices,
updates,
output,
reduction,
masked_value);
}

} // namespace contrib
10 changes: 10 additions & 0 deletions operators/cuda/scatter_nd_of_shape_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,14 @@ cudaError_t LaunchScatterNDOfShapeKernel(cudaStream_t stream,
T* output,
ScatterReduction reduction);

template <typename T>
cudaError_t LaunchMaskedScatterNDOfShapeKernel(cudaStream_t stream,
const std::vector<int64_t>& output_shape,
const std::vector<int64_t>& indices_shape,
const int64_t* indices,
const T* updates,
T* output,
ScatterReduction reduction,
int64_t masked_value);

} // namespace contrib
Loading

0 comments on commit 2605674

Please sign in to comment.