Skip to content

Commit

Permalink
Add top_k in nucleus sampler.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed May 22, 2024
1 parent 8c862c0 commit 567d2db
Showing 1 changed file with 4 additions and 25 deletions.
29 changes: 4 additions & 25 deletions crates/ai00-core/src/sampler/nucleus.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,44 +9,22 @@ use super::Sampler;

#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
#[serde(default)]
pub struct NucleusParams {
#[derivative(Default(value = "0.5"))]
#[serde(default = "default_top_p")]
pub top_p: f32,
#[derivative(Default(value = "128"))]
pub top_k: usize,
#[derivative(Default(value = "1.0"))]
#[serde(default = "default_temperature")]
pub temperature: f32,
#[derivative(Default(value = "0.3"))]
#[serde(default = "default_presence_penalty")]
pub presence_penalty: f32,
#[derivative(Default(value = "0.3"))]
#[serde(default = "default_frequency_penalty")]
pub frequency_penalty: f32,
#[derivative(Default(value = "0.99654026"))]
#[serde(default = "default_penalty_decay")]
pub penalty_decay: f32,
}

fn default_top_p() -> f32 {
NucleusParams::default().top_p
}

fn default_temperature() -> f32 {
NucleusParams::default().temperature
}

fn default_presence_penalty() -> f32 {
NucleusParams::default().presence_penalty
}

fn default_frequency_penalty() -> f32 {
NucleusParams::default().frequency_penalty
}

fn default_penalty_decay() -> f32 {
NucleusParams::default().penalty_decay
}

#[derive(Debug, Default, Clone)]
pub struct NucleusState {
pub penalties: HashMap<u16, f32>,
Expand Down Expand Up @@ -95,6 +73,7 @@ impl Sampler for NucleusSampler {
.iter()
.enumerate()
.sorted_unstable_by(|(_, x), (_, y)| x.total_cmp(y).reverse())
.take(params.top_k)
.scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| {
if *cum > params.top_p {
None
Expand Down

0 comments on commit 567d2db

Please sign in to comment.