Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,27 @@ if (NOT MSVC)
if (RWKV_GPROF)
add_compile_options(-pg)
endif()
if (RWKV_AVX2)
add_compile_options(-mavx2)
add_compile_definitions(__AVX2__)
endif()
if (RWKV_FMA)
add_compile_options(-mfma)
endif()
if (RWKV_NATIVE)
add_compile_options(-march=native)
endif()
endif()

if (CMAKE_BUILD_TYPE STREQUAL "Release")
message(STATUS "Here we are in Release")
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
add_compile_options(-O3)
elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC")
add_compile_options(/O2)
endif()
endif()

#
# Build libraries
#
Expand Down
2 changes: 1 addition & 1 deletion python/generate_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Besides the usual **FP32**, it supports **FP16**, **quantized INT4, INT5 and INT8** inference. This project is **focused on CPU**, but cuBLAS is also supported."""

# How many completions to generate.
generation_count: int = 3
generation_count: int = 1
# Token count per single completion.
tokens_per_generation: int = 100

Expand Down
102 changes: 68 additions & 34 deletions rwkv_operators_wkv_v7.inc
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
// Ported from https://github.com/harrisonvanderbyl/RNN-Factory/blob/3b696b547cc9e25de04a077602c3fe1133d8984c/src/models/modules/cuda/cpuonly.cpp#L8
// Original code by Harrison Vanderbyl.
// TODO Fix 1. unaligned memory access on Linux with AVX2, 2. tiny-rwkv with AVX-512
/*#ifdef __AVX512F__
#ifdef __AVX512F__
#include <immintrin.h>
#define SIMD_WIDTH 16
#define LOAD(x) _mm512_load_ps(x)
#define STORE(x, y) _mm512_store_ps(x, y)
#define SET1(x) _mm512_set1_ps(x)
#define MULTIPLY(x, y) _mm512_mul_ps(x, y)
#define MULTADD(x, y, z) _mm512_fmadd_ps(x, y, z)
#elif __AVX2__
#define ADD(x, y) _mm512_add_ps(x, y)
#define ZEROS() _mm512_setzero_ps()
#elif defined(__AVX2__)
#include <immintrin.h>
#define SIMD_WIDTH 8
#define LOAD(x) _mm256_load_ps(x)
#define STORE(x, y) _mm256_store_ps(x, y)
#define SET1(x) _mm256_set1_ps(x)
#define MULTIPLY(x, y) _mm256_mul_ps(x, y)
#define MULTADD(x, y, z) _mm256_fmadd_ps(x, y, z)
#define ADD(x, y) _mm256_add_ps(x, y)
#define ZEROS() _mm256_setzero_ps()
#elif defined(__ARM_NEON) || defined(__ARM_NEON__)
#include <arm_neon.h>
#define SIMD_WIDTH 4
Expand All @@ -25,14 +29,25 @@
#define SET1(x) vdupq_n_f32(x)
#define MULTIPLY(x, y) vmulq_f32(x, y)
#define MULTADD(x, y, z) vmlaq_f32(z, x, y)
#else*/
#else
#define SIMD_WIDTH 1
#define LOAD(x) *x
#define STORE(x, y) *x = y
#define SET1(x) x
#define MULTIPLY(x, y) x * y
#define MULTADD(x, y, z) x * y + z
//#endif
#endif

// TODO: This is ONLY for avx256, thus should be put in the macro in a decent way.
inline float horizontal_sum(__m256 vec) {
__m256 sum1 = _mm256_hadd_ps(vec, vec);
__m256 sum2 = _mm256_hadd_ps(sum1, sum1);
__m128 sum128 = _mm_add_ps(_mm256_extractf128_ps(sum2, 0),
_mm256_extractf128_ps(sum2, 1));
float result;
_mm_store_ss(&result, sum128);
return result;
}

static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tensor * src, int ith, int nth, void * userdata) {
// const size_t T = result->ne[1];
Expand All @@ -41,7 +56,7 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
const size_t H = result->src[1]->ne[1];
const size_t T = result->src[1]->ne[2];
GGML_ASSERT(C == S * H);

float * result_data = (float *) result->data;
float * state_out = (float *) result->data + C * T;

Expand All @@ -62,41 +77,60 @@ static void rwkv_wkv_v7_impl(struct ggml_tensor * result, const struct ggml_tens
size_t t_offset = t * t_stride;

float * state_in = (t == 0) ? state : state_out;

for (size_t h = ith; h < H; h += nth) {
size_t h_offset = h * h_stride;
size_t t_h_offset = t_offset + h_offset;
size_t h_2d_offset = h * h_stride_2d;

for (size_t i = 0; i < C / H; i++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;

auto v_val = v[t_h_i_offset];

float sa = 0;
for (size_t j = 0; j < C / H; j++) {
sa += a[t_h_offset + j] * state_in[h_2d_i_offset + j];
}

if (i == 0) {
memset(&result_data[t_h_offset], 0, h_stride * sizeof(float));
}

for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;

auto r_val = r[t_h_j_offset];
auto w_val = w[t_h_j_offset];
auto k_val = k[t_h_j_offset];
auto b_val = b[t_h_j_offset];
auto kv_val = v_val * k_val;
auto prev_state_val = state_in[h_2d_i_j_offset];
state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val;
}
for (size_t i = 0; i < C / H; i ++) {
size_t t_h_i_offset = t_h_offset + i;
size_t h_2d_i_offset = h_2d_offset + i * h_stride;
if (i == 0) {
memset(&result_data[t_h_offset], 0, h_stride * sizeof(float));
}

auto sa_vec = ZEROS();
for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
sa_vec = ADD(sa_vec, MULTIPLY(
LOAD(&a[t_h_offset + j]),
LOAD(&state_in[h_2d_i_offset + j])
)
);
}
float sa = horizontal_sum(sa_vec);

auto v_vec = SET1(v[t_h_i_offset]);
sa_vec = SET1(sa);

auto sum = ZEROS(); // Initialize the sum vector
for (size_t j = 0; j < C / H; j += SIMD_WIDTH) {
size_t t_h_j_offset = t_h_offset + j;
size_t h_2d_i_j_offset = h_2d_i_offset + j;
auto r_val = LOAD(&r[t_h_j_offset]);
auto w_val = LOAD(&w[t_h_j_offset]);
auto k_val = LOAD(&k[t_h_j_offset]);
auto b_val = LOAD(&b[t_h_j_offset]);
auto prev_state_val = LOAD(&state_in[h_2d_i_j_offset]);

// auto kv_val = v_val * k_val;
auto kv_val = MULTIPLY(v_vec, k_val);

// state_out[h_2d_i_j_offset] = prev_state_val * w_val + kv_val + sa * b_val;
auto sab_val = MULTIPLY(sa_vec, b_val);
auto state_out_val = MULTADD(prev_state_val, w_val, kv_val);
state_out_val = ADD(state_out_val, sab_val);
STORE(&state_out[h_2d_i_j_offset], state_out_val);

// result_data[t_h_i_offset] += state_out[h_2d_i_j_offset] * r_val;
auto result = MULTIPLY(state_out_val, r_val);

// auto sum = LOAD(&result_data[t_h_i_offset]);
sum = ADD(sum, result);
}
// Reduce all elements in the vector in one.
result_data[t_h_i_offset] = horizontal_sum(sum);
}

}
}

Expand Down
6 changes: 6 additions & 0 deletions script.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/bin/bash
rm -rf build
mkdir build
cd build
cmake ..
cmake --build . --config Release
Loading