From 6e89c9648c529ef6e7e9d981aedb2d1ff35e54fa Mon Sep 17 00:00:00 2001 From: lshAlgorithm Date: Mon, 12 May 2025 21:47:10 +0800 Subject: [PATCH 1/4] FINISHED! Signed-off-by: lshAlgorithm --- CMakeLists.txt | 16 ++++++ python/generate_completions.py | 2 +- rwkv_operators_wkv_v7.inc | 96 ++++++++++++++++++++++------------ script.sh | 6 +++ 4 files changed, 85 insertions(+), 35 deletions(-) create mode 100755 script.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 11384de..9a99a5a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -153,11 +153,27 @@ if (NOT MSVC) if (RWKV_GPROF) add_compile_options(-pg) endif() + if (RWKV_AVX2) + add_compile_options(-mavx2) + add_compile_definitions(__AVX2__) + endif() + if (RWKV_FMA) + add_compile_options(-mfma) + endif() if (RWKV_NATIVE) add_compile_options(-march=native) endif() endif() +if (CMAKE_BUILD_TYPE STREQUAL "Release") + message(STATUS "Here we are in Release") + if (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + add_compile_options(-O3) + elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + add_compile_options(/O2) + endif() +endif() + # # Build libraries # diff --git a/python/generate_completions.py b/python/generate_completions.py index cc14ef1..3b568c7 100644 --- a/python/generate_completions.py +++ b/python/generate_completions.py @@ -17,7 +17,7 @@ Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **focused on CPU**, but cuBLAS is also supported.""" # How many completions to generate. -generation_count: int = 3 +generation_count: int = 1 # Token count per single completion. tokens_per_generation: int = 100 diff --git a/rwkv_operators_wkv_v7.inc b/rwkv_operators_wkv_v7.inc index eb4541d..4803aa1 100644 --- a/rwkv_operators_wkv_v7.inc +++ b/rwkv_operators_wkv_v7.inc @@ -1,7 +1,7 @@ // Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L8 // Original code by Harrison Vanderbyl. // TODO Fix 1. unaligned memory access on Linux with AVX2, 2. tiny-rwkv with AVX-512 -/*#ifdef __AVX512F__ +#ifdef __AVX512F__ #include #define SIMD_WIDTH 16 #define LOAD(x) _mm512_load_ps(x) @@ -9,7 +9,7 @@ #define SET1(x) _mm512_set1_ps(x) #define MULTIPLY(x, y) _mm512_mul_ps(x, y) #define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z) -#elif __AVX2__ +#elif defined(__AVX2__) #include #define SIMD_WIDTH 8 #define LOAD(x) _mm256_load_ps(x) @@ -17,6 +17,7 @@ #define SET1(x) _mm256_set1_ps(x) #define MULTIPLY(x, y) _mm256_mul_ps(x, y) #define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z) + #define ADD(x, y) _mm256_add_ps(x, y) #elif defined(__ARM_NEON) || defined(__ARM_NEON__) #include #define SIMD_WIDTH 4 @@ -25,14 +26,32 @@ #define SET1(x) vdupq_n_f32(x) #define MULTIPLY(x, y) vmulq_f32(x, y) #define MULTADD(x, y, z) vmlaq_f32(z, x, y) -#else*/ +#else #define SIMD_WIDTH 1 #define LOAD(x) *x #define STORE(x, y) *x = y #define SET1(x) x #define MULTIPLY(x, y) x * y #define MULTADD(x, y, z) x * y + z -//#endif +#endif + + +inline float horizontal_sum_avx(__m256 vec) { + // 水平相加:将8个float两两相加,得到4个结果 + __m256 sum1 = _mm256_hadd_ps(vec, vec); + + // 再次水平相加:将4个结果两两相加,得到2个结果 + __m256 sum2 = _mm256_hadd_ps(sum1, sum1); + + // 提取低128位和高128位 + __m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum2, 0), + _mm256_extractf128_ps(sum2, 1)); + + // 从SSE寄存器中提取最终结果 + float result; + _mm_store_ss(&result, sum128); + return result; +} static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tensor * src, int ith, int nth, void * userdata) { // const size_t T = result->ne[1]; @@ -41,7 +60,7 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens const size_t H = result->src[1]->ne[1]; const size_t T = result->src[1]->ne[2]; GGML_ASSERT(C == S * H); - + float * result_data = (float *) result->data; float * state_out = (float *) result->data + C * T; @@ -62,41 +81,50 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens size_t t_offset = t * t_stride; float * state_in = (t == 0) ? state : state_out; - + // transpose_square_inplace(state_in, C/H); for (size_t h = ith; h < H; h += nth) { size_t h_offset = h * h_stride; size_t t_h_offset = t_offset + h_offset; size_t h_2d_offset = h * h_stride_2d; - for (size_t i = 0; i < C / H; i++) { - size_t t_h_i_offset = t_h_offset + i; - size_t h_2d_i_offset = h_2d_offset + i * h_stride; - - auto v_val = v[t_h_i_offset]; - - float sa = 0; - for (size_t j = 0; j < C / H; j++) { - sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j]; - } - - if (i == 0) { - memset(&result_data[t_h_offset], 0, h_stride * sizeof(float)); - } - - for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { - size_t t_h_j_offset = t_h_offset + j; - size_t h_2d_i_j_offset = h_2d_i_offset + j; - - auto r_val = r[t_h_j_offset]; - auto w_val = w[t_h_j_offset]; - auto k_val = k[t_h_j_offset]; - auto b_val = b[t_h_j_offset]; - auto kv_val = v_val * k_val; - auto prev_state_val = state_in[h_2d_i_j_offset]; - state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; - result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val; - } + for (size_t i = 0; i < C / H; i ++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + if (i == 0) { + memset(&result_data[t_h_offset], 0, h_stride * sizeof(float)); + } + + float sa = .0; + for (size_t j = 0; j < C / H; j++) { + sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j]; + } + auto v_vec = SET1(v[t_h_i_offset]); + auto sa_vec = SET1(sa); + + auto sum = _mm256_setzero_ps(); + for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + auto r_val = LOAD(&r[t_h_j_offset]); + auto w_val = LOAD(&w[t_h_j_offset]); + auto k_val = LOAD(&k[t_h_j_offset]); + auto b_val = LOAD(&b[t_h_j_offset]); + auto prev_state_val = LOAD(&state_in[h_2d_i_j_offset]); + // auto kv_val = v_val * k_val; + auto kv_val = MULTIPLY(v_vec, k_val); + // state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; + auto sab_val = MULTIPLY(sa_vec, b_val); + auto state_out_val = MULTADD(prev_state_val, w_val, kv_val); + state_out_val = ADD(state_out_val, sab_val); + STORE(&state_out[h_2d_i_j_offset], state_out_val); + // result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val; + auto result = MULTIPLY(state_out_val, r_val); + // auto sum = LOAD(&result_data[t_h_i_offset]); + sum = ADD(sum, result); + } + result_data[t_h_i_offset] = horizontal_sum_avx(sum); } + } } diff --git a/script.sh b/script.sh new file mode 100755 index 0000000..6aeb10e --- /dev/null +++ b/script.sh @@ -0,0 +1,6 @@ +#!/bin/bash +rm -rf build +mkdir build +cd build +cmake .. +cmake --build . --config Release From 493682dda134213068e53db93bd9eab1adc3f431 Mon Sep 17 00:00:00 2001 From: lshAlgorithm Date: Wed, 14 May 2025 01:07:18 +0800 Subject: [PATCH 2/4] change format Signed-off-by: lshAlgorithm --- rwkv_operators_wkv_v7.inc | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/rwkv_operators_wkv_v7.inc b/rwkv_operators_wkv_v7.inc index 4803aa1..55dd14e 100644 --- a/rwkv_operators_wkv_v7.inc +++ b/rwkv_operators_wkv_v7.inc @@ -9,6 +9,8 @@ #define SET1(x) _mm512_set1_ps(x) #define MULTIPLY(x, y) _mm512_mul_ps(x, y) #define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z) + #define ADD(x, y) _mm512_add_ps(x, y) + #define ZEROS() _mm512_setzero_ps() #elif defined(__AVX2__) #include #define SIMD_WIDTH 8 @@ -18,6 +20,7 @@ #define MULTIPLY(x, y) _mm256_mul_ps(x, y) #define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z) #define ADD(x, y) _mm256_add_ps(x, y) + #define ZEROS() _mm256_setzero_ps() #elif defined(__ARM_NEON) || defined(__ARM_NEON__) #include #define SIMD_WIDTH 4 @@ -36,18 +39,11 @@ #endif -inline float horizontal_sum_avx(__m256 vec) { - // 水平相加:将8个float两两相加,得到4个结果 +inline float horizontal_sum(__m256 vec) { __m256 sum1 = _mm256_hadd_ps(vec, vec); - - // 再次水平相加:将4个结果两两相加,得到2个结果 __m256 sum2 = _mm256_hadd_ps(sum1, sum1); - - // 提取低128位和高128位 __m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum2, 0), _mm256_extractf128_ps(sum2, 1)); - - // 从SSE寄存器中提取最终结果 float result; _mm_store_ss(&result, sum128); return result; @@ -81,7 +77,6 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens size_t t_offset = t * t_stride; float * state_in = (t == 0) ? state : state_out; - // transpose_square_inplace(state_in, C/H); for (size_t h = ith; h < H; h += nth) { size_t h_offset = h * h_stride; size_t t_h_offset = t_offset + h_offset; @@ -94,14 +89,24 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens memset(&result_data[t_h_offset], 0, h_stride * sizeof(float)); } + // auto sa_vec = ZEROS(); + // for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { + // sa_vec = ADD(sa_vec, MULTIPLY( + // LOAD(&a[t_h_offset + j]), + // LOAD(&state_in[h_2d_i_offset + j]) + // ) + // ); + // } + // float sa = horizontal_sum(sa_vec); float sa = .0; for (size_t j = 0; j < C / H; j++) { sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j]; } + auto v_vec = SET1(v[t_h_i_offset]); - auto sa_vec = SET1(sa); + sa_vec = SET1(sa); - auto sum = _mm256_setzero_ps(); + auto sum = ZEROS(); for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { size_t t_h_j_offset = t_h_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j; @@ -110,19 +115,23 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens auto k_val = LOAD(&k[t_h_j_offset]); auto b_val = LOAD(&b[t_h_j_offset]); auto prev_state_val = LOAD(&state_in[h_2d_i_j_offset]); + // auto kv_val = v_val * k_val; auto kv_val = MULTIPLY(v_vec, k_val); + // state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; auto sab_val = MULTIPLY(sa_vec, b_val); auto state_out_val = MULTADD(prev_state_val, w_val, kv_val); state_out_val = ADD(state_out_val, sab_val); STORE(&state_out[h_2d_i_j_offset], state_out_val); + // result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val; auto result = MULTIPLY(state_out_val, r_val); + // auto sum = LOAD(&result_data[t_h_i_offset]); sum = ADD(sum, result); } - result_data[t_h_i_offset] = horizontal_sum_avx(sum); + result_data[t_h_i_offset] = horizontal_sum(sum); } } From 1c50bb07b7549dcc66a8b86002f702fe70a0aa75 Mon Sep 17 00:00:00 2001 From: lshAlgorithm Date: Wed, 14 May 2025 01:18:28 +0800 Subject: [PATCH 3/4] vectorization on sum of sa Signed-off-by: lshAlgorithm --- rwkv_operators_wkv_v7.inc | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/rwkv_operators_wkv_v7.inc b/rwkv_operators_wkv_v7.inc index 55dd14e..abf17e9 100644 --- a/rwkv_operators_wkv_v7.inc +++ b/rwkv_operators_wkv_v7.inc @@ -89,19 +89,15 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens memset(&result_data[t_h_offset], 0, h_stride * sizeof(float)); } - // auto sa_vec = ZEROS(); - // for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { - // sa_vec = ADD(sa_vec, MULTIPLY( - // LOAD(&a[t_h_offset + j]), - // LOAD(&state_in[h_2d_i_offset + j]) - // ) - // ); - // } - // float sa = horizontal_sum(sa_vec); - float sa = .0; - for (size_t j = 0; j < C / H; j++) { - sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j]; + auto sa_vec = ZEROS(); + for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { + sa_vec = ADD(sa_vec, MULTIPLY( + LOAD(&a[t_h_offset + j]), + LOAD(&state_in[h_2d_i_offset + j]) + ) + ); } + float sa = horizontal_sum(sa_vec); auto v_vec = SET1(v[t_h_i_offset]); sa_vec = SET1(sa); From 97a51fc8d62af3f4573426761a7fee12fba93462 Mon Sep 17 00:00:00 2001 From: lshAlgorithm Date: Wed, 14 May 2025 14:07:18 +0800 Subject: [PATCH 4/4] add comments Signed-off-by: lshAlgorithm --- rwkv_operators_wkv_v7.inc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rwkv_operators_wkv_v7.inc b/rwkv_operators_wkv_v7.inc index abf17e9..51e249a 100644 --- a/rwkv_operators_wkv_v7.inc +++ b/rwkv_operators_wkv_v7.inc @@ -38,7 +38,7 @@ #define MULTADD(x, y, z) x * y + z #endif - +// TODO: This is ONLY for avx256, thus should be put in the macro in a decent way. inline float horizontal_sum(__m256 vec) { __m256 sum1 = _mm256_hadd_ps(vec, vec); __m256 sum2 = _mm256_hadd_ps(sum1, sum1); @@ -102,7 +102,7 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens auto v_vec = SET1(v[t_h_i_offset]); sa_vec = SET1(sa); - auto sum = ZEROS(); + auto sum = ZEROS(); // Initialize the sum vector for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { size_t t_h_j_offset = t_h_offset + j; size_t h_2d_i_j_offset = h_2d_i_offset + j; @@ -127,6 +127,7 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens // auto sum = LOAD(&result_data[t_h_i_offset]); sum = ADD(sum, result); } + // Reduce all elements in the vector in one. result_data[t_h_i_offset] = horizontal_sum(sum); }