Skip to content

Commit 95d65e4

Browse files
authored
sync to flash attention kernel 2.5.9 and add document of how to write custom op (microsoft#757)
* sync to flash attention kernel 2.5.9 * support users to overload GetMayInplace and ReleaseMayInplace * Undo the change for pybind11 dependency
1 parent b436d09 commit 95d65e4

File tree

12 files changed

+1315
-1006
lines changed

12 files changed

+1315
-1006
lines changed

cmake/ext_cuda.cmake

-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcudafe \"--diag_suppress=expr_has_no
3030

3131
add_compile_definitions(USE_CUDA)
3232

33-
set(OCOS_USE_MEMORY_EFFICIENT_ATTENTION OFF) # turn off for the build time. Turn them on when these 2 libs are really in use
34-
set(OCOS_USE_FLASH_ATTENTION OFF)
3533
if (OCOS_USE_FLASH_ATTENTION)
3634
message(STATUS "Enable flash attention")
3735
add_compile_definitions(OCOS_USE_FLASH_ATTENTION)

docs/How_to_write_custom_op.md

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# How to write custom ops
2+
3+
Custom Ops are based on ONNXRuntime-extensions API, especially **OrtLiteCustomOp** and **Tensor** class. C++ template metaprogramming is heavily used under the hood to provide big flexibility to the Custom Op authors on the parameter's count, type and order.
4+
5+
## Basic scenario
6+
7+
You have 2 ways to write a custom op: by writing a function, or by writing a structure.
8+
9+
### Custom op in the form of function
10+
11+
If your kernel is simple, you can use this option by just providing a function to compute the customized kernel. That function can have arbitrary number of inputs and outputs. For the inputs that are mandatory, their type would be like:
12+
13+
```C++
14+
const Ort::Custom::Tensor<T>&
15+
// or
16+
const Ort::Custom::Tensor<T>*
17+
```
18+
19+
For the inputs that are optional, their type would be like:
20+
21+
```C++
22+
std::optional<const Ort::Custom::Tensor<T>*>
23+
```
24+
25+
The function can also accept the pointer of **CUDAKernelContext**, where you can retrieve CUDA stream and other CUDA resources, if it requires to be run in CUDA GPU.
26+
27+
The function will return the type **OrtStatusPtr**
28+
29+
Please refer to [negpos_def.h](https://github.com/microsoft/onnxruntime-extensions/blob/main/operators/math/cuda/negpos_def.h) as an example and [tensor_tuple.inc](https://github.com/microsoft/onnxruntime-extensions/blob/main/include/custom_op/tensor_tuple.inc) for more possible parameter types.
30+
31+
### Custom op in the form of structure
32+
33+
If the kernel is complicated and there are extra properties of the custom op, you can use this option by providing a C++ structure where you can put these properties as the structure's member variables. Besides that, you also need to provide the following member functions:
34+
35+
```C++
36+
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) // This function initialize the properties of the custom op
37+
38+
OrtStatusPtr Compute(...) const // This function computes the customized kernel.
39+
```
40+
41+
The specification of the parameters of the Compute function is the same as the first way (custom op in the form of function)
42+
43+
## Advanced scenario
44+
45+
In some cases you need more control on the parameters, in this case you have to use the structure form, which you need to provide the implementations of the following member functions such as:
46+
47+
```C++
48+
// By default the function will return OrtMemType::OrtMemTypeDefault for all the inputs,
49+
// you can provide your own implementation to specify the ith input is in CPU or GPU.
50+
static OrtMemType GetInputMemoryType(size_t input_index)
51+
52+
// You can specify input i shares the same memory with output j if possible, by allocating
53+
// two array with same length for the pointer input_index and output_index seperately, and
54+
// then let (*input_index)[k] = i and (*output_index)[k] = j.
55+
// The return value is the length of the allocated array.
56+
static size_t GetMayInplace(int** input_index, int** output_index)
57+
58+
// Release the allocated array from the GetMayInplace() function.
59+
static void ReleaseMayInplace(int* input_index, int* output_index)
60+
```

include/custom_op/custom_op_lite.h

+7
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,13 @@ struct OrtLiteCustomOp : public OrtCustomOp {
886886
return INPUT_OUTPUT_OPTIONAL;
887887
};
888888
#endif
889+
890+
#if ORT_API_VERSION >= 18
891+
OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t {
892+
return 0;
893+
};
894+
OrtCustomOp::ReleaseMayInplace = [](int*, int*) -> void {};
895+
#endif
889896
}
890897

891898
const std::string op_name_;

include/op_def_struct.h

+25
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,18 @@ struct CustomOp_defined_getInputMemoryType : std::false_type {};
106106
template <typename T>
107107
struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {};
108108

109+
template <typename T, typename = void>
110+
struct CustomOp_defined_getMayInplace : std::false_type {};
111+
112+
template <typename T>
113+
struct CustomOp_defined_getMayInplace<T, std::void_t<decltype(&T::GetMayInplace)>> : std::true_type {};
114+
115+
template <typename T, typename = void>
116+
struct CustomOp_defined_releaseMayInplace : std::false_type {};
117+
118+
template <typename T>
119+
struct CustomOp_defined_releaseMayInplace<T, std::void_t<decltype(&T::ReleaseMayInplace)>> : std::true_type {};
120+
109121
template <typename CustomOpKernel>
110122
struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
111123
using ComputeFunction = decltype(&CustomOpKernel::Compute);
@@ -192,6 +204,19 @@ struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
192204
};
193205
}
194206

207+
#if ORT_API_VERSION >= 18
208+
if constexpr (CustomOp_defined_getMayInplace<CustomOpKernel>::value) {
209+
OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) -> size_t {
210+
return CustomOpKernel::GetMayInplace(input_index, output_index);
211+
};
212+
}
213+
if constexpr (CustomOp_defined_releaseMayInplace<CustomOpKernel>::value) {
214+
OrtCustomOp::ReleaseMayInplace = [](int* input_index, int* output_index) -> void {
215+
CustomOpKernel::ReleaseMayInplace(input_index, output_index);
216+
};
217+
}
218+
#endif
219+
195220
OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
196221
const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
197222
if (api == nullptr) {

onnxruntime_extensions/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
__author__ = "Microsoft"
1212

13-
1413
from ._version import __version__
1514
from ._ocos import get_library_path
1615
from ._ocos import Opdef, PyCustomOpDef
@@ -66,6 +65,10 @@ def _unimplemented(*args, **kwargs):
6665
gen_processing_models = _unimplemented
6766
OrtPyFunction = _unimplemented
6867
ort_inference = _unimplemented
68+
PyOrtFunction = _unimplemented
69+
optimize_model = _unimplemented
70+
make_onnx_model = _unimplemented
71+
ONNXRuntimeError = _unimplemented
6972

7073
else:
7174
__all__ += _offline_api

operators/cuda/attention_lib/flash_attention/flash.h

+10
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ struct Flash_fwd_params : public Qkv_params {
8787
// The indices to index into the KV cache.
8888
int* __restrict__ cache_batch_idx = nullptr;
8989

90+
// Paged KV cache
91+
int * __restrict__ block_table;
92+
index_t block_table_batch_stride;
93+
int page_block_size;
94+
95+
float rp_dropout;
96+
9097
// Local window size
9198
int window_size_left = -1;
9299
int window_size_right = -1;
@@ -102,6 +109,9 @@ struct Flash_fwd_params : public Qkv_params {
102109

103110
int num_splits = 0; // For split-KV version
104111

112+
void * __restrict__ alibi_slopes_ptr;
113+
index_t alibi_slopes_batch_stride;
114+
105115
const cudaDeviceProp* dprops = nullptr;
106116
};
107117

operators/cuda/attention_lib/flash_attention/flash_api.cc

+22-7
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@ void set_params_fprop(Flash_fwd_params& params,
3232
bool is_bf16,
3333
bool kv_bsnh = true,
3434
int window_size_left = -1,
35-
int window_size_right = -1) {
35+
int window_size_right = -1,
36+
bool paged_KV = false,
37+
int page_block_size = -1) {
3638
// Set the pointers and strides.
3739
params.q_ptr = q;
3840
params.k_ptr = k;
@@ -64,8 +66,8 @@ void set_params_fprop(Flash_fwd_params& params,
6466

6567
if (cu_seqlens_q_d == nullptr) {
6668
params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
67-
params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
68-
params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
69+
params.k_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0)
70+
params.v_batch_stride = (paged_KV ? page_block_size : seqlen_k) * num_heads_k * head_size; // stride(0)
6971
params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
7072
} else {
7173
params.q_batch_stride = 0;
@@ -99,6 +101,10 @@ void set_params_fprop(Flash_fwd_params& params,
99101
params.scale_softmax = softmax_scale;
100102
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
101103

104+
params.rp_dropout = 1.f;
105+
params.alibi_slopes_ptr = nullptr;
106+
params.alibi_slopes_batch_stride = 0;
107+
102108
// In our API, causal/unidirectional determines if we only look at prior tokens. However, the flash API seperates
103109
// local and causal, meaning when we have local window size
104110
params.is_causal = is_causal;
@@ -349,8 +355,8 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in
349355
OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
350356
cudaStream_t stream,
351357
void* q, // batch_size x seqlen_q x num_heads x head_size
352-
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
353-
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
358+
void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
359+
void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
354360
void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
355361
void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
356362
void* out, // batch_size x seqlen_q x num_heads x head_size
@@ -374,7 +380,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
374380
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
375381
int local_window_size,
376382
bool is_rotary_interleaved,
377-
bool is_packed_qkv) {
383+
bool is_packed_qkv,
384+
int32_t* block_table, // batch_size x max_num_blocks_per_seq
385+
int32_t max_num_blocks_per_seq,
386+
int32_t page_block_size) {
378387
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
379388
const int head_size_rounded = round_multiple(head_size, 32);
380389
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
@@ -398,7 +407,9 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
398407
is_bf16,
399408
past_bsnh,
400409
local_window_size,
401-
is_causal ? 0 : -1);
410+
is_causal ? 0 : -1,
411+
block_table != nullptr,
412+
page_block_size);
402413
params.dprops = &dprops;
403414

404415
if (k_new != nullptr && v_new != nullptr) {
@@ -454,6 +465,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
454465
params.oaccum_ptr = nullptr;
455466
}
456467

468+
params.block_table = block_table;
469+
params.block_table_batch_stride = max_num_blocks_per_seq;
470+
params.page_block_size = page_block_size;
471+
457472
// Only split kernel supports appending to KV cache
458473
run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr);
459474

operators/cuda/attention_lib/flash_attention/flash_api.h

+6-3
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ OrtStatusPtr mha_varlen_fwd(const cudaDeviceProp& dprops,
5353
OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
5454
cudaStream_t stream,
5555
void* q, // batch_size x seqlen_q x num_heads x head_size
56-
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
57-
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size
56+
void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
57+
void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table
5858
void* k, // batch_size x seqlen_k_new x num_heads_k x head_size
5959
void* v, // batch_size x seqlen_k_new x num_heads_k x head_size
6060
void* out, // batch_size x seqlen_q x num_heads x head_size
@@ -78,7 +78,10 @@ OrtStatusPtr mha_fwd_kvcache(const cudaDeviceProp& dprops,
7878
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
7979
int local_window_size = -1,
8080
bool is_rotary_interleaved = false,
81-
bool is_packed_qkv = false);
81+
bool is_packed_qkv = false,
82+
int32_t* block_table = nullptr, // batch_size x max_num_blocks_per_seq
83+
int32_t max_num_blocks_per_seq = -1,
84+
int32_t page_block_size = 1);
8285

8386
size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);
8487

0 commit comments

Comments
 (0)