From 1c9c4a4476ed59f8ec6ad786aff817aaa8338705 Mon Sep 17 00:00:00 2001 From: xadupre Date: Thu, 6 Jun 2024 10:04:42 +0000 Subject: [PATCH] draf --- operators/cuda/roatry_impl.cuh | 15 +++++++ operators/cuda/rotary.h | 77 +++++++++++++++++++++++++++++++++ operators/cuda/rotary_impl.cu | 79 ++++++++++++++++++++++++++++++++++ 3 files changed, 171 insertions(+) create mode 100644 operators/cuda/roatry_impl.cuh create mode 100644 operators/cuda/rotary.h create mode 100644 operators/cuda/rotary_impl.cu diff --git a/operators/cuda/roatry_impl.cuh b/operators/cuda/roatry_impl.cuh new file mode 100644 index 000000000..9d50b5313 --- /dev/null +++ b/operators/cuda/roatry_impl.cuh @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +enum class RotarySide : int { + LEFT = 1, + RIGHT = 2, +}; + +template +cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, + const T* input, const int64_t* split_data, T* output, RotarySide side); diff --git a/operators/cuda/rotary.h b/operators/cuda/rotary.h new file mode 100644 index 000000000..2ece1bf0e --- /dev/null +++ b/operators/cuda/rotary.h @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "ocos.h" +#include "rotary_impl.cuh" +#include "ortx_common.h" + +namespace contrib { + +template +struct Rotary { + template + OrtxStatus OnModelAttach(OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { + std::string side; + auto status = OrtW::GetOpAttribute(info, "side", side); + if (!status) { + return {kOrtxErrorInvalidArgument, "Missing or wrong argument side."}; + } + if (side == "left") { + side_ = RotarySide::LEFT; + } + else if (side == "right") { + side_ = RotarySide::RIGHT; + } + else { + return {kOrtxErrorInvalidArgument, "side must be 'left' or 'right'."}; + } + + return {}; + } + OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, + const ortc::Tensor& input, + const ortc::Tensor& split, + ortc::Tensor& output) const { + const T* input_data = input.Data(); + auto input_shape = input.Shape(); + T* output_data = output.Allocate(input_shape); + auto input_length = input.NumberOfElement(); + if (0 == input_length) { + return {}; + } + + auto shape_split = split.Shape(); + if (shape_split.size() != 1 || shape_split[0] != 2) { + return {kOrtxErrorInvalidArgument, "Rotary only works when there are two sides."}; + } + if (shape_split[0] != shape_split[1]) { + return {kOrtxErrorInvalidArgument, "Only equal split are allowed."}; + } + if (shape_split[0] * 2 != input_shape[input_shape.size()-1]) { + return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."}; + } + + const int64_t* split_data = split.Data(); + + LaunchRotaryKernel(reinterpret_cast(ctx->GetCudaStream()), + input_length, + static_cast(input_shape[input_shape.size()-1]), + input_data, + split_data, + output_data, + side_); + return {}; + } + + static OrtMemType GetInputMemoryType(size_t input_index) { + if (input_index == 1) // split + return OrtMemType::OrtMemTypeCPUInput; + return OrtMemType::OrtMemTypeDefault; + } + + private: + RotarySide side_; +}; + +} // namespace contrib \ No newline at end of file diff --git a/operators/cuda/rotary_impl.cu b/operators/cuda/rotary_impl.cu new file mode 100644 index 000000000..d8928b2d4 --- /dev/null +++ b/operators/cuda/rotary_impl.cu @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "device_prop.cuh" +#include "utils.cuh" +#include "Rotary_impl.cuh" +#include "cuda_type.h" + +using namespace Ort::Custom; + +template __device__ __inline__ T _neg(const T x) { return -x; } + +#if __CUDA_ARCH__ < 700 +template <> __device__ __inline__ half _neg(const half x) { + return __float2half(-__half2float(x)); +} +#endif + +template +__global__ void RotaryKernel(T *output_data, const T *input_data, CUDA_LONG half_N, CUDA_LONG half_stride) { + CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; + if (id >= half_N) + return; + CUDA_LONG last = id % half_stride; + id = (id - last) * 2 + last; + if (side == RotarySide::RIGHT) { + output_data[id + half_stride] = input_data[id]; + output_data[id] = _neg(input_data[id + half_stride]); + } else { + output_data[id + half_stride] = _neg(input_data[id]); + output_data[id] = input_data[id + half_stride]; + } +} + +template +cudaError_t _LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, + const T* input, const int64_t* split_data, T* output, RotarySide side) { + constexpr int blockSize = 256; + const int gridSize = (input_length + blockSize - 1) / blockSize; + if (input_length == 0) + return; + using TT = typename contrib::CudaT::MappedType; + + CUDA_LONG N = static_cast(count); + CUDA_LONG stride = static_cast(last_dim); + + const int num_threads_per_block = GridDim::maxThreadsPerBlock; + const int num_elements_per_thread = + (N / 2 + num_threads_per_block - 1) / num_threads_per_block; + + switch (side) { + case RotarySide::LEFT: + RotaryKernel + <<>>(output_data, input_data, + N / 2, stride / 2); + break; + case RotarySide::RIGHT: + RotaryKernel + <<>>(output_data, input_data, + N / 2, stride / 2); + break; + } + + RotaryKernel<<>>(reinterpret_cast(output), reinterpret_cast(input), input_length); + return cudaGetLastError(); +} + +template <> +cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, + const float* input, const int64_t* split_data, float* output, RotarySide side) { + return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side); +} + +template <> +cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, + const ortc::MFloat16* input, const int64_t* split_data, + ortc::MFloat16* output, RotarySide side) { + return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side); +}