Skip to content

Commit

Permalink
draf
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 6, 2024
1 parent 1e8c121 commit 1c9c4a4
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 0 deletions.
15 changes: 15 additions & 0 deletions operators/cuda/roatry_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>

enum class RotarySide : int {
LEFT = 1,
RIGHT = 2,
};

template <typename T>
cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
const T* input, const int64_t* split_data, T* output, RotarySide side);
77 changes: 77 additions & 0 deletions operators/cuda/rotary.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"
#include "rotary_impl.cuh"
#include "ortx_common.h"

namespace contrib {

template <typename T>
struct Rotary {
template <typename TDict>
OrtxStatus OnModelAttach(OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string side;
auto status = OrtW::GetOpAttribute(info, "side", side);
if (!status) {
return {kOrtxErrorInvalidArgument, "Missing or wrong argument side."};
}
if (side == "left") {
side_ = RotarySide::LEFT;
}
else if (side == "right") {
side_ = RotarySide::RIGHT;
}
else {
return {kOrtxErrorInvalidArgument, "side must be 'left' or 'right'."};
}

return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& input,
const ortc::Tensor<int64_t>& split,
ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
auto input_shape = input.Shape();
T* output_data = output.Allocate(input_shape);
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return {};
}

auto shape_split = split.Shape();
if (shape_split.size() != 1 || shape_split[0] != 2) {
return {kOrtxErrorInvalidArgument, "Rotary only works when there are two sides."};
}
if (shape_split[0] != shape_split[1]) {
return {kOrtxErrorInvalidArgument, "Only equal split are allowed."};
}
if (shape_split[0] * 2 != input_shape[input_shape.size()-1]) {
return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."};
}

const int64_t* split_data = split.Data();

LaunchRotaryKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
static_cast<int>(input_shape[input_shape.size()-1]),
input_data,
split_data,
output_data,
side_);
return {};
}

static OrtMemType GetInputMemoryType(size_t input_index) {
if (input_index == 1) // split
return OrtMemType::OrtMemTypeCPUInput;
return OrtMemType::OrtMemTypeDefault;
}

private:
RotarySide side_;
};

} // namespace contrib
79 changes: 79 additions & 0 deletions operators/cuda/rotary_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "device_prop.cuh"
#include "utils.cuh"
#include "Rotary_impl.cuh"
#include "cuda_type.h"

using namespace Ort::Custom;

template <typename T> __device__ __inline__ T _neg(const T x) { return -x; }

#if __CUDA_ARCH__ < 700
template <> __device__ __inline__ half _neg(const half x) {
return __float2half(-__half2float(x));
}
#endif

template <typename T, RotarySide side>
__global__ void RotaryKernel(T *output_data, const T *input_data, CUDA_LONG half_N, CUDA_LONG half_stride) {
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= half_N)
return;
CUDA_LONG last = id % half_stride;
id = (id - last) * 2 + last;
if (side == RotarySide::RIGHT) {
output_data[id + half_stride] = input_data[id];
output_data[id] = _neg(input_data[id + half_stride]);
} else {
output_data[id + half_stride] = _neg(input_data[id]);
output_data[id] = input_data[id + half_stride];
}
}

template <typename T>
cudaError_t _LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
const T* input, const int64_t* split_data, T* output, RotarySide side) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
if (input_length == 0)
return;
using TT = typename contrib::CudaT<T>::MappedType;

CUDA_LONG N = static_cast<CUDA_LONG>(count);
CUDA_LONG stride = static_cast<CUDA_LONG>(last_dim);

const int num_threads_per_block = GridDim::maxThreadsPerBlock;
const int num_elements_per_thread =
(N / 2 + num_threads_per_block - 1) / num_threads_per_block;

switch (side) {
case RotarySide::LEFT:
RotaryKernel<T, RotarySide::LEFT>
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(output_data, input_data,
N / 2, stride / 2);
break;
case RotarySide::RIGHT:
RotaryKernel<T, RotarySide::RIGHT>
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(output_data, input_data,
N / 2, stride / 2);
break;
}

RotaryKernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output), reinterpret_cast<const TT*>(input), input_length);
return cudaGetLastError();
}

template <>
cudaError_t LaunchRotaryKernel<float>(cudaStream_t stream, int input_length, int last_dim,
const float* input, const int64_t* split_data, float* output, RotarySide side) {
return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side);
}

template <>
cudaError_t LaunchRotaryKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, int last_dim,
const ortc::MFloat16* input, const int64_t* split_data,
ortc::MFloat16* output, RotarySide side) {
return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side);
}

0 comments on commit 1c9c4a4

Please sign in to comment.