diff --git a/CMakeLists.txt b/CMakeLists.txt index 11384de..9a99a5a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -153,11 +153,27 @@ if (NOT MSVC) if (RWKV_GPROF) add_compile_options(-pg) endif() + if (RWKV_AVX2) + add_compile_options(-mavx2) + add_compile_definitions(__AVX2__) + endif() + if (RWKV_FMA) + add_compile_options(-mfma) + endif() if (RWKV_NATIVE) add_compile_options(-march=native) endif() endif() +if (CMAKE_BUILD_TYPE STREQUAL "Release") + message(STATUS "Here we are in Release") + if (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + add_compile_options(-O3) + elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + add_compile_options(/O2) + endif() +endif() + # # Build libraries # diff --git a/python/generate_completions.py b/python/generate_completions.py index cc14ef1..3b568c7 100644 --- a/python/generate_completions.py +++ b/python/generate_completions.py @@ -17,7 +17,7 @@ Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **focused on CPU**, but cuBLAS is also supported.""" # How many completions to generate. -generation_count: int = 3 +generation_count: int = 1 # Token count per single completion. tokens_per_generation: int = 100 diff --git a/rwkv_operators_wkv_v7.inc b/rwkv_operators_wkv_v7.inc index eb4541d..51e249a 100644 --- a/rwkv_operators_wkv_v7.inc +++ b/rwkv_operators_wkv_v7.inc @@ -1,7 +1,7 @@ // Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L8 // Original code by Harrison Vanderbyl. // TODO Fix 1. unaligned memory access on Linux with AVX2, 2. tiny-rwkv with AVX-512 -/*#ifdef __AVX512F__ +#ifdef __AVX512F__ #include #define SIMD_WIDTH 16 #define LOAD(x) _mm512_load_ps(x) @@ -9,7 +9,9 @@ #define SET1(x) _mm512_set1_ps(x) #define MULTIPLY(x, y) _mm512_mul_ps(x, y) #define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z) -#elif __AVX2__ + #define ADD(x, y) _mm512_add_ps(x, y) + #define ZEROS() _mm512_setzero_ps() +#elif defined(__AVX2__) #include #define SIMD_WIDTH 8 #define LOAD(x) _mm256_load_ps(x) @@ -17,6 +19,8 @@ #define SET1(x) _mm256_set1_ps(x) #define MULTIPLY(x, y) _mm256_mul_ps(x, y) #define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z) + #define ADD(x, y) _mm256_add_ps(x, y) + #define ZEROS() _mm256_setzero_ps() #elif defined(__ARM_NEON) || defined(__ARM_NEON__) #include #define SIMD_WIDTH 4 @@ -25,14 +29,25 @@ #define SET1(x) vdupq_n_f32(x) #define MULTIPLY(x, y) vmulq_f32(x, y) #define MULTADD(x, y, z) vmlaq_f32(z, x, y) -#else*/ +#else #define SIMD_WIDTH 1 #define LOAD(x) *x #define STORE(x, y) *x = y #define SET1(x) x #define MULTIPLY(x, y) x * y #define MULTADD(x, y, z) x * y + z -//#endif +#endif + +// TODO: This is ONLY for avx256, thus should be put in the macro in a decent way. +inline float horizontal_sum(__m256 vec) { + __m256 sum1 = _mm256_hadd_ps(vec, vec); + __m256 sum2 = _mm256_hadd_ps(sum1, sum1); + __m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum2, 0), + _mm256_extractf128_ps(sum2, 1)); + float result; + _mm_store_ss(&result, sum128); + return result; +} static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tensor * src, int ith, int nth, void * userdata) { // const size_t T = result->ne[1]; @@ -41,7 +56,7 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens const size_t H = result->src[1]->ne[1]; const size_t T = result->src[1]->ne[2]; GGML_ASSERT(C == S * H); - + float * result_data = (float *) result->data; float * state_out = (float *) result->data + C * T; @@ -62,41 +77,60 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens size_t t_offset = t * t_stride; float * state_in = (t == 0) ? state : state_out; - for (size_t h = ith; h < H; h += nth) { size_t h_offset = h * h_stride; size_t t_h_offset = t_offset + h_offset; size_t h_2d_offset = h * h_stride_2d; - for (size_t i = 0; i < C / H; i++) { - size_t t_h_i_offset = t_h_offset + i; - size_t h_2d_i_offset = h_2d_offset + i * h_stride; - - auto v_val = v[t_h_i_offset]; - - float sa = 0; - for (size_t j = 0; j < C / H; j++) { - sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j]; - } - - if (i == 0) { - memset(&result_data[t_h_offset], 0, h_stride * sizeof(float)); - } - - for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { - size_t t_h_j_offset = t_h_offset + j; - size_t h_2d_i_j_offset = h_2d_i_offset + j; - - auto r_val = r[t_h_j_offset]; - auto w_val = w[t_h_j_offset]; - auto k_val = k[t_h_j_offset]; - auto b_val = b[t_h_j_offset]; - auto kv_val = v_val * k_val; - auto prev_state_val = state_in[h_2d_i_j_offset]; - state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; - result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val; - } + for (size_t i = 0; i < C / H; i ++) { + size_t t_h_i_offset = t_h_offset + i; + size_t h_2d_i_offset = h_2d_offset + i * h_stride; + if (i == 0) { + memset(&result_data[t_h_offset], 0, h_stride * sizeof(float)); + } + + auto sa_vec = ZEROS(); + for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { + sa_vec = ADD(sa_vec, MULTIPLY( + LOAD(&a[t_h_offset + j]), + LOAD(&state_in[h_2d_i_offset + j]) + ) + ); + } + float sa = horizontal_sum(sa_vec); + + auto v_vec = SET1(v[t_h_i_offset]); + sa_vec = SET1(sa); + + auto sum = ZEROS(); // Initialize the sum vector + for (size_t j = 0; j < C / H; j += SIMD_WIDTH) { + size_t t_h_j_offset = t_h_offset + j; + size_t h_2d_i_j_offset = h_2d_i_offset + j; + auto r_val = LOAD(&r[t_h_j_offset]); + auto w_val = LOAD(&w[t_h_j_offset]); + auto k_val = LOAD(&k[t_h_j_offset]); + auto b_val = LOAD(&b[t_h_j_offset]); + auto prev_state_val = LOAD(&state_in[h_2d_i_j_offset]); + + // auto kv_val = v_val * k_val; + auto kv_val = MULTIPLY(v_vec, k_val); + + // state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val; + auto sab_val = MULTIPLY(sa_vec, b_val); + auto state_out_val = MULTADD(prev_state_val, w_val, kv_val); + state_out_val = ADD(state_out_val, sab_val); + STORE(&state_out[h_2d_i_j_offset], state_out_val); + + // result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val; + auto result = MULTIPLY(state_out_val, r_val); + + // auto sum = LOAD(&result_data[t_h_i_offset]); + sum = ADD(sum, result); + } + // Reduce all elements in the vector in one. + result_data[t_h_i_offset] = horizontal_sum(sum); } + } } diff --git a/script.sh b/script.sh new file mode 100755 index 0000000..6aeb10e --- /dev/null +++ b/script.sh @@ -0,0 +1,6 @@ +#!/bin/bash +rm -rf build +mkdir build +cd build +cmake .. +cmake --build . --config Release