Skip to content

Commit

Permalink
add_mul
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 8, 2024
1 parent 07140b4 commit 31d8feb
Show file tree
Hide file tree
Showing 5 changed files with 355 additions and 19 deletions.
84 changes: 77 additions & 7 deletions operators/cuda/add_mul.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,91 @@ struct AddOrMulTwice {
auto length_c = tensor_c.NumberOfElement();

T* output_data_ab = output_ab.Allocate(
length_a <= length_b
? lenght_c <= length_b ? tensor_b.Shape() : tensor_c.Shape()
: lenght_a <= length_b ? tensor_b.Shape() : tensor_a.Shape());
length_a <= length_b
? lenght_c <= length_b ? tensor_b.Shape() : tensor_c.Shape()
: lenght_a <= length_b ? tensor_b.Shape()
: tensor_a.Shape());

if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) {
return {};
}
LaunchAddOrMulTwiceKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_data_a, input_data_b, input_data_c,
output_data,
length_a, length_b, length_c,
addition);
input_data_a, input_data_b, input_data_c,
output_data,
length_a, length_b, length_c,
addition);
return {};
}
};

template <typename T, bool addition_first>
struct AddAndMul {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& dict) {
return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& tensor_a,
const ortc::Tensor<T>& tensor_b,
const ortc::Tensor<T>& tensor_c,
ortc::Tensor<T>& output) const {
const T* input_data_a = tensor_a.Data();
const T* input_data_b = tensor_b.Data();
const T* input_data_c = tensor_c.Data();

auto length_a = tensor_a.NumberOfElement();
auto length_b = tensor_b.NumberOfElement();
auto length_c = tensor_c.NumberOfElement();
if (0 == input_data_a || 0 == input_data_b || 0 == input_data_c) {
return {};
}

std::vector<int64_t> dimsA = tensor_a.Shape();
std::vector<int64_t> dimsB = tensor_b.Shape();
std::vector<int64_t> dimsC = tensor_c.Shape();

auto max_length = std::max(length_a, std::max(length_b, length_c));

auto max_rank = std::max(dimsA.size(), std::max(dimsB.size(), dimsC.size()));
while (dimsA.size() < max_rank)
dimsA.insert(dimsA.begin(), 1);
while (dimsB.size() < max_rank)
dimsB.insert(dimsB.begin(), 1);
while (dimsC.size() < max_rank)
dimsC.insert(dimsC.begin(), 1);

std::vector<int64_t> output_dims(dimsA.size());
for (size_t i = 0; i < dimsA.size(); ++i) {
output_dims[i] = std::max(std::max(dimsA[i], dimsB[i]), dimsC[i]);
}

if (switchMiddelAxis_) {
if (output_dims.size() != 4) {
ORTX_CXX_API_THROW("switchMiddleAxes only works with 4D tensors", ORT_RUNTIME_EXCEPTION);
}
int64_t d4 = output_dims[output_dims.size() - 1];
int64_t d3 = output_dims[output_dims.size() - 2];
int64_t d2 = output_dims[output_dims.size() - 3];
output_dims[1] = d3;
output_dims[2] = d2;
LaunchAddAndMulSwitchMiddleAxesKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_data_a, input_data_b, input_data_c,
output_data,
length_a, length_b, length_c,
addition_first, d2, d3, d4);
} else {
T* output_data_ab = output_ab.Allocate(output_dims);
LaunchAddAndMulKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_data_a, input_data_b, input_data_c,
output_data,
length_a, length_b, length_c,
addition_first, switchMiddelAxis_);
}
return {};
}

private:
bool switchMiddelAxis_;
};

} // namespace contrib
181 changes: 172 additions & 9 deletions operators/cuda/add_mul_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@ cudaError_t LaunchAddOrMulSharedInputKernel<ortc::MFloat16>(cudaStream_t stream,
length_a, length_b, length_c, addition);
}

__device__ __forceinline__ void _add3_op(float *address, const float a, const float b,
__device__ __forceinline__ void _add3_op(float* address, const float a, const float b,
const float c) {
*address = a + b + c;
}

__device__ __forceinline__ void _add3_op(half *address, const half a, const half b,
__device__ __forceinline__ void _add3_op(half* address, const half a, const half b,
const half c) {
#if __CUDA_ARCH__ < 700
*address = __float2half(__half2float(a) + __half2float(b) + __half2float(c));
Expand All @@ -140,12 +140,12 @@ __device__ __forceinline__ void _add3_op(half *address, const half a, const half
#endif
}

__device__ __forceinline__ void _mul3_op(float *address, const float a, const float b,
__device__ __forceinline__ void _mul3_op(float* address, const float a, const float b,
const float c) {
*address = a * b * c;
}

__device__ __forceinline__ void _mul3_op(half *address, const half a, const half b,
__device__ __forceinline__ void _mul3_op(half* address, const half a, const half b,
const half c) {
#if __CUDA_ARCH__ < 700
*address = __float2half(__half2float(a) * __half2float(b) * __half2float(c));
Expand All @@ -154,14 +154,16 @@ __device__ __forceinline__ void _mul3_op(half *address, const half a, const half
#endif
}

template <typename T> struct Mul3Op {
__device__ __inline__ void operator()(T *address, const T a, const T b, const T c) const {
template <typename T>
struct Mul3Op {
__device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const {
_mul3_op(address, a, b, c);
}
};

template <typename T> struct Add3Op {
__device__ __inline__ void operator()(T *address, const T a, const T b, const T c) const {
template <typename T>
struct Add3Op {
__device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const {
_add3_op(address, a, b, c);
}
};
Expand Down Expand Up @@ -201,7 +203,7 @@ cudaError_t _LaunchAddOrMulTwiceKernel(cudaStream_t stream,
if (addition) {
AddMulTwiceKernel<TT, Add3Op<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
reinterpret_cast<TT*>(output),
reinterpret_cast<TT*>(output),
reinterpret_cast<const TT*>(pA), reinterpret_cast<const TT*>(pB), reinterpret_cast<const TT*>(pC),
static_cast<CUDA_LONG>(countA), static_cast<CUDA_LONG>(countB), static_cast<CUDA_LONG>(countC),
static_cast<CUDA_LONG>(max_count), Add3SharedOp<TT>());
Expand Down Expand Up @@ -236,3 +238,164 @@ cudaError_t LaunchAddOrMulSharedInputKernel<ortc::MFloat16>(cudaStream_t stream,
length_a, length_b, length_c, addition);
}

__device__ __forceinline__ void _addmul_op(float* address, const float a, const float b,
const float c) {
*address = (a + b) * c;
}

__device__ __forceinline__ void _addmul_op(half* address, const half a, const half b,
const half c) {
#if __CUDA_ARCH__ < 700
*address = __float2half((__half2float(a) + __half2float(b)) * __half2float(c));
#else
*address = (a + b) * c;
#endif
}

__device__ __forceinline__ void _muladd_op(float* address, const float a, const float b,
const float c) {
*address = a * b + c;
}

__device__ __forceinline__ void _muladd_op(half* address, const half a, const half b,
const half c) {
#if __CUDA_ARCH__ < 700
*address = __float2half(__half2float(a) * __half2float(b) + __half2float(c));
#else
*address = a * b + c;
#endif
}

template <typename T>
struct AddMul {
__device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const {
_addmul_op(address, a, b, c);
}
};

template <typename T>
struct MulAdd {
__device__ __inline__ void operator()(T* address, const T a, const T b, const T c) const {
_muladd_op(address, a, b, c);
}
};

template <typename T, typename TFunc, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void _AddAndMulKernel(T* output_data, const T* pA, const T* pB, const T* pC,
CUDA_LONG nA, CUDA_LONG nB, CUDA_LONG nC, CUDA_LONG N,
const TFunc func) {
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
CUDA_LONG id = start;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
func(output_data + id, pA[id % nA], pB[id % nB], pC[id % nC]);
id += NumThreadsPerBlock;
}
}
}

template <typename T, typename TFunc, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void _AddAndMulSwitchMiddleAxesKernel(T* output_data, const T* pA, const T* pB,
const T* pC, CUDA_LONG nA, CUDA_LONG nB,
CUDA_LONG nC, CUDA_LONG N,
const TFunc func, CUDA_LONG d2,
CUDA_LONG d3, CUDA_LONG d4) {
CUDA_LONG start = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
CUDA_LONG id = start;
CUDA_LONG k, j, ido;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
k = (id / d4) % d3;
j = (id / (d4 * d3)) % d2;
ido = id + d4 * ((k * d2 + j) - (j * d3 + k));
func(output_data + ido, pA[id % nA], pB[id % nB], pC[id % nC]);
id += NumThreadsPerBlock;
}
}
}

template <typename T>
cudaError_t _LaunchAddAndMulKernel(cudaStream_t stream,
const T* pA, const T* pB, const T* pC,
T* output,
int64_t countA, int64_t countB, int64_t countC,
bool addition_first, bool switchMiddleAxes) {
int64_t max_count = std::max(std::max(countA, countB), countC);
if (max_count == 0) // special case where there's a dim value of 0 in the output shape
return cudaGetLastError();

const int num_elements_per_thread = 4;
const int num_threads_per_block = 256;
const int num_el_th = num_threads_per_block * num_elements_per_thread;

int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th;

using TT = typename contrib::CudaT<T>::MappedType;

if (addition_first) {
AddAndMulKernel<TT, AddMul<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
cuda_stream,
reinterpret_cast<TT*>(output),
reinterpret_cast<const TT*>(pA),
reinterpret_cast<const TT*>(pB),
reinterpret_cast<const TT*>(pC),
countA, countB, countC,
max_size, AddMul<TT>());
} else {
AddAndMulKernel<TT, MulAdd<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
cuda_stream,
reinterpret_cast<TT*>(output),
reinterpret_cast<const TT*>(pA),
reinterpret_cast<const TT*>(pB),
reinterpret_cast<const TT*>(pC),
countA, countB, countC,
max_size, MulAdd<TT>());
}
return cudaGetLastError();
}

template <typename T>
cudaError_t _LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream,
const T* pA, const T* pB, const T* pC,
T* output,
int64_t countA, int64_t countB, int64_t countC,
bool addition_first, int64_t d2, int64_t d3, int64_t d4) {
int64_t max_count = std::max(std::max(countA, countB), countC);
if (max_count == 0) // special case where there's a dim value of 0 in the output shape
return cudaGetLastError();

const int num_elements_per_thread = 4;
const int num_threads_per_block = 256;
const int num_el_th = num_threads_per_block * num_elements_per_thread;

int blocksPerGrid = (max_count + num_el_th - 1) / num_el_th;

using TT = typename contrib::CudaT<T>::MappedType;

if (addition_first) {
AddAndMulSwitchMiddleAxesKernel<TT, AddMul<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
cuda_stream,
reinterpret_cast<TT*>(output),
reinterpret_cast<const TT*>(pA),
reinterpret_cast<const TT*>(pB),
reinterpret_cast<const TT*>(pC),
countA, countB, countC,
max_size, AddMul<TT>());
} else {
AddAndMulSwitchMiddleAxesKernel<TT, MulAdd<TT>, num_threads_per_block, num_elements_per_thread>
<<<blocksPerGrid, num_threads_per_block, 0, stream>>>(
cuda_stream,
reinterpret_cast<TT*>(output),
reinterpret_cast<const TT*>(pA),
reinterpret_cast<const TT*>(pB),
reinterpret_cast<const TT*>(pC),
countA, countB, countC,
max_size, MulAdd<TT>());
}
return cudaGetLastError();
}
13 changes: 12 additions & 1 deletion operators/cuda/add_mul_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,15 @@ cudaError_t LaunchAddOrMulSharedInputKernel(cudaStream_t stream, const T* input_

template <typename T>
cudaError_t LaunchAddOrMulTwiceKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c,
T* output, int64_t length_a, int64_t length_b, int64_t length_c, bool addition);
T* output, int64_t length_a, int64_t length_b, int64_t length_c, bool addition);

template <typename T>
cudaError_t LaunchAddAndMulKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c,
T* output, int64_t length_a, int64_t length_b, int64_t length_c,
bool addition, bool switchMiddleAxis);

template <typename T>
cudaError_t LaunchAddAndMulSwitchMiddleAxesKernel(cudaStream_t stream, const T* input_a, const T* input_b, const T* input_c,
T* output, int64_t length_a, int64_t length_b, int64_t length_c,
bool addition,
int64_t d2, int64_t d3, int64_t d4);
12 changes: 10 additions & 2 deletions operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,46 @@
#endif

FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {

using AddSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, true>;
using MulSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, false>;

using AddTwiceFloat32Type = typename contrib::AddOrMulTwice<float, true>;
using MulTwiceFloat32Type = typename contrib::AddOrMulTwice<float, false>;

using AddAndMulFloat32Type = typename contrib::AddOrMulTwice<float, true>;
using MulAndAddFloat32Type = typename contrib::AddOrMulTwice<float, false>;

#if ORT_API_VERSION >= 16
using AddSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, true>;
using MulSharedInputFloat16Type = typename contrib::AddOrMulSharedInput<ortc::MFloat16, false>;

using AddTwiceFloat16Type = typename contrib::AddOrMulTwice<ortc::MFloat16, true>;
using MulTwiceFloat16Type = typename contrib::AddOrMulTwice<ortc::MFloat16, false>;
#endif

using AddAndMulFloat32Type = typename contrib::AddOrMulTwice<ortc::MFloat16, true>;
using MulAndAddFloat32Type = typename contrib::AddOrMulTwice<ortc::MFloat16, false>;
#endif

static OrtOpLoader op_loader(
[]() { return nullptr; }
#ifdef USE_CUDA
,
CustomCudaStructV2("AddMul", AddAndMulFloat32Type),
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type),
CustomCudaStructV2("AddTwice", AddTwiceFloat32Type),
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
CustomCudaStructV2("MulAdd", MulAndAddFloat32Type),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
CustomCudaStructV2("MulTwice", MulTwiceFloat32Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
#if ORT_API_VERSION >= 16

CustomCudaStructV2("AddMul", AddAndMulFloat16Type),
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
CustomCudaStructV2("AddTwice", AddTwiceFloat16Type),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
CustomCudaStructV2("MulAdd", MulAndAddFloat16Type),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
CustomCudaStructV2("MulTwice", MulTwiceFloat16Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>)
Expand Down
Loading

0 comments on commit 31d8feb

Please sign in to comment.