Skip to content

Commit

Permalink
Fix implementation of Rotary
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 6, 2024
1 parent 1c9c4a4 commit 52f351c
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 35 deletions.
5 changes: 4 additions & 1 deletion operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "cuda/add_mul.h"
#include "cuda/fast_gelu.h"
#include "cuda/negxplus1.h"
#include "cuda/rotary.h"
#endif

FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
Expand All @@ -28,13 +29,15 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
CustomCudaStructV2("Rotary", contrib::Rotary<float>),
#if ORT_API_VERSION >= 16

CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>)
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("Rotary", contrib::Rotary<ortc::MFloat16>)
#endif
#endif
);
Expand Down
16 changes: 6 additions & 10 deletions operators/cuda/rotary.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,9 @@ namespace contrib {
template <typename T>
struct Rotary {
template <typename TDict>
OrtxStatus OnModelAttach(OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
std::string side;
auto status = OrtW::GetOpAttribute(info, "side", side);
if (!status) {
return {kOrtxErrorInvalidArgument, "Missing or wrong argument side."};
}
OrtxStatus OnModelAttach(const TDict& dict) {
std::string empty;
std::string side = dict.TryToGetAttributeWithDefault("side", empty);
if (side == "left") {
side_ = RotarySide::LEFT;
}
Expand Down Expand Up @@ -45,15 +42,14 @@ struct Rotary {
if (shape_split.size() != 1 || shape_split[0] != 2) {
return {kOrtxErrorInvalidArgument, "Rotary only works when there are two sides."};
}
if (shape_split[0] != shape_split[1]) {
const int64_t* split_data = split.Data();
if (split_data[0] != split_data[1]) {
return {kOrtxErrorInvalidArgument, "Only equal split are allowed."};
}
if (shape_split[0] * 2 != input_shape[input_shape.size()-1]) {
if (split_data[0] * 2 != input_shape[input_shape.size()-1]) {
return {kOrtxErrorInvalidArgument, "Sum of the splits are not equal to the last dimension."};
}

const int64_t* split_data = split.Data();

LaunchRotaryKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
static_cast<int>(input_shape[input_shape.size()-1]),
Expand Down
38 changes: 20 additions & 18 deletions operators/cuda/rotary_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@

#include "device_prop.cuh"
#include "utils.cuh"
#include "Rotary_impl.cuh"
#include "rotary_impl.cuh"
#include "cuda_type.h"

#ifndef CUDA_LONG
#define CUDA_LONG int32_t
#endif

using namespace Ort::Custom;

template <typename T> __device__ __inline__ T _neg(const T x) { return -x; }
Expand Down Expand Up @@ -34,46 +38,44 @@ __global__ void RotaryKernel(T *output_data, const T *input_data, CUDA_LONG half

template <typename T>
cudaError_t _LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
const T* input, const int64_t* split_data, T* output, RotarySide side) {
constexpr int blockSize = 256;
const int gridSize = (input_length + blockSize - 1) / blockSize;
const T* input_data, const int64_t* /* split_data */, T* output_data, RotarySide side) {
if (input_length == 0)
return;
return cudaGetLastError();
using TT = typename contrib::CudaT<T>::MappedType;

CUDA_LONG N = static_cast<CUDA_LONG>(count);
CUDA_LONG N = static_cast<CUDA_LONG>(input_length);
CUDA_LONG stride = static_cast<CUDA_LONG>(last_dim);

const int num_threads_per_block = GridDim::maxThreadsPerBlock;
const int num_threads_per_block = 256;
const int num_elements_per_thread =
(N / 2 + num_threads_per_block - 1) / num_threads_per_block;

switch (side) {
case RotarySide::LEFT:
RotaryKernel<T, RotarySide::LEFT>
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(output_data, input_data,
RotaryKernel<TT, RotarySide::LEFT>
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(reinterpret_cast<TT*>(output_data),
reinterpret_cast<const TT*>(input_data),
N / 2, stride / 2);
break;
case RotarySide::RIGHT:
RotaryKernel<T, RotarySide::RIGHT>
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(output_data, input_data,
RotaryKernel<TT, RotarySide::RIGHT>
<<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(reinterpret_cast<TT*>(output_data),
reinterpret_cast<const TT*>(input_data),
N / 2, stride / 2);
break;
}

RotaryKernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output), reinterpret_cast<const TT*>(input), input_length);
return cudaGetLastError();
}

template <>
cudaError_t LaunchRotaryKernel<float>(cudaStream_t stream, int input_length, int last_dim,
const float* input, const int64_t* split_data, float* output, RotarySide side) {
return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side);
const float* input_data, const int64_t* split_data, float* output_data, RotarySide side) {
return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side);
}

template <>
cudaError_t LaunchRotaryKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, int last_dim,
const ortc::MFloat16* input, const int64_t* split_data,
ortc::MFloat16* output, RotarySide side) {
return _LaunchRotaryKernel(stream, input_length, last_dim, input, split_data, output, side);
const ortc::MFloat16* input_data, const int64_t* split_data,
ortc::MFloat16* output_data, RotarySide side) {
return _LaunchRotaryKernel(stream, input_length, last_dim, input_data, split_data, output_data, side);
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ enum class RotarySide : int {

template <typename T>
cudaError_t LaunchRotaryKernel(cudaStream_t stream, int input_length, int last_dim,
const T* input, const int64_t* split_data, T* output, RotarySide side);
const T* input_data, const int64_t* split_data, T* output_data, RotarySide side);
69 changes: 64 additions & 5 deletions test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ def test_cuda_negxplus1(self):
self._negxplus1_cuda(TensorProto.FLOAT16)

def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, 2, 3), shapec=(3, 2, 3)):
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs

model1 = helper.make_model(
helper.make_graph(
[
Expand Down Expand Up @@ -181,7 +179,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
f"{op_type}SharedInput",
["X", "Y", "Z"],
["XY", "XZ"],
domain="onnx_extended.ortops.optim.cuda",
domain="ai.onnx.contrib",
)
],
"nd",
Expand All @@ -197,7 +195,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("onnx_extended.ortops.optim.cuda", 1),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)
Expand All @@ -212,7 +210,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
expected = ref.run(None, feeds1)

opts = _ort.SessionOptions()
opts.register_custom_ops_library(get_ort_ext_libs()[0])
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds1)
for i in range(2):
Expand Down Expand Up @@ -262,6 +260,67 @@ def test_add_shared_input_cuda_broadcast2(self):
shapec=(3, 2, 3),
)

def _rotary_cuda(self, itype, side, input_shape=(3, 2, 3, 4)):
model2 = helper.make_model(
helper.make_graph(
[
helper.make_node(
"Rotary",
["X", "splits"],
["Y"],
domain="ai.onnx.contrib",
side=side,
)
],
"nd",
[
helper.make_tensor_value_info("X", itype, [None, None, None, None]),
helper.make_tensor_value_info("splits", TensorProto.INT64, [2]),
],
[helper.make_tensor_value_info("Y", itype, [None, None, None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
x = (np.arange(np.prod(input_shape)) + 1).reshape(input_shape).astype(dtype)
splits = np.array([x.shape[-1] // 2, x.shape[-1] // 2], dtype=np.int64)

expected = x.copy()
half = x.shape[-1] // 2
if side == "left":
expected[:, :, :, :half] = x[:, :, :, half:]
expected[:, :, :, half:] = -x[:, :, :, :half]
else:
expected[:, :, :, :half] = -x[:, :, :, half:]
expected[:, :, :, half:] = x[:, :, :, :half]

feeds = dict(X=x, splits=splits)
opts = _ort.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds)[0]
assert_almost_equal(expected, got)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_rotary_cuda(self):
self._rotary_cuda(TensorProto.FLOAT, "left")
self._rotary_cuda(TensorProto.FLOAT, "right")
self._rotary_cuda(TensorProto.FLOAT16, "left")
self._rotary_cuda(TensorProto.FLOAT16, "right")

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_bigger_rotary_cuda(self):
sh = (2, 2, 1024, 8)
self._rotary_cuda(TensorProto.FLOAT, "left", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT, "right", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT16, "left", input_shape=sh)
self._rotary_cuda(TensorProto.FLOAT16, "right", input_shape=sh)


if __name__ == "__main__":
unittest.main()

0 comments on commit 52f351c

Please sign in to comment.