forked from microsoft/onnxruntime-extensions
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/microsoft/onnxruntime-exten…
…sions into misc
- Loading branch information
Showing
5 changed files
with
335 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "ocos.h" | ||
#include "mul_sigmoid_impl.cuh" | ||
#include "ortx_common.h" | ||
|
||
namespace contrib { | ||
|
||
/** | ||
* MulSigmoid(X) = X * Sigmoid(X) | ||
No shape broadcasting supported. | ||
*/ | ||
template <typename T> | ||
struct MulSigmoid { | ||
template <typename TDict> | ||
OrtxStatus OnModelAttach(const TDict& /*dict*/) { | ||
return {}; | ||
} | ||
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, | ||
const ortc::Tensor<T>& input, | ||
ortc::Tensor<T>& output) const { | ||
const T* input_data = input.Data(); | ||
T* output_data = output.Allocate(input.Shape()); | ||
auto input_length = input.NumberOfElement(); | ||
if (0 == input_length) { | ||
return {}; | ||
} | ||
LaunchMulSigmoidKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()), | ||
input_length, | ||
input_data, | ||
output_data); | ||
return {}; | ||
} | ||
}; | ||
|
||
/** | ||
* MulSigmoid(X, Y) = X * Y * Sigmoid(Y) | ||
No shape broadcasting supported. | ||
*/ | ||
template <typename T> | ||
struct MulMulSigmoid { | ||
template <typename TDict> | ||
OrtxStatus OnModelAttach(const TDict& /*dict*/) { | ||
return {}; | ||
} | ||
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx, | ||
const ortc::Tensor<T>& input_x, | ||
const ortc::Tensor<T>& input_y, | ||
ortc::Tensor<T>& output) const { | ||
const T* input_data_x = input_x.Data(); | ||
const T* input_data_y = input_y.Data(); | ||
auto input_length_x = input_x.NumberOfElement(); | ||
auto input_length_y = input_y.NumberOfElement(); | ||
if (0 == input_length_x || 0 == input_data_y) { | ||
return {}; | ||
} | ||
T* output_data = output.Allocate(input_length_x > input_length_y ? input_x.Shape() : input_y.Shape()); | ||
LaunchMulMulSigmoidKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()), | ||
input_length_x, | ||
input_length_y, | ||
input_data_x, | ||
input_data_y, | ||
output_data); | ||
return {}; | ||
} | ||
}; | ||
|
||
} // namespace contrib |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "device_prop.cuh" | ||
#include "utils.cuh" | ||
#include "mul_sigmoid_impl.cuh" | ||
#include "cuda_type.h" | ||
|
||
#ifndef CUDA_LONG | ||
#define CUDA_LONG int32_t | ||
#endif | ||
|
||
using namespace Ort::Custom; | ||
|
||
template <typename T> __device__ __inline__ T _exp_typed(const T x); | ||
|
||
template <> __device__ __inline__ float _exp_typed(const float x) { return expf(x); } | ||
|
||
#if __CUDA_ARCH__ < 700 | ||
template <> __device__ __inline__ half _exp_typed(const half x) { | ||
return __float2half(expf(__half2float(x))); | ||
} | ||
#else | ||
template <> __device__ __inline__ half _exp_typed(const half x) { return hexp(x); } | ||
#endif | ||
|
||
template <typename T> __device__ __inline__ T sigmoid(const T a) { | ||
return a > T(0) ? (T)1 / ((T)1. + _exp_typed<T>(-a)) | ||
: (T)1 - (T)1 / ((T)1 + _exp_typed<T>(a)); | ||
} | ||
|
||
#if __CUDA_ARCH__ < 700 | ||
template <> __device__ __inline__ half sigmoid(const half a) { | ||
return __float2half(sigmoid(__half2float(a))); | ||
} | ||
#endif | ||
|
||
template <typename T> __device__ __inline__ T mul_sigmoid(const T a) { return a * sigmoid(a); } | ||
|
||
#if __CUDA_ARCH__ < 700 | ||
template <> __device__ __inline__ half mul_sigmoid(const half a) { | ||
float x = __half2float(a); | ||
return __float2half(x * sigmoid(x)); | ||
} | ||
#endif | ||
|
||
template <typename T> __device__ __inline__ T mul_mul_sigmoid(const T x, const T y) { | ||
return x * y * sigmoid(y); | ||
} | ||
|
||
#if __CUDA_ARCH__ < 700 | ||
template <> __device__ __inline__ half mul_mul_sigmoid(const half x, const half y) { | ||
float hy = __half2float(y); | ||
return __float2half(__half2float(x) * hy * sigmoid(hy)); | ||
} | ||
#endif | ||
|
||
template <typename T> | ||
__global__ void MulSigmoidKernel(T *output_data, const T *input_data, CUDA_LONG N) { | ||
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; | ||
if (id >= N) | ||
return; | ||
output_data[id] = mul_sigmoid(input_data[id]); | ||
} | ||
|
||
template <typename T> | ||
__global__ void MulMulSigmoidKernel(T *output_data, const T *px, const T *py, CUDA_LONG N, | ||
CUDA_LONG Nx, CUDA_LONG Ny) { | ||
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x; | ||
if (id >= N) | ||
return; | ||
output_data[id] = mul_mul_sigmoid(px[id % Nx], py[id % Ny]); | ||
} | ||
|
||
template <typename T> | ||
cudaError_t _LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output) { | ||
constexpr int blockSize = 256; | ||
const int gridSize = (input_length + blockSize - 1) / blockSize; | ||
using TT = typename contrib::CudaT<T>::MappedType; | ||
MulSigmoidKernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output), reinterpret_cast<const TT*>(input), input_length); | ||
return cudaGetLastError(); | ||
} | ||
|
||
template <> | ||
cudaError_t LaunchMulSigmoidKernel<float>(cudaStream_t stream, int input_length, const float* input, float* output) { | ||
return _LaunchMulSigmoidKernel(stream, input_length, input, output); | ||
} | ||
|
||
template <> | ||
cudaError_t LaunchMulSigmoidKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, const ortc::MFloat16* input, ortc::MFloat16* output) { | ||
return _LaunchMulSigmoidKernel(stream, input_length, input, output); | ||
} | ||
|
||
template <typename T> | ||
cudaError_t _LaunchMulMulSigmoidKernel(cudaStream_t stream, int input_length_x, int input_length_y, | ||
const T* input_data_x, const T* input_data_y, T* output) { | ||
int input_length = std::max(input_length_x, input_length_y); | ||
constexpr int blockSize = 256; | ||
const int gridSize = (input_length + blockSize - 1) / blockSize; | ||
using TT = typename contrib::CudaT<T>::MappedType; | ||
MulMulSigmoidKernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output), | ||
reinterpret_cast<const TT*>(input_data_x), | ||
reinterpret_cast<const TT*>(input_data_y), | ||
input_length, input_length_x, input_length_y); | ||
return cudaGetLastError(); | ||
} | ||
|
||
template <> | ||
cudaError_t LaunchMulMulSigmoidKernel<float>(cudaStream_t stream, int input_length_x, int input_length_y, | ||
const float* input_data_x, const float* input_data_y, float* output) { | ||
return _LaunchMulMulSigmoidKernel(stream, input_length_x, input_length_y, input_data_x, input_data_y, output); | ||
} | ||
|
||
template <> | ||
cudaError_t LaunchMulMulSigmoidKernel<ortc::MFloat16>(cudaStream_t stream, int input_length_x, int input_length_y, | ||
const ortc::MFloat16* input_data_x, const ortc::MFloat16* input_data_y, | ||
ortc::MFloat16* output) { | ||
return _LaunchMulMulSigmoidKernel(stream, input_length_x, input_length_y, input_data_x, input_data_y, output); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include <cuda.h> | ||
#include <cuda_runtime.h> | ||
|
||
template <typename T> | ||
cudaError_t LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output); | ||
|
||
template <typename T> | ||
cudaError_t LaunchMulMulSigmoidKernel(cudaStream_t stream, int input_length_x, int input_length_y, | ||
const T* input_data_x, const T* input_data_y, T* output); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters