You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Custom Ops are based on ONNXRuntime-extensions API, especially **OrtLiteCustomOp** and **Tensor** class. C++ template metaprogramming is heavily used under the hood to provide big flexibility to the Custom Op authors on the parameter's count, type and order.
4
+
5
+
## Basic scenario
6
+
7
+
You have 2 ways to write a custom op: by writing a function, or by writing a structure.
8
+
9
+
### Custom op in the form of function
10
+
11
+
If your kernel is simple, you can use this option by just providing a function to compute the customized kernel. That function can have arbitrary number of inputs and outputs. For the inputs that are mandatory, their type would be like:
12
+
13
+
```C++
14
+
const Ort::Custom::Tensor<T>&
15
+
// or
16
+
const Ort::Custom::Tensor<T>*
17
+
```
18
+
19
+
For the inputs that are optional, their type would be like:
20
+
21
+
```C++
22
+
std::optional<const Ort::Custom::Tensor<T>*>
23
+
```
24
+
25
+
The function can also accept the pointer of **CUDAKernelContext**, where you can retrieve CUDA stream and other CUDA resources, if it requires to be run in CUDA GPU.
26
+
27
+
The function will return the type **OrtStatusPtr**
28
+
29
+
Please refer to [negpos_def.h](https://github.com/microsoft/onnxruntime-extensions/blob/main/operators/math/cuda/negpos_def.h) as an example and [tensor_tuple.inc](https://github.com/microsoft/onnxruntime-extensions/blob/main/include/custom_op/tensor_tuple.inc) for more possible parameter types.
30
+
31
+
### Custom op in the form of structure
32
+
33
+
If the kernel is complicated and there are extra properties of the custom op, you can use this option by providing a C++ structure where you can put these properties as the structure's member variables. Besides that, you also need to provide the following member functions:
34
+
35
+
```C++
36
+
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) // This function initialize the properties of the custom op
37
+
38
+
OrtStatusPtr Compute(...) const // This function computes the customized kernel.
39
+
```
40
+
41
+
The specification of the parameters of the Compute function is the same as the first way (custom op in the form of function)
42
+
43
+
## Advanced scenario
44
+
45
+
In some cases you need more control on the parameters, in this case you have to use the structure form, which you need to provide the implementations of the following member functions such as:
46
+
47
+
```C++
48
+
// By default the function will return OrtMemType::OrtMemTypeDefault for all the inputs,
49
+
// you can provide your own implementation to specify the ith input is in CPU or GPU.
void* q, // batch_size x seqlen_q x num_heads x head_size
352
-
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
353
-
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
358
+
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
359
+
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
354
360
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
355
361
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
356
362
void* out, // batch_size x seqlen_q x num_heads x head_size
void* q, // batch_size x seqlen_q x num_heads x head_size
56
-
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
57
-
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
56
+
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
57
+
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
58
58
void* k, // batch_size x seqlen_k_new x num_heads_k x head_size
59
59
void* v, // batch_size x seqlen_k_new x num_heads_k x head_size
60
60
void* out, // batch_size x seqlen_q x num_heads x head_size
0 commit comments