Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hardware][Intel] fp8 kv cache support for CPU #5492

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/cpu_extension.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
"-mavx512dq")

find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
if (AVX512BF16_FOUND AND ENABLE_AVX512BF16)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to compile for AVX512BF16 if we simply find it, not if we find it and force it

if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
Expand Down
290 changes: 182 additions & 108 deletions csrc/cpu/attention.cpp

Large diffs are not rendered by default.

93 changes: 68 additions & 25 deletions csrc/cpu/cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,33 @@ void copy_blocks_cpu_impl(std::vector<torch::Tensor> const& key_caches,
}
}

template <typename scalar_t>
void reshape_and_cache_cpu_impl(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int num_tokens,
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, const int x) {
template <typename scalar_t, typename cache_t = scalar_t>
cache_t assign_cache_value(const scalar_t* src) {
return *src;
}

template <>
uint8_t assign_cache_value<float, uint8_t>(const float* src) {
uint8_t res = cast_fp32x1_to_fp8x1(*src);
return res;
}

template <>
uint8_t assign_cache_value<int16_t, uint8_t>(const int16_t* src) {
uint8_t res = cast_bf16x1_to_fp8x1(*src);
return res;
}

template <typename scalar_t, typename cache_t = scalar_t, bool use_fp8 = false>
void reshape_and_cache_cpu_impl(const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
cache_t* __restrict__ key_cache,
cache_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping,
const int num_tokens, const int key_stride,
const int value_stride, const int num_heads,
const int head_size, const int block_size,
const int kv_cache_stride, const int x) {
const int block_elem_num = num_heads * head_size * block_size;

#pragma omp parallel for collapse(2)
Expand All @@ -53,19 +73,20 @@ void reshape_and_cache_cpu_impl(
const scalar_t* src_value_head_ptr = value + src_value_head_idx;
const int64_t block_index = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
scalar_t* target_key_head_ptr = key_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
scalar_t* target_value_head_ptr = value_cache +
block_elem_num * block_index +
head_idx * block_size * head_size;
cache_t* target_key_head_ptr = key_cache +
kv_cache_stride * block_index +
head_idx * block_size * head_size;
cache_t* target_value_head_ptr = value_cache +
kv_cache_stride * block_index +
head_idx * block_size * head_size;

for (int src_key_idx = 0; src_key_idx < head_size; src_key_idx += x) {
const int64_t target_offset =
src_key_idx * block_size + block_offset * x;
for (int i = 0; i < x; ++i) {
target_key_head_ptr[target_offset + i] =
src_key_head_ptr[src_key_idx + i];
assign_cache_value<scalar_t, cache_t>(src_key_head_ptr +
src_key_idx + i);
}
}

Expand All @@ -74,7 +95,8 @@ void reshape_and_cache_cpu_impl(
const int64_t target_offset =
src_value_idx * block_size + block_offset;
target_value_head_ptr[target_offset] =
src_value_head_ptr[src_value_idx];
assign_cache_value<scalar_t, cache_t>(src_value_head_ptr +
src_value_idx);
}
}
}
Expand Down Expand Up @@ -104,6 +126,17 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
});
}

#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_KV_CACHE) \
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl) \
reshape_and_cache_cpu_impl<KV_T, CACHE_T, IS_FP8_KV_CACHE>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride, value_stride, \
num_heads, head_size, block_size, kv_cache_stride, x); \
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)

void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
Expand All @@ -115,20 +148,30 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
int head_size = key.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int kv_cache_stride = key_cache.stride(0);

int key_stride = key.stride(0);
int value_stride = value.stride(0);

VLLM_DISPATCH_FLOATING_TYPES(
key.scalar_type(), "reshape_and_cache_cpu_impl", [&] {
CPU_KERNEL_GUARD_IN(reshape_and_cache_cpu_impl)
reshape_and_cache_cpu_impl<scalar_t>(
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
key_cache.data_ptr<scalar_t>(), value_cache.data_ptr<scalar_t>(),
slot_mapping.data_ptr<int64_t>(), num_tokens, key_stride,
value_stride, num_heads, head_size, block_size, x);
CPU_KERNEL_GUARD_OUT(reshape_and_cache_cpu_impl)
});
if (kv_cache_dtype == "auto") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, float, false);
} else if (key.dtype() == at::ScalarType::Half) {
TORCH_CHECK(false, "Unsupported data type: Half");
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(int16_t, int16_t, false);
}
} else if (kv_cache_dtype == "fp8") {
if (key.dtype() == at::ScalarType::Float) {
CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
} else if (key.dtype() == at::ScalarType::Half) {
TORCH_CHECK(false, "Unsupported data type: Half");
} else if (key.dtype() == at::ScalarType::BFloat16) {
CALL_RESHAPE_AND_CACHE(int16_t, uint8_t, true);
}
} else {
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
}
}

void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
Expand Down
1 change: 0 additions & 1 deletion csrc/cpu/cpu_types.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#ifndef CPU_TYPES_HPP
#define CPU_TYPES_HPP

Expand Down
19 changes: 19 additions & 0 deletions csrc/cpu/cpu_types_x86.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
#include <immintrin.h>
#include <torch/all.h>

#include "fp8_utils.h"

typedef uint8_t cpu_fp8;

#ifndef __AVX2__
static_assert(false, "AVX2 must be supported for the current implementation.");
#endif
Expand Down Expand Up @@ -50,6 +54,19 @@ template <typename T> struct Vec {
struct FP32Vec8;
struct FP32Vec16;

struct FP8Vec16 : public Vec<FP8Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m128 reg;
cpu_fp8 values[VEC_ELEM_NUM];
};
__m128 reg;

explicit FP8Vec16() : reg(_mm_set1_ps(0)) {}
explicit FP8Vec16(const cpu_fp8 *ptr) : reg((__m128)_mm_loadu_epi8(ptr)) {}

};

#ifdef __AVX512FP16__
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
Expand Down Expand Up @@ -279,6 +296,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {

explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}

explicit FP32Vec16(const FP8Vec16 &data) : reg(cast_fp8x16_to_fp32x16((__m128)data.reg)) {}

explicit FP32Vec16(const FP32Vec4 &data)
: reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4(
Expand Down
Loading
Loading