Skip to content

Commit

Permalink
Api-doc-improvement (#129)
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan authored Jun 8, 2024
1 parent 4a2ad45 commit 45c3488
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 55 deletions.
32 changes: 30 additions & 2 deletions crates/ai00-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,11 +172,12 @@ pub struct GenerateRequest {
pub state: StateId,
}

#[derive(Debug, Derivative, Clone, Serialize, Deserialize)]
#[derive(Debug, Derivative, Clone, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
#[serde(default)]
pub struct ReloadRequest {
/// Path to the model.
#[salvo(schema(value_type = String))]
pub model_path: PathBuf,
/// List of LoRA blended on the model.
pub lora: Vec<reload::Lora>,
Expand All @@ -185,6 +186,7 @@ pub struct ReloadRequest {
/// Specify layers that needs to be quantized.
pub quant: usize,
/// Quantization type (`Int8` or `NF4`).
#[salvo(schema(value_type = sealed::Quant))]
pub quant_type: Quant,
/// Precision for intermediate tensors (`Fp16` or `Fp32`).
pub precision: Precision,
Expand All @@ -195,20 +197,23 @@ pub struct ReloadRequest {
#[derivative(Default(value = "8"))]
pub max_batch: usize,
/// Device to put the embed tensor.
#[salvo(schema(value_type = sealed::EmbedDevice))]
pub embed_device: EmbedDevice,
/// Path to the tokenizer.
#[salvo(schema(value_type = String))]
pub tokenizer_path: PathBuf,
/// BNF options.
pub bnf: BnfOption,
/// Adapter selection.
pub adapter: AdapterOption,
}

#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[derive(Debug, Default, Clone, Serialize, Deserialize, ToSchema)]
#[serde(default)]
pub struct SaveRequest {
/// Path to save the model.
#[serde(alias = "model_path")]
#[salvo(schema(value_type = String))]
pub path: PathBuf,
}

Expand Down Expand Up @@ -745,3 +750,26 @@ pub async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
}
}
}

#[allow(dead_code)]
mod sealed {
use salvo::oapi::ToSchema;

#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash, ToSchema)]
pub enum Quant {
/// No quantization.
#[default]
None,
/// Use `Int8` quantization.
Int8,
/// Use `NF4` quantization.
NF4,
}

#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, ToSchema)]
pub enum EmbedDevice {
#[default]
Cpu,
Gpu,
}
}
22 changes: 15 additions & 7 deletions crates/ai00-core/src/reload.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
use std::path::PathBuf;

use derivative::Derivative;
use salvo::oapi::ToSchema;
use serde::{Deserialize, Serialize};
use web_rwkv::runtime::model::{EmbedDevice, Quant};

use crate::run::StateId;

#[derive(Debug, Clone, Derivative, Serialize, Deserialize)]
#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
#[serde(default)]
pub struct Model {
/// Path to the folder containing all models.
#[derivative(Default(value = "\"assets/models\".into()"))]
#[serde(alias = "model_path")]
#[salvo(schema(value_type = String))]
pub path: PathBuf,
/// Name of the model.
#[serde(alias = "model_name")]
#[salvo(schema(value_type = String))]
pub name: PathBuf,
/// Specify layers that needs to be quantized.
pub quant: usize,
/// Quantization type (`Int8` or `NF4`).
#[salvo(schema(value_type = super::sealed::Quant))]
pub quant_type: Quant,
/// Precision for intermediate tensors (`Fp16` or `Fp32`).
pub precision: Precision,
Expand All @@ -30,27 +34,30 @@ pub struct Model {
#[derivative(Default(value = "8"))]
pub max_batch: usize,
/// Device to put the embed tensor.
#[salvo(schema(value_type = super::sealed::EmbedDevice))]
pub embed_device: EmbedDevice,
}

/// Low-rank adaptor.
#[derive(Debug, Clone, Derivative, Serialize, Deserialize)]
#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
#[serde(default)]
pub struct Lora {
/// Path to the LoRA.
#[salvo(schema(value_type = String))]
pub path: PathBuf,
/// Blend factor.
#[derivative(Default(value = "1.0"))]
pub alpha: f32,
}

/// State-tuned initial state.
#[derive(Debug, Clone, Derivative, Serialize, Deserialize)]
#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
#[serde(default)]
pub struct State {
/// Path to the initial state.
#[salvo(schema(value_type = String))]
pub path: PathBuf,
/// Given name for the state.
pub name: Option<String>,
Expand All @@ -61,15 +68,16 @@ pub struct State {
pub default: bool,
}

#[derive(Debug, Derivative, Clone, Serialize, Deserialize)]
#[derive(Debug, Derivative, Clone, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
#[serde(default)]
pub struct Tokenizer {
#[derivative(Default(value = "\"assets/tokenizer/rwkv_vocab_v20230424.json\".into()"))]
#[salvo(schema(value_type = String))]
pub path: PathBuf,
}

#[derive(Debug, Derivative, Clone, Serialize, Deserialize)]
#[derive(Debug, Derivative, Clone, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
#[serde(default)]
pub struct BnfOption {
Expand All @@ -81,14 +89,14 @@ pub struct BnfOption {
pub start_nonterminal: String,
}

#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, ToSchema)]
pub enum Precision {
#[default]
Fp16,
Fp32,
}

#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize, ToSchema)]
pub enum AdapterOption {
#[default]
Auto,
Expand Down
100 changes: 63 additions & 37 deletions crates/ai00-core/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,12 @@ pub struct GenerateContext {
#[derive(Debug, Clone)]
struct CachedItem {
state: TensorCpu<f32>,
output: Option<TensorCpu<f32>>,
output: TensorCpu<f32>,
instant: Instant,
}

impl CachedItem {
pub fn new(state: TensorCpu<f32>, output: Option<TensorCpu<f32>>) -> Self {
pub fn new(state: TensorCpu<f32>, output: TensorCpu<f32>) -> Self {
Self {
state,
output,
Expand All @@ -265,6 +265,12 @@ impl CachedItem {
}
}

struct CacheCheckout {
prefix: Vec<u16>,
state: TensorCpu<f32>,
output: Option<TensorCpu<f32>>,
}

#[derive(Debug, Default)]
struct Cache {
state: Option<InitState>,
Expand Down Expand Up @@ -508,7 +514,7 @@ impl Runtime {

/// Search for the longest common prefix in the memory cache and checkout the state from that point.
/// Should there be a cache miss, an initial state is returned.
async fn checkout(&self, id: StateId, tokens: &[u16], batch: usize) -> (Vec<u16>, CachedItem) {
async fn checkout(&self, id: StateId, tokens: &[u16], batch: usize) -> CacheCheckout {
let mut caches = self.caches.lock().await;

let Cache { state, cache } = caches.fetch(id);
Expand All @@ -521,15 +527,27 @@ impl Runtime {

let prefix = prefix[0..len].to_vec();
let state = state.clone().map(|state| state.data);
let item = match cache.remove(prefix[..].as_token_slice()) {
Some(item) => CachedItem::update(item),
None => CachedItem::new(state.unwrap_or_else(|| self.state.init()), None),
};
if len > 0 {
let key = Tokens(prefix.clone());
cache.insert(key, item.clone());

match cache.remove(prefix[..].as_token_slice()) {
Some(item) => {
let item = CachedItem::update(item);
let key = Tokens(prefix.clone());
cache.insert(key, item.clone());
CacheCheckout {
prefix,
state: item.state,
output: Some(item.output),
}
}
None => {
let state = state.unwrap_or_else(|| self.state.init());
CacheCheckout {
prefix,
state,
output: None,
}
}
}
(prefix, item)
}

/// Compile and cache the given schema into a BNF sampler.
Expand Down Expand Up @@ -600,15 +618,15 @@ impl Runtime {
// back a non-relative and non-empty slot and use it for our new context
Some(SlotChoice::Back(batch)) => {
log::info!("start at non-empty slot {}", batch);
let (prefix, reload) = self.checkout(context.request.state, &tokens, batch).await;
self.state.load(reload.state, batch)?;
let checkout = self.checkout(context.request.state, &tokens, batch).await;
self.state.load(checkout.state, batch)?;

let len = prefix.len();
let len = checkout.prefix.len();
let mut state = SlotState::Wait(
GenerateContext {
prefix: Tokens(tokens[..len].to_vec()),
suffix: Tokens(tokens[len..].to_vec()),
output: reload.output,
output: checkout.output,
transformers,
..context
}
Expand All @@ -621,15 +639,15 @@ impl Runtime {
// directly occupy an empty slot so no need backing
Some(SlotChoice::Empty(batch)) => {
log::info!("start at empty slot {}", batch);
let (prefix, reload) = self.checkout(context.request.state, &tokens, batch).await;
self.state.load(reload.state, batch)?;
let checkout = self.checkout(context.request.state, &tokens, batch).await;
self.state.load(checkout.state, batch)?;

let len = prefix.len();
let len = checkout.prefix.len();
let state = SlotState::Wait(
GenerateContext {
prefix: Tokens(tokens[..len].to_vec()),
suffix: Tokens(tokens[len..].to_vec()),
output: reload.output,
output: checkout.output,
transformers,
..context
}
Expand All @@ -641,10 +659,20 @@ impl Runtime {
// continue from an existing slot; no need backing as well
Some(SlotChoice::Continue(batch, len)) => {
log::info!("continue at slot {}", batch);
let checkout = self.checkout(context.request.state, &tokens, batch).await;

// retrieve the last output from the cache
assert!(checkout.prefix.len() <= len);
if checkout.prefix.len() < len {
self.state.load(checkout.state, batch)?;
}
let len = checkout.prefix.len();

let state = SlotState::Wait(
GenerateContext {
prefix: Tokens(tokens[..len].to_vec()),
suffix: Tokens(tokens[len..].to_vec()),
output: checkout.output,
transformers,
..context
}
Expand Down Expand Up @@ -682,17 +710,16 @@ impl Runtime {
let _ = context.sender.send(Token::Embed(embed));
}

let mut caches = self.caches.lock().await;
let cache = &mut caches.fetch(context.request.state).cache;
cache.insert(
context.prefix.clone(),
CachedItem::new(backed, context.output),
);
log::info!(
"backed completed slot {} of length {}",
batch,
context.prefix.len()
);
if let Some(output) = context.output {
let mut caches = self.caches.lock().await;
let cache = &mut caches.fetch(context.request.state).cache;
cache.insert(context.prefix.clone(), CachedItem::new(backed, output));
log::info!(
"backed completed slot {} of length {}",
batch,
context.prefix.len()
);
}

assert!(matches!(slots[batch], SlotState::Busy));
slots[batch] = SlotState::Idle(context.prefix, Instant::now());
Expand Down Expand Up @@ -801,8 +828,6 @@ impl Runtime {
}

async fn process(&self, payloads: &mut [Payload]) -> Result<()> {
self.prepare(payloads).await?;

let outputs = payloads
.iter()
.map(|payload| match payload {
Expand All @@ -827,15 +852,14 @@ impl Runtime {
};

// cache the prompt if it is too long.
if !context.prompt_cached && context.prompt_tokens.len() > PROMPT_CACHE_TOKENS {
let cache_prompt = !context.prompt_cached;
let cache_prompt = cache_prompt && context.prompt_tokens.len() > PROMPT_CACHE_TOKENS;
if let Some(output) = cache_prompt.then_some(()).and(context.output.clone()) {
let mut caches = self.caches.lock().await;
let cache = &mut caches.fetch(context.request.state).cache;
let backed = self.state.back(batch).await?;

cache.insert(
context.prefix.clone(),
CachedItem::new(backed, context.output.clone()),
);
cache.insert(context.prefix.clone(), CachedItem::new(backed, output));
context.prompt_cached = true;

log::info!(
Expand Down Expand Up @@ -990,6 +1014,8 @@ impl Runtime {
done.then(|| payload.finalize());
}

self.prepare(payloads).await?;

let option = InferOption::Last;
let batches = payloads
.iter()
Expand Down
Loading

0 comments on commit 45c3488

Please sign in to comment.