Skip to content

Commit

Permalink
Add operator AddAdd, MulMul
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 7, 2024
1 parent 1e8c121 commit 07140b4
Show file tree
Hide file tree
Showing 5 changed files with 264 additions and 19 deletions.
37 changes: 37 additions & 0 deletions operators/cuda/add_mul.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,41 @@ struct AddOrMulSharedInput {
}
};

template <typename T, bool addition>
struct AddOrMulTwice {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& tensor_a,
const ortc::Tensor<T>& tensor_b,
const ortc::Tensor<T>& tensor_c,
ortc::Tensor<T>& output) const {
const T* input_data_a = tensor_a.Data();
const T* input_data_b = tensor_b.Data();
const T* input_data_c = tensor_c.Data();

auto length_a = tensor_a.NumberOfElement();
auto length_b = tensor_b.NumberOfElement();
auto length_c = tensor_c.NumberOfElement();

T* output_data_ab = output_ab.Allocate(
length_a <= length_b
? lenght_c <= length_b ? tensor_b.Shape() : tensor_c.Shape()
: lenght_a <= length_b ? tensor_b.Shape() : tensor_a.Shape());

if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) {
return {};
}
LaunchAddOrMulTwiceKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_data_a, input_data_b, input_data_c,
output_data,
length_a, length_b, length_c,
addition);
return {};
}
};


} // namespace contrib
147 changes: 132 additions & 15 deletions operators/cuda/add_mul_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@

using namespace Ort::Custom;

__device__ __forceinline__ void _add3_op(float* ab, float* ac, const float a, const float b, const float c) {
__device__ __forceinline__ void _add3_2_op(float* ab, float* ac, const float a, const float b, const float c) {
*ab = a + b;
*ac = a + c;
}

__device__ __forceinline__ void _add3_op(half* ab, half* ac, const half a, const half b, const half c) {
__device__ __forceinline__ void _add3_2_op(half* ab, half* ac, const half a, const half b, const half c) {
#if __CUDA_ARCH__ < 700
*ab = __float2half(__half2float(a) + __half2float(b));
*ac = __float2half(__half2float(a) + __half2float(c));
Expand All @@ -27,12 +27,12 @@ __device__ __forceinline__ void _add3_op(half* ab, half* ac, const half a, const
#endif
}

__device__ __forceinline__ void _mul3_op(float* ab, float* ac, const float a, const float b, const float c) {
__device__ __forceinline__ void _mul3_2_op(float* ab, float* ac, const float a, const float b, const float c) {
*ab = a * b;
*ac = a * c;
}

__device__ __forceinline__ void _mul3_op(half* ab, half* ac, const half a, const half b, const half c) {
__device__ __forceinline__ void _mul3_2_op(half* ab, half* ac, const half a, const half b, const half c) {
#if __CUDA_ARCH__ < 700
*ab = __float2half(__half2float(a) * __half2float(b));
*ac = __float2half(__half2float(a) * __half2float(c));
Expand All @@ -45,21 +45,21 @@ __device__ __forceinline__ void _mul3_op(half* ab, half* ac, const half a, const
template <typename T>
struct Mul3SharedOp {
__device__ __forceinline__ void operator()(T* ab, T* ac, const T a, const T b, const T c) const {
_mul3_op(ab, ac, a, b, c);
_mul3_2_op(ab, ac, a, b, c);
}
};

template <typename T>
struct Add3SharedOp {
__device__ __forceinline__ void operator()(T* ab, T* ac, const T a, const T b, const T c) const {
_add3_op(ab, ac, a, b, c);
_add3_2_op(ab, ac, a, b, c);
}
};

template <typename T, typename TFunc, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void AddMulKernel(T* output_ab, T* output_ac, const T* pA, const T* pB,
const T* pC, CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC,
CUDA_LONG N, const TFunc func) {
__global__ void AddMulSharedInputKernel(T* output_ab, T* output_ac, const T* pA, const T* pB,
const T* pC, CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC,
CUDA_LONG N, const TFunc func) {
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
CUDA_LONG id = start;
#pragma unroll
Expand Down Expand Up @@ -89,14 +89,14 @@ cudaError_t _LaunchAddOrMulSharedInputKernel(cudaStream_t stream,
using TT = typename contrib::CudaT<T>::MappedType;

if (addition) {
AddMulKernel<TT, Add3SharedOp<TT>, num_threads_per_block, num_elements_per_thread>
AddMulSharedInputKernel<TT, Add3SharedOp<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
reinterpret_cast<TT*>(output_ab), reinterpret_cast<TT*>(output_ac),
reinterpret_cast<const TT*>(pA), reinterpret_cast<const TT*>(pB), reinterpret_cast<const TT*>(pC), static_cast<CUDA_LONG>(countA),
static_cast<CUDA_LONG>(countB), static_cast<CUDA_LONG>(countC),
static_cast<CUDA_LONG>(max_count), Add3SharedOp<TT>());
} else {
AddMulKernel<TT, Mul3SharedOp<TT>, num_threads_per_block, num_elements_per_thread>
AddMulSharedInputKernel<TT, Mul3SharedOp<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
reinterpret_cast<TT*>(output_ab), reinterpret_cast<TT*>(output_ac),
reinterpret_cast<const TT*>(pA), reinterpret_cast<const TT*>(pB), reinterpret_cast<const TT*>(pC), static_cast<CUDA_LONG>(countA),
Expand All @@ -107,15 +107,132 @@ cudaError_t _LaunchAddOrMulSharedInputKernel(cudaStream_t stream,
}

template <>
cudaError_t LaunchAddOrMulSharedInputKernel<float>(cudaStream_t stream, const float* input_a, const float* input_b, const float* input_c,
cudaError_t LaunchAddOrMulSharedInputKernel<float>(cudaStream_t stream,
const float* input_a, const float* input_b, const float* input_c,
float* output_ab, float* output_ac,
int64_t length_a, int64_t length_b, int64_t length_c, bool addition) {
return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, output_ab, output_ac, length_a, length_b, length_c, addition);
return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c,
output_ab, output_ac,
length_a, length_b, length_c, addition);
}

template <>
cudaError_t LaunchAddOrMulSharedInputKernel<ortc::MFloat16>(cudaStream_t stream, const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c,
cudaError_t LaunchAddOrMulSharedInputKernel<ortc::MFloat16>(cudaStream_t stream,
const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c,
ortc::MFloat16* output_ab, ortc::MFloat16* output_ac,
int64_t length_a, int64_t length_b, int64_t length_c, bool addition) {
return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c, output_ab, output_ac, length_a, length_b, length_c, addition);
return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c,
output_ab, output_ac,
length_a, length_b, length_c, addition);
}

__device__ __forceinline__ void _add3_op(float *address, const float a, const float b,
const float c) {
*address = a + b + c;
}

__device__ __forceinline__ void _add3_op(half *address, const half a, const half b,
const half c) {
#if __CUDA_ARCH__ < 700
*address = __float2half(__half2float(a) + __half2float(b) + __half2float(c));
#else
*address = a + b + c;
#endif
}

__device__ __forceinline__ void _mul3_op(float *address, const float a, const float b,
const float c) {
*address = a * b * c;
}

__device__ __forceinline__ void _mul3_op(half *address, const half a, const half b,
const half c) {
#if __CUDA_ARCH__ < 700
*address = __float2half(__half2float(a) * __half2float(b) * __half2float(c));
#else
*address = a * b * c;
#endif
}

template <typename T> struct Mul3Op {
__device__ __inline__ void operator()(T *address, const T a, const T b, const T c) const {
_mul3_op(address, a, b, c);
}
};

template <typename T> struct Add3Op {
__device__ __inline__ void operator()(T *address, const T a, const T b, const T c) const {
_add3_op(address, a, b, c);
}
};

template <typename T, typename TFunc, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void AddMulTwiceKernel(T* output, const T* pA, const T* pB,
const T* pC, CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC,
CUDA_LONG N, const TFunc func) {
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
CUDA_LONG id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
func(output_ab, pA[id % nA], pB[id % nB], pC[id % nC]);
id += NumThreadsPerBlock;
}
}
}

template <typename T>
cudaError_t _LaunchAddOrMulTwiceKernel(cudaStream_t stream,
const T* pA, const T* pB, const T* pC,
T* output,
int64_t countA, int64_t countB, int64_t countC, bool addition) {
int64_t max_count = std::max(std::max(countA, countB), countC);
if (max_count == 0) // special case where there's a dim value of 0 in the output shape
return cudaGetLastError();

const int num_elements_per_thread = 4;
const int num_threads_per_block = 256;
const int num_el_th = num_threads_per_block * num_elements_per_thread;

int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th;

using TT = typename contrib::CudaT<T>::MappedType;

if (addition) {
AddMulTwiceKernel<TT, Add3Op<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
reinterpret_cast<TT*>(output),
reinterpret_cast<const TT*>(pA), reinterpret_cast<const TT*>(pB), reinterpret_cast<const TT*>(pC),
static_cast<CUDA_LONG>(countA), static_cast<CUDA_LONG>(countB), static_cast<CUDA_LONG>(countC),
static_cast<CUDA_LONG>(max_count), Add3SharedOp<TT>());
} else {
AddMulTwiceKernel<TT, Mul3Op<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
reinterpret_cast<TT*>(output),
reinterpret_cast<const TT*>(pA), reinterpret_cast<const TT*>(pB), reinterpret_cast<const TT*>(pC), static_cast<CUDA_LONG>(countA),
static_cast<CUDA_LONG>(countB), static_cast<CUDA_LONG>(countC),
static_cast<CUDA_LONG>(max_count), Mul3SharedOp<TT>());
}
return cudaGetLastError();
}

template <>
cudaError_t LaunchAddOrMulSharedInputKernel<float>(cudaStream_t stream,
const float* input_a, const float* input_b, const float* input_c,
float* output,
int64_t length_a, int64_t length_b, int64_t length_c, bool addition) {
return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c,
output,
length_a, length_b, length_c, addition);
}

template <>
cudaError_t LaunchAddOrMulSharedInputKernel<ortc::MFloat16>(cudaStream_t stream,
const ortc::MFloat16* input_a, const ortc::MFloat16* input_b, const ortc::MFloat16* input_c,
ortc::MFloat16* output,
int64_t length_a, int64_t length_b, int64_t length_c, bool addition) {
return _LaunchAddOrMulSharedInputKernel(stream, input_a, input_b, input_c,
output,
length_a, length_b, length_c, addition);
}

6 changes: 5 additions & 1 deletion operators/cuda/add_mul_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,8 @@
template <typename T>
cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c,
T* output_ab, T* output_ac,
int64_t length_a, int64_t length_b, int64_t length_c, bool addition);
int64_t length_a, int64_t length_b, int64_t length_c, bool addition);

template <typename T>
cudaError_t LaunchAddOrMulTwiceKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c,
T* output, int64_t length_a, int64_t length_b, int64_t length_c, bool addition);
10 changes: 10 additions & 0 deletions operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
using AddSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, true>;
using MulSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, false>;

using AddTwiceFloat32Type = typename contrib::AddOrMulTwice<float, true>;
using MulTwiceFloat32Type = typename contrib::AddOrMulTwice<float, false>;

#if ORT_API_VERSION >= 16
using AddSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, true>;
using MulSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, false>;

using AddTwiceFloat16Type = typename contrib::AddOrMulTwice<ortc::MFloat16, true>;
using MulTwiceFloat16Type = typename contrib::AddOrMulTwice<ortc::MFloat16, false>;
#endif


Expand All @@ -25,15 +31,19 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
#ifdef USE_CUDA
,
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type),
CustomCudaStructV2("AddTwice", AddTwiceFloat32Type),
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
CustomCudaStructV2("MulTwice", MulTwiceFloat32Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
#if ORT_API_VERSION >= 16

CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
CustomCudaStructV2("AddTwice", AddTwiceFloat16Type),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
CustomCudaStructV2("MulTwice", MulTwiceFloat16Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>)
#endif
#endif
Expand Down
Loading

0 comments on commit 07140b4

Please sign in to comment.