Skip to content

Commit

Permalink
Implement typical sampler and refactor advanced sampler selection API.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed May 22, 2024
1 parent abd291f commit 7319310
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 27 deletions.
22 changes: 9 additions & 13 deletions crates/ai00-core/src/sampler/mirostat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,13 @@ use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
#[serde(default)]
pub struct MirostatParams {
#[derivative(Default(value = "3.0"))]
pub tau: f32,
#[derivative(Default(value = "0.1"))]
#[serde(alias = "learning_rate")]
pub rate: f32,
#[serde(default = "default_top_p")]
#[derivative(Default(value = "0.95"))]
pub top_p: f32,
}

fn default_top_p() -> f32 {
MirostatParams::default().top_p
}

#[derive(Debug, Clone, Default)]
Expand Down Expand Up @@ -55,12 +49,14 @@ impl Sampler for MirostatSampler {
.enumerate()
.sorted_unstable_by(|(_, x), (_, y)| x.total_cmp(y).reverse())
.scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| {
if *cum > params.top_p {
None
} else {
*cum += x;
Some((id, *cum, *x))
}
// if *cum > params.top_p {
// None
// } else {
// *cum += x;
// Some((id, *cum, *x))
// }
*cum += x;
Some((id, *cum, *x))
})
.collect_vec();
let k = sorted
Expand Down
1 change: 1 addition & 0 deletions crates/ai00-core/src/sampler/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod bnf;
pub mod mirostat;
pub mod nucleus;
pub mod typical;

pub trait Sampler {
/// Initialize the sampler state.
Expand Down
122 changes: 122 additions & 0 deletions crates/ai00-core/src/sampler/typical.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
use std::collections::HashMap;

use derivative::Derivative;
use itertools::Itertools;
use salvo::oapi::ToSchema;
use serde::{Deserialize, Serialize};

use super::Sampler;

#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
#[serde(default)]
pub struct TypicalParams {
#[derivative(Default(value = "0.5"))]
pub tau: f32,
#[derivative(Default(value = "128"))]
pub top_k: usize,
#[derivative(Default(value = "1.0"))]
pub temperature: f32,
#[derivative(Default(value = "0.3"))]
pub presence_penalty: f32,
#[derivative(Default(value = "0.3"))]
pub frequency_penalty: f32,
#[derivative(Default(value = "0.99654026"))]
pub penalty_decay: f32,
}

#[derive(Debug, Default, Clone)]
pub struct TypicalState {
pub penalties: HashMap<u16, f32>,
}

#[derive(Debug, Default, Clone)]
pub struct TypicalSampler {
pub params: TypicalParams,
pub state: TypicalState,
}

impl TypicalSampler {
pub fn new(params: TypicalParams) -> Self {
Self {
params,
state: Default::default(),
}
}
}

impl Sampler for TypicalSampler {
fn init(&mut self, model_tokens: &[u16]) {
let TypicalSampler { params, state } = self;
for (index, token) in model_tokens.iter().rev().enumerate() {
let ap = params.presence_penalty;
let af = params.frequency_penalty;
let ad = params.penalty_decay;
let mut penalty = state.penalties.remove(token).unwrap_or(ap);
penalty += af * ad.powf(index as f32);
state.penalties.insert(*token, penalty);
}
}

fn transform(&self, output: &mut [f32]) {
self.state
.penalties
.iter()
// .filter(|(token, _)| !penalty_free_tokens.contains(token))
.for_each(|(token, penalty)| output[*token as usize] -= penalty)
}

fn sample(&mut self, probs: &[f32]) -> u16 {
let TypicalSampler { params, state } = self;

let entropy: f32 = probs.iter().map(|x| -x * x.ln()).sum();

let sorted = probs
.iter()
.map(|&x| (x, (x - entropy).abs()))
.enumerate()
.sorted_unstable_by(|(_, (_, x)), (_, (_, y))| x.total_cmp(y).reverse())
.take(params.top_k)
.scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| {
if *cum > params.tau {
None
} else {
*cum += x.1;
Some((id, *cum, x.0))
}
})
.map(|(id, _, x)| (id, x.powf(1.0 / params.temperature)))
.collect_vec();

let sum: f32 = sorted.iter().map(|(_, x)| x).sum();
let sorted = sorted
.into_iter()
.map(|(id, x)| (id, x / sum))
.scan((0, 0.0), |(_, cum), (id, x)| {
*cum += x;
Some((id, *cum))
})
.collect_vec();

let rand = fastrand::f32();
let token = sorted
.into_iter()
.find_or_first(|&(_, cum)| rand <= cum)
.map(|(id, _)| id)
.unwrap_or_default();
let token = token as u16;

state
.penalties
.iter_mut()
.for_each(|(_, penalty)| *penalty *= params.penalty_decay);

let penalty = match state.penalties.get(&token) {
Some(penalty) => penalty + params.frequency_penalty,
None => params.presence_penalty,
};
state.penalties.insert(token, penalty);

token
}
}
21 changes: 13 additions & 8 deletions crates/ai00-server/src/api/oai/chat.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use std::{collections::HashMap, sync::Arc};

use ai00_core::{
run::StateId, sampler::Sampler, FinishReason, GenerateRequest, ThreadRequest, Token,
TokenCounter, MAX_TOKENS,
run::StateId, FinishReason, GenerateRequest, ThreadRequest, Token, TokenCounter, MAX_TOKENS,
};
use futures_util::StreamExt;
use itertools::Itertools;
use regex::Regex;
use salvo::{oapi::extract::JsonBody, prelude::*, sse::SseEvent, Depot, Writer};
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;

use super::*;
use crate::{
Expand Down Expand Up @@ -51,6 +49,8 @@ pub struct ChatRequest {
messages: Array<ChatRecord>,
#[serde(default)]
names: HashMap<Role, String>,
#[serde(default)]
state: StateId,
#[serde(default = "default_max_tokens")]
max_tokens: usize,
#[serde(default = "default_stop")]
Expand All @@ -61,22 +61,23 @@ pub struct ChatRequest {
#[serde(alias = "logit_bias")]
bias: HashMap<u16, f32>,
#[serde(flatten)]
sampler: SamplerParams,
sampler: NucleusParams,
#[serde(default)]
state: StateId,
sampler_override: Option<SamplerParams>,
}

impl Default for ChatRequest {
fn default() -> Self {
Self {
messages: Array::default(),
names: HashMap::new(),
state: Default::default(),
max_tokens: 256,
stop: Array::Item("\n\n".into()),
stream: false,
bias: HashMap::new(),
sampler: Default::default(),
state: Default::default(),
sampler_override: None,
}
}
}
Expand All @@ -94,11 +95,12 @@ impl From<ChatRequest> for GenerateRequest {
let ChatRequest {
messages,
names,
state,
max_tokens,
stop,
sampler,
sampler_override,
bias,
state,
..
} = value;

Expand Down Expand Up @@ -128,7 +130,10 @@ impl From<ChatRequest> for GenerateRequest {
let max_tokens = max_tokens.min(MAX_TOKENS);
let stop = stop.into();
let bias = Arc::new(bias);
let sampler: Arc<RwLock<dyn Sampler + Send + Sync>> = sampler.into();
let sampler = match sampler_override {
Some(sampler) => sampler.into(),
None => SamplerParams::Nucleus(sampler).into(),
};

Self {
prompt,
Expand Down
17 changes: 12 additions & 5 deletions crates/ai00-server/src/api/oai/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use crate::{
pub struct CompletionRequest {
#[serde(default)]
prompt: Array<String>,
#[serde(default)]
state: StateId,
#[serde(default = "default_max_tokens")]
max_tokens: usize,
#[serde(default)]
Expand All @@ -35,22 +37,23 @@ pub struct CompletionRequest {
#[serde(default)]
bnf_schema: Option<String>,
#[serde(flatten)]
sampler: SamplerParams,
sampler: NucleusParams,
#[serde(default)]
state: StateId,
sampler_override: Option<SamplerParams>,
}

impl Default for CompletionRequest {
fn default() -> Self {
Self {
prompt: Array::default(),
state: Default::default(),
max_tokens: 256,
stop: Array::default(),
stream: false,
bias: HashMap::new(),
bnf_schema: Default::default(),
sampler: Default::default(),
state: Default::default(),
sampler_override: None,
}
}
}
Expand All @@ -63,20 +66,24 @@ impl From<CompletionRequest> for GenerateRequest {
fn from(value: CompletionRequest) -> Self {
let CompletionRequest {
prompt,
state,
max_tokens,
stop,
sampler,
sampler_override,
bias,
bnf_schema,
state,
..
} = value;

let prompt = Vec::from(prompt).join("");
let max_tokens = max_tokens.min(MAX_TOKENS);
let stop = stop.into();
let bias = Arc::new(bias);
let sampler = sampler.into();
let sampler = match sampler_override {
Some(sampler) => sampler.into(),
None => SamplerParams::Nucleus(sampler).into(),
};

Self {
prompt,
Expand Down
5 changes: 4 additions & 1 deletion crates/ai00-server/src/api/oai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::sync::Arc;
use ai00_core::sampler::{
mirostat::{MirostatParams, MirostatSampler},
nucleus::{NucleusParams, NucleusSampler},
typical::{TypicalParams, TypicalSampler},
Sampler,
};
use salvo::oapi::ToSchema;
Expand All @@ -20,9 +21,10 @@ pub use embedding::embeddings;
pub use info::models;

#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
#[serde(untagged)]
#[serde(tag = "type")]
pub enum SamplerParams {
Mirostat(MirostatParams),
Typical(TypicalParams),
Nucleus(NucleusParams),
}

Expand All @@ -36,6 +38,7 @@ impl From<SamplerParams> for Arc<RwLock<dyn Sampler + Send + Sync>> {
fn from(value: SamplerParams) -> Self {
match value {
SamplerParams::Mirostat(params) => Arc::new(RwLock::new(MirostatSampler::new(params))),
SamplerParams::Typical(params) => Arc::new(RwLock::new(TypicalSampler::new(params))),
SamplerParams::Nucleus(params) => Arc::new(RwLock::new(NucleusSampler::new(params))),
}
}
Expand Down

0 comments on commit 7319310

Please sign in to comment.