Skip to content

Commit

Permalink
Add custom ops ReplaceZero
Browse files Browse the repository at this point in the history
  • Loading branch information
xadupre committed Jun 6, 2024
1 parent 1e8c121 commit d58c35c
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 5 deletions.
5 changes: 4 additions & 1 deletion operators/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "cuda/add_mul.h"
#include "cuda/fast_gelu.h"
#include "cuda/negxplus1.h"
#include "cuda/replace_zero.h"
#endif

FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
Expand All @@ -28,13 +29,15 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<float>),
#if ORT_API_VERSION >= 16

CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>)
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<ortc::MFloat16>)
#endif
#endif
);
Expand Down
42 changes: 42 additions & 0 deletions operators/cuda/replace_zero.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "ocos.h"
#include "replace_zero_impl.cuh"
#include "ortx_common.h"

namespace contrib {

template <typename T>
struct ReplaceZero {
template <typename TDict>
OrtxStatus OnModelAttach(const TDict& dict) {
float default_value=0;
by_ = dict.TryToGetAttributeWithDefault("by", default_value);
return {};
}
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
const ortc::Tensor<T>& input,
ortc::Tensor<T>& output) const {
const T* input_data = input.Data();
auto input_shape = input.Shape();
T* output_data = output.Allocate(input_shape);
auto input_length = input.NumberOfElement();
if (0 == input_length) {
return {};
}

LaunchReplaceZeroKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
input_length,
input_data,
output_data,
by_);
return {};
}

private:
float by_;
};

} // namespace contrib
64 changes: 64 additions & 0 deletions operators/cuda/replace_zero_impl.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "device_prop.cuh"
#include "utils.cuh"
#include "replace_zero_impl.cuh"
#include "cuda_type.h"

#ifndef CUDA_LONG
#define CUDA_LONG int32_t
#endif

using namespace Ort::Custom;

template <typename T> __device__ __inline__ T _replace_zero(const T x, const T by) {
return x == (T)0 ? by : x;
}

template <> __device__ __inline__ half _replace_zero(const half x, const half by) {
#if __CUDA_ARCH__ < 700
return __half2float(x) == 0 ? by : x;
#else
return x == (half)0 ? by : x;
#endif
}

template <typename T>
__global__ void ReplaceZeroKernel(T *output_data, const T *input_data, CUDA_LONG N, const T by) {
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
if (id >= N)
return;
output_data[id] = _replace_zero(input_data[id], by);
}

template <typename T> T _cvt(float value) { return (T)value; }

template <> half _cvt(float value) { return __float2half(value); }

template <typename T>
cudaError_t _LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by) {
if (input_length == 0)
return cudaGetLastError();
using TT = typename contrib::CudaT<T>::MappedType;

CUDA_LONG N = static_cast<CUDA_LONG>(input_length);

const int num_threads_per_block = 256;
const int num_elements_per_thread = (N + num_threads_per_block - 1) / num_threads_per_block;

TT cby = _cvt<TT>(by);
ReplaceZeroKernel<TT><<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(
reinterpret_cast<TT*>(output_data), reinterpret_cast<const TT*>(input_data), N, cby);
return cudaGetLastError();
}

template <>
cudaError_t LaunchReplaceZeroKernel<float>(cudaStream_t stream, int input_length, const float* input_data, float* output_data, float by) {
return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by);
}

template <>
cudaError_t LaunchReplaceZeroKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, const ortc::MFloat16* input_data, ortc::MFloat16* output_data, float by) {
return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by);
}
9 changes: 9 additions & 0 deletions operators/cuda/replace_zero_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <cuda.h>
#include <cuda_runtime.h>

template <typename T>
cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by);
68 changes: 64 additions & 4 deletions test/cuda/test_cudaops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import numpy as np
from numpy.testing import assert_almost_equal
from numpy.testing import assert_almost_equal, assert_allclose
from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto
from onnx.reference import ReferenceEvaluator
from onnx.reference.op_run import OpRun
Expand Down Expand Up @@ -151,7 +151,7 @@ def test_cuda_negxplus1(self):
self._negxplus1_cuda(TensorProto.FLOAT16)

def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3, 2, 3), shapec=(3, 2, 3)):
from onnx_extended.ortops.optim.cuda import get_ort_ext_libs
from ai.onnx.contrib import get_ort_ext_libs

model1 = helper.make_model(
helper.make_graph(
Expand Down Expand Up @@ -181,7 +181,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
f"{op_type}SharedInput",
["X", "Y", "Z"],
["XY", "XZ"],
domain="onnx_extended.ortops.optim.cuda",
domain="ai.onnx.contrib",
)
],
"nd",
Expand All @@ -197,7 +197,7 @@ def _addmul_shared_input_cuda(self, itype, op_type, shapea=(3, 2, 3), shapeb=(3,
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("onnx_extended.ortops.optim.cuda", 1),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)
Expand Down Expand Up @@ -262,6 +262,66 @@ def test_add_shared_input_cuda_broadcast2(self):
shapec=(3, 2, 3),
)

def _replace_zero_cuda(self, itype):
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
model1 = helper.make_model(
helper.make_graph(
[
helper.make_node("Equal", ["X", "zero"], ["cond"]),
helper.make_node("Where", ["cond", "cst", "X"], ["Y"]),
],
"nd",
[helper.make_tensor_value_info("X", itype, [None, None, None])],
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
[
numpy_helper.from_array(np.array([0], dtype=dtype), name="zero"),
numpy_helper.from_array(np.array([1.67], dtype=dtype), name="cst"),
],
),
opset_imports=[helper.make_opsetid("", 18)],
ir_version=9,
)

model2 = helper.make_model(
helper.make_graph(
[
helper.make_node(
"ReplaceZero",
["X"],
["Y"],
by=1.67,
domain="ai.onnx.contrib",
)
],
"nd",
[helper.make_tensor_value_info("X", itype, [None, None, None])],
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
),
opset_imports=[
helper.make_opsetid("", 18),
helper.make_opsetid("ai.onnx.contrib", 1),
],
ir_version=9,
)

dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
x = (np.arange(18) - 4).reshape((3, 2, 3)).astype(dtype)

feeds1 = dict(X=x)
ref = ReferenceEvaluator(model1)
expected = ref.run(None, feeds1)[0]

opts = _ort.SessionOptions()
opts.register_custom_ops_library(_get_library_path())
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
got = sess.run(None, feeds1)[0]
assert_allclose(expected, got, atol=1e-5)

@unittest.skipIf(not has_cuda(), reason="cuda not available")
def test_replace_zero_cuda(self):
self._replace_zero_cuda(TensorProto.FLOAT)
self._replace_zero_cuda(TensorProto.FLOAT16)


if __name__ == "__main__":
unittest.main()

0 comments on commit d58c35c

Please sign in to comment.