Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 44 additions & 4 deletions csrc/cache/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@ StaticKVCacheConfig::StaticKVCacheConfig(
max_cache_len_(_max_cache_len) {
}

StaticKVCacheConfig::StaticKVCacheConfig(
infinicore::Size _max_batch_size,
infinicore::Size _max_cache_len,
std::string kv_cache_dtype)
: max_batch_size_(_max_batch_size),
max_cache_len_(_max_cache_len) {
if (kv_cache_dtype.empty()) {
kv_cache_dtype_set_ = false;
} else {
this->kv_cache_dtype_ = parse_dtype(kv_cache_dtype);
kv_cache_dtype_set_ = true;
}
}

std::unique_ptr<CacheConfig>
StaticKVCacheConfig::unique_copy() const {
return std::make_unique<StaticKVCacheConfig>(*this);
Expand All @@ -42,7 +56,6 @@ StaticKVCache::StaticKVCache(
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::Size max_positional_embedding,
infinicore::DataType dtype,
const StaticKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info)
: Cache(),
Expand All @@ -53,7 +66,7 @@ StaticKVCache::StaticKVCache(
rank_batch_size_(config.max_batch_size()),
cache_len_(config.max_cache_len() == std::numeric_limits<infinicore::Size>::max() || config.max_cache_len() == 0 ? max_positional_embedding : config.max_cache_len()),
rank_num_layers_(num_layers),
dtype_(dtype) {
dtype_(config.kv_cache_dtype()) {

// Allocate K cache
k_caches_ = infinicore::Tensor::empty(
Expand Down Expand Up @@ -115,9 +128,28 @@ StaticKVCache::update(size_t layer_idx,
return {k_cache_layer, v_cache_layer};
}

infinicore::DataType
StaticKVCacheConfig::kv_cache_dtype() const {
return kv_cache_dtype_;
}

void StaticKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) const {
kv_cache_dtype_ = dtype;
}

// ==========================
// PagedKVCacheConfig
// ==========================
PagedKVCacheConfig::PagedKVCacheConfig(
size_t num_blocks,
std::string kv_cache_dtype,
size_t block_size)
: num_blocks_(num_blocks),
block_size_(block_size),
kv_cache_dtype_(parse_dtype(kv_cache_dtype)) {
kv_cache_dtype_set_ = true;
}

PagedKVCacheConfig::PagedKVCacheConfig(
size_t num_blocks,
size_t block_size)
Expand All @@ -140,6 +172,15 @@ PagedKVCacheConfig::block_size() const {
return block_size_;
}

infinicore::DataType
PagedKVCacheConfig::kv_cache_dtype() const {
return kv_cache_dtype_;
}

void PagedKVCacheConfig::set_kv_cache_dtype(infinicore::DataType dtype) const {
kv_cache_dtype_ = dtype;
}

// ==========================
// PagedKVCache
// ==========================
Expand All @@ -149,7 +190,6 @@ PagedKVCache::PagedKVCache(
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::DataType dtype,
const PagedKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info)
: Cache(),
Expand All @@ -158,7 +198,7 @@ PagedKVCache::PagedKVCache(
num_rank_k_heads_(num_k_heads / rank_info.tp_size),
num_rank_v_heads_(num_v_heads / rank_info.tp_size),
rank_num_layers_(num_layers),
dtype_(dtype),
dtype_(config.kv_cache_dtype()),
num_blocks_per_layer_(config.num_blocks()),
block_size_(config.block_size()) {
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
Expand Down
26 changes: 24 additions & 2 deletions csrc/cache/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "base_cache.hpp"

#include "../utils.hpp"
#include "infinicore/context/context.hpp"
#include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
Expand All @@ -22,13 +23,25 @@ class StaticKVCacheConfig final : public CacheConfig {
infinicore::Size _max_batch_size = 1,
infinicore::Size _max_cache_len = std::numeric_limits<infinicore::Size>::max());

StaticKVCacheConfig(
infinicore::Size _max_batch_size,
infinicore::Size _max_cache_len,
std::string kv_cache_dtype);

std::unique_ptr<CacheConfig> unique_copy() const override;
infinicore::Size max_batch_size() const;
infinicore::Size max_cache_len() const;

infinicore::DataType kv_cache_dtype() const;
void set_kv_cache_dtype(infinicore::DataType dtype) const;
bool kv_cache_dtype_is_set() const { return kv_cache_dtype_set_; }

private:
infinicore::Size max_batch_size_;
infinicore::Size max_cache_len_;

bool kv_cache_dtype_set_ = false;
mutable infinicore::DataType kv_cache_dtype_;
};

class StaticKVCache final : public Cache {
Expand All @@ -41,7 +54,6 @@ class StaticKVCache final : public Cache {
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::Size max_positional_embedding,
infinicore::DataType dtype,
const StaticKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info);

Expand Down Expand Up @@ -88,13 +100,24 @@ class PagedKVCacheConfig final : public CacheConfig {
size_t num_blocks,
size_t block_size = 256);

PagedKVCacheConfig(
size_t num_blocks,
std::string kv_cache_dtype,
size_t block_size = 16);

std::unique_ptr<CacheConfig> unique_copy() const override;
size_t num_blocks() const;
size_t block_size() const;
infinicore::DataType kv_cache_dtype() const;
void set_kv_cache_dtype(infinicore::DataType dtype) const;
bool kv_cache_dtype_set() const { return kv_cache_dtype_set_; }

private:
size_t num_blocks_;
size_t block_size_;

bool kv_cache_dtype_set_ = false;
mutable infinicore::DataType kv_cache_dtype_;
};

class PagedKVCache final : public Cache {
Expand All @@ -106,7 +129,6 @@ class PagedKVCache final : public Cache {
infinicore::Size num_k_heads,
infinicore::Size num_v_heads,
infinicore::Size num_layers,
infinicore::DataType dtype,
const PagedKVCacheConfig &config,
const engine::distributed::RankInfo &rank_info);

Expand Down
21 changes: 3 additions & 18 deletions csrc/config/model_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,8 @@ ModelConfig::get_rope_scaling() const {
}
}

infinicore::DataType
ModelConfig::get_dtype() const {
try {
std::string dtype_str = this->get<std::string>("torch_dtype");
if (dtype_str == "float32") {
return infinicore::DataType::F32;
} else if (dtype_str == "float16") {
return infinicore::DataType::F16;
} else if (dtype_str == "bfloat16") {
return infinicore::DataType::BF16;
} else if (dtype_str == "int8") {
return infinicore::DataType::I8;
} else {
throw std::runtime_error("Unsupported dtype string: " + dtype_str);
}
} catch (const std::exception &e) {
throw std::runtime_error("Error getting dtype from config: " + std::string(e.what()));
}
infinicore::DataType ModelConfig::get_dtype() const {
std::string dtype_str = this->get<std::string>("torch_dtype");
return parse_dtype(dtype_str);
}
} // namespace infinilm::config
9 changes: 9 additions & 0 deletions csrc/config/model_config.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include "../utils.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include "quant_config.hpp"
Expand Down Expand Up @@ -63,6 +64,14 @@ class ModelConfig {
infinicore::DataType get_dtype() const;
infinicore::quantization::QuantScheme get_quant_scheme() const;
std::shared_ptr<infinicore::nn::RoPE::ScalingConfig> get_rope_scaling() const;
void set_kv_quant_scheme(std::string kv_cache_dtype) {
if (kv_cache_dtype == "int8") {
this->quant_config.set_kv_quant_scheme(kv_cache_dtype);
}
}
infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const {
return quant_config.get_kv_quant_scheme();
}

private:
nlohmann::json config_json;
Expand Down
20 changes: 19 additions & 1 deletion csrc/config/quant_config.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once
// #include "../quantization/quantization.hpp"
#include "../utils.hpp"
#include "infinicore/quantization.hpp"
#include "nlohmann/json.hpp"

Expand All @@ -22,9 +22,27 @@ class QuantConfig {
}
}

void set_kv_quant_scheme(std::string kv_cache_dtype) {
switch (parse_dtype(kv_cache_dtype)) {
case infinicore::DataType::I8: {
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::INT8;
break;
}
default: {
this->kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE;
break;
}
}
}

infinicore::quantization::KVQuantAlgo get_kv_quant_scheme() const {
return kv_quant_scheme;
}

private:
nlohmann::json quantization_config;
std::shared_ptr<infinicore::quantization::BaseQuantization> quantization_method;
infinicore::quantization::KVQuantAlgo kv_quant_scheme = infinicore::quantization::KVQuantAlgo::NONE;
};

} // namespace infinilm::config
5 changes: 4 additions & 1 deletion csrc/engine/infer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,17 @@ InferEngine::InferEngine(
infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config,
bool enable_graph_compiling,
backends::AttentionBackend attention_backend) // Changed parameter
backends::AttentionBackend attention_backend,
const std::string &kv_cache_dtype) // Changed parameter
: communication_group_(distributed_config, device_type), attention_backend_(attention_backend) {
if (cache_config != nullptr) {
cache_config_ = cache_config->unique_copy();
}

// Load model config if model_path is provided, model_path must be valid, and config.json exists
this->model_config_ = std::make_shared<infinilm::config::ModelConfig>(model_path + "/config.json");
// Only support offline int8 kv cache quantization in this version
this->model_config_->set_kv_quant_scheme(kv_cache_dtype);
// Create one RankWorker per rank
int world_size = communication_group_.get_world_size();
barrier_ = std::make_unique<RankBarrier>((size_t)world_size);
Expand Down
3 changes: 2 additions & 1 deletion csrc/engine/infer_engine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ class InferEngine {
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false,
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default);
backends::AttentionBackend attention_backend = backends::AttentionBackend::Default,
const std::string &kv_cache_dtype = "");

// Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param);
Expand Down
47 changes: 43 additions & 4 deletions csrc/models/llama/llama_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@
#include "infinicore/ops/mha_kvcache.hpp"
#include "infinicore/ops/mha_varlen.hpp"
#include "infinicore/ops/mul.hpp"
#include "infinicore/ops/per_tensor_dequant_i8.hpp"
#include "infinicore/ops/per_tensor_quant_i8.hpp"

#include <algorithm>
#include <cmath>
#include <cstring>
#include <iostream>
#include <optional>
#include <spdlog/spdlog.h>
#include <stdexcept>
Expand Down Expand Up @@ -137,6 +140,17 @@ LlamaAttention::LlamaAttention(std::shared_ptr<infinilm::config::ModelConfig> mo
INFINICORE_NN_MODULE_INIT(q_norm, head_dim_, model_config_->get<double>("rms_norm_eps"), dtype, device);
INFINICORE_NN_MODULE_INIT(k_norm, head_dim_, model_config_->get<double>("rms_norm_eps"), dtype, device);
}

switch (this->model_config_->get_kv_quant_scheme()) {
case (infinicore::quantization::KVQuantAlgo::INT8): {
INFINICORE_NN_PARAMETER_INIT(kv_cache_k_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1));
INFINICORE_NN_PARAMETER_INIT(kv_cache_v_scale, ({1}, infinicore::DataType::F32, device, 0, 0, 1));
break;
}
default: {
break;
}
}
}

infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_states,
Expand Down Expand Up @@ -184,6 +198,17 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim]
rotary_emb_->forward(k_reshaped, pos_ids_for_rope, true); // [bs, seq_len, n_kv_head, head_dim]

switch (this->model_config_->get_kv_quant_scheme()) {
case (infinicore::quantization::KVQuantAlgo::INT8): {
k_reshaped = infinicore::op::per_tensor_quant_i8(k_reshaped, this->kv_cache_k_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()), true);
v_reshaped = infinicore::op::per_tensor_quant_i8(v_reshaped, this->kv_cache_v_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()), true);
break;
}
default: {
break;
}
}

// 5. Prepare KV caches
// Convert to [batch, n_head, seq_len, head_dim] for cache
// Ensure contiguous after permute for F16 compatibility with cache operations
Expand Down Expand Up @@ -212,6 +237,21 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
} else {
size_t total_seq_len = reinterpret_cast<int32_t *>(total_sequence_lengths.value()->to(infinicore::Device::cpu())->data())[0];

switch (this->model_config_->get_kv_quant_scheme()) {
case (infinicore::quantization::KVQuantAlgo::INT8): {
auto k_total_dequant = infinicore::Tensor::strided_empty(k_total->shape(), k_total->strides(), q_reshaped->dtype(), q_reshaped->device());
auto v_total_dequant = infinicore::Tensor::strided_empty(v_total->shape(), v_total->strides(), q_reshaped->dtype(), q_reshaped->device());
infinicore::op::per_tensor_dequant_i8_(k_total_dequant, k_total, this->kv_cache_k_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()));
infinicore::op::per_tensor_dequant_i8_(v_total_dequant, v_total, this->kv_cache_v_scale(), infinicore::Tensor::zeros({1}, k_reshaped->dtype(), k_reshaped->device()));
k_total = k_total_dequant;
v_total = v_total_dequant;
break;
}
default: {
break;
}
}
k_total = k_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]
v_total = v_total->narrow({{2, 0, total_seq_len}}); // [bs, n_kv_head, total_seq_len, head_dim]

Expand Down Expand Up @@ -342,10 +382,10 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
auto q_for_fa = q_reshaped->view({seq_len, 1, num_attention_heads_, head_dim_});
auto attn_out_4d = infinicore::op::mha_kvcache(
q_for_fa,
k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim]
k_total->permute({0, 2, 1, 3}), // [num_blocks, block_size, num_kv_heads, head_dim]
v_total->permute({0, 2, 1, 3}),
total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence)
block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32
total_sequence_lengths.value(), // [seq_len] int32 (one entry per sequence)
block_tables.value(), // [seq_len, max_num_blocks_per_seq] int32
std::nullopt,
scaling_);
attn_output = attn_out_4d->view({seq_len, num_attention_heads_, head_dim_});
Expand All @@ -361,7 +401,6 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
scaling_);
}
}


// 7. Project output
attn_output
Expand Down
Loading