forked from microsoft/onnxruntime-extensions
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
171 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
|
||
enum class RotarySide : int { | ||
LEFT = 1, | ||
RIGHT = 2, | ||
}; | ||
|
||
template <typename T> | ||
cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, | ||
const T* input, const int64_t* split_data, T* output, RotarySide side); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "ocos.h" | ||
#include "rotary_impl.cuh" | ||
#include "ortx_common.h" | ||
|
||
namespace contrib { | ||
|
||
template <typename T> | ||
struct Rotary { | ||
template <typename TDict> | ||
OrtxStatus OnModelAttach(OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { | ||
std::string side; | ||
auto status = OrtW::GetOpAttribute(info, "side", side); | ||
if (!status) { | ||
return {kOrtxErrorInvalidArgument, "Missing or wrong argument side."}; | ||
} | ||
if (side == "left") { | ||
side_ = RotarySide::LEFT; | ||
} | ||
else if (side == "right") { | ||
side_ = RotarySide::RIGHT; | ||
} | ||
else { | ||
return {kOrtxErrorInvalidArgument, "side must be 'left' or 'right'."}; | ||
} | ||
|
||
return {}; | ||
} | ||
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, | ||
const ortc::Tensor<T>& input, | ||
const ortc::Tensor<int64_t>& split, | ||
ortc::Tensor<T>& output) const { | ||
const T* input_data = input.Data(); | ||
auto input_shape = input.Shape(); | ||
T* output_data = output.Allocate(input_shape); | ||
auto input_length = input.NumberOfElement(); | ||
if (0 == input_length) { | ||
return {}; | ||
} | ||
|
||
auto shape_split = split.Shape(); | ||
if (shape_split.size() != 1 || shape_split[0] != 2) { | ||
return {kOrtxErrorInvalidArgument, "Rotary only works when there are two sides."}; | ||
} | ||
if (shape_split[0] != shape_split[1]) { | ||
return {kOrtxErrorInvalidArgument, "Only equal split are allowed."}; | ||
} | ||
if (shape_split[0] * 2 != input_shape[input_shape.size()-1]) { | ||
return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."}; | ||
} | ||
|
||
const int64_t* split_data = split.Data(); | ||
|
||
LaunchRotaryKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()), | ||
input_length, | ||
static_cast<int>(input_shape[input_shape.size()-1]), | ||
input_data, | ||
split_data, | ||
output_data, | ||
side_); | ||
return {}; | ||
} | ||
|
||
static OrtMemType GetInputMemoryType(size_t input_index) { | ||
if (input_index == 1) // split | ||
return OrtMemType::OrtMemTypeCPUInput; | ||
return OrtMemType::OrtMemTypeDefault; | ||
} | ||
|
||
private: | ||
RotarySide side_; | ||
}; | ||
|
||
} // namespace contrib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "device_prop.cuh" | ||
#include "utils.cuh" | ||
#include "Rotary_impl.cuh" | ||
#include "cuda_type.h" | ||
|
||
using namespace Ort::Custom; | ||
|
||
template <typename T> __device__ __inline__ T _neg(const T x) { return -x; } | ||
|
||
#if __CUDA_ARCH__ < 700 | ||
template <> __device__ __inline__ half _neg(const half x) { | ||
return __float2half(-__half2float(x)); | ||
} | ||
#endif | ||
|
||
template <typename T, RotarySide side> | ||
__global__ void RotaryKernel(T *output_data, const T *input_data, CUDA_LONG half_N, CUDA_LONG half_stride) { | ||
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; | ||
if (id >= half_N) | ||
return; | ||
CUDA_LONG last = id % half_stride; | ||
id = (id - last) * 2 + last; | ||
if (side == RotarySide::RIGHT) { | ||
output_data[id + half_stride] = input_data[id]; | ||
output_data[id] = _neg(input_data[id + half_stride]); | ||
} else { | ||
output_data[id + half_stride] = _neg(input_data[id]); | ||
output_data[id] = input_data[id + half_stride]; | ||
} | ||
} | ||
|
||
template <typename T> | ||
cudaError_t _LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim, | ||
const T* input, const int64_t* split_data, T* output, RotarySide side) { | ||
constexpr int blockSize = 256; | ||
const int gridSize = (input_length + blockSize - 1) / blockSize; | ||
if (input_length == 0) | ||
return; | ||
using TT = typename contrib::CudaT<T>::MappedType; | ||
|
||
CUDA_LONG N = static_cast<CUDA_LONG>(count); | ||
CUDA_LONG stride = static_cast<CUDA_LONG>(last_dim); | ||
|
||
const int num_threads_per_block = GridDim::maxThreadsPerBlock; | ||
const int num_elements_per_thread = | ||
(N / 2 + num_threads_per_block - 1) / num_threads_per_block; | ||
|
||
switch (side) { | ||
case RotarySide::LEFT: | ||
RotaryKernel<T, RotarySide::LEFT> | ||
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(output_data, input_data, | ||
N / 2, stride / 2); | ||
break; | ||
case RotarySide::RIGHT: | ||
RotaryKernel<T, RotarySide::RIGHT> | ||
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(output_data, input_data, | ||
N / 2, stride / 2); | ||
break; | ||
} | ||
|
||
RotaryKernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output), reinterpret_cast<const TT*>(input), input_length); | ||
return cudaGetLastError(); | ||
} | ||
|
||
template <> | ||
cudaError_t LaunchRotaryKernel<float>(cudaStream_t stream, int input_length, int last_dim, | ||
const float* input, const int64_t* split_data, float* output, RotarySide side) { | ||
return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side); | ||
} | ||
|
||
template <> | ||
cudaError_t LaunchRotaryKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, int last_dim, | ||
const ortc::MFloat16* input, const int64_t* split_data, | ||
ortc::MFloat16* output, RotarySide side) { | ||
return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side); | ||
} |