-
Notifications
You must be signed in to change notification settings - Fork 10k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
rwkv6: add wkv6 support for Vulkan backend (#10829)
* rwkv_wkv6 vulkan shader * RWKV_WKV6 Vulkan op tests passed Signed-off-by: Molly Sophia <[email protected]> * Apply code format changes Signed-off-by: Molly Sophia <[email protected]> * add [[unroll]] and remove unnecessary conditions * add uma support * fix erros in EditorConfig Checker --------- Signed-off-by: Molly Sophia <[email protected]> Co-authored-by: Molly Sophia <[email protected]>
- Loading branch information
1 parent
08ea539
commit 160bc03
Showing
3 changed files
with
245 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
#version 450 | ||
|
||
#extension GL_EXT_control_flow_attributes : require | ||
|
||
#define BLOCK_SIZE 64 | ||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; | ||
|
||
layout(push_constant) uniform Parameters { | ||
uint B; | ||
uint T; | ||
uint C; | ||
uint H; | ||
}; | ||
|
||
layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; }; | ||
layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; }; | ||
layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; }; | ||
layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; }; | ||
layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; }; | ||
layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; }; | ||
layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; | ||
|
||
shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE]; | ||
|
||
void main() { | ||
const uint head_size = BLOCK_SIZE; | ||
const uint batch_id = gl_WorkGroupID.x / H; | ||
const uint head_id = gl_WorkGroupID.x % H; | ||
const uint tid = gl_LocalInvocationID.x; | ||
|
||
const uint state_size = C * head_size; | ||
const uint n_seq_tokens = T / B; | ||
|
||
if (batch_id >= B || head_id >= H) { | ||
return; | ||
} | ||
|
||
A_TYPE state[BLOCK_SIZE]; | ||
[[unroll]] for (uint i = 0; i < head_size; i++) { | ||
state[i] = state_in[batch_id * state_size + head_id * head_size * head_size | ||
+ i * head_size + tid]; | ||
} | ||
|
||
barrier(); | ||
_tf[tid] = tf[head_id * head_size + tid]; | ||
barrier(); | ||
|
||
const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; | ||
const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; | ||
|
||
for (uint t = start_t; t < end_t; t += C) { | ||
barrier(); | ||
_k[tid] = k[t]; | ||
_r[tid] = r[t]; | ||
_td[tid] = td[t]; | ||
barrier(); | ||
|
||
const A_TYPE v_val = v[t]; | ||
A_TYPE y = 0.0; | ||
|
||
[[unroll]] for (uint j = 0; j < head_size; j += 4) { | ||
vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); | ||
vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); | ||
vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); | ||
vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); | ||
vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); | ||
|
||
vec4 kv = k_vec * v_val; | ||
|
||
vec4 temp = tf_vec * kv + s_vec; | ||
y += dot(r_vec, temp); | ||
|
||
s_vec = s_vec * td_vec + kv; | ||
state[j] = s_vec.x; | ||
state[j+1] = s_vec.y; | ||
state[j+2] = s_vec.z; | ||
state[j+3] = s_vec.w; | ||
} | ||
|
||
dst[t] = y; | ||
} | ||
|
||
[[unroll]] for (uint i = 0; i < head_size; i++) { | ||
dst[T * C + batch_id * state_size + head_id * head_size * head_size | ||
+ i * head_size + tid] = state[i]; | ||
} | ||
} |