diff --git a/crates/ai00-core/src/sampler/mirostat.rs b/crates/ai00-core/src/sampler/mirostat.rs index 5f03af47..d3bbc6aa 100644 --- a/crates/ai00-core/src/sampler/mirostat.rs +++ b/crates/ai00-core/src/sampler/mirostat.rs @@ -1,4 +1,4 @@ -use super::{utils, Sampler}; +use super::{radix, Sampler}; use derivative::Derivative; use itertools::Itertools; use salvo::oapi::ToSchema; @@ -49,13 +49,13 @@ impl Sampler for MirostatSampler { .iter() .copied() .enumerate() - .map(|(id, x)| utils::F32WithIndex(id, x)) + .map(|(id, x)| radix::F32WithIndex(id, x)) .collect_vec(); sorted.voracious_sort(); let sorted = sorted .into_iter() .rev() - .scan((0, 0.0, 0.0), |(_, cum, _), utils::F32WithIndex(id, x)| { + .scan((0, 0.0, 0.0), |(_, cum, _), radix::F32WithIndex(id, x)| { // if *cum > params.top_p { // None // } else { diff --git a/crates/ai00-core/src/sampler/mod.rs b/crates/ai00-core/src/sampler/mod.rs index 72eb1052..1475afe8 100644 --- a/crates/ai00-core/src/sampler/mod.rs +++ b/crates/ai00-core/src/sampler/mod.rs @@ -2,7 +2,9 @@ pub mod bnf; pub mod mirostat; pub mod nucleus; pub mod typical; -mod utils; + +mod radix; + pub trait Sampler { /// Initialize the sampler state. fn init(&mut self, model_tokens: &[u16]); diff --git a/crates/ai00-core/src/sampler/nucleus.rs b/crates/ai00-core/src/sampler/nucleus.rs index 314df06c..485ff284 100644 --- a/crates/ai00-core/src/sampler/nucleus.rs +++ b/crates/ai00-core/src/sampler/nucleus.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use super::{utils, Sampler}; +use super::{radix, Sampler}; use derivative::Derivative; use itertools::Itertools; use salvo::oapi::ToSchema; @@ -72,14 +72,14 @@ impl Sampler for NucleusSampler { .iter() .copied() .enumerate() - .map(|(id, x)| utils::F32WithIndex(id, x)) + .map(|(id, x)| radix::F32WithIndex(id, x)) .collect_vec(); sorted.voracious_sort(); let sorted = sorted .into_iter() .rev() .take(params.top_k) - .scan((0, 0.0, 0.0), |(_, cum, _), utils::F32WithIndex(id, x)| { + .scan((0, 0.0, 0.0), |(_, cum, _), radix::F32WithIndex(id, x)| { if *cum > params.top_p { None } else { diff --git a/crates/ai00-core/src/sampler/utils.rs b/crates/ai00-core/src/sampler/radix.rs similarity index 98% rename from crates/ai00-core/src/sampler/utils.rs rename to crates/ai00-core/src/sampler/radix.rs index 7dfdaec7..492ea854 100644 --- a/crates/ai00-core/src/sampler/utils.rs +++ b/crates/ai00-core/src/sampler/radix.rs @@ -1,38 +1,49 @@ use std::cmp::Ordering; + use voracious_radix_sort::Radixable; + #[derive(Copy, Clone, Debug)] pub struct F32WithIndex(pub usize, pub f32); + impl PartialOrd for F32WithIndex { fn partial_cmp(&self, other: &F32WithIndex) -> Option { self.1.partial_cmp(&other.1) } } + impl PartialEq for F32WithIndex { fn eq(&self, other: &Self) -> bool { self.1 == other.1 } } + impl Radixable for F32WithIndex { type Key = f32; + #[inline] fn key(&self) -> Self::Key { self.1 } } + #[derive(Copy, Clone, Debug)] pub struct DoubleF32WithIndex(pub usize, pub f32, pub f32); + impl PartialOrd for DoubleF32WithIndex { fn partial_cmp(&self, other: &DoubleF32WithIndex) -> Option { self.2.partial_cmp(&other.2) } } + impl PartialEq for DoubleF32WithIndex { fn eq(&self, other: &Self) -> bool { self.2 == other.2 } } + impl Radixable for DoubleF32WithIndex { type Key = f32; + #[inline] fn key(&self) -> Self::Key { self.2 diff --git a/crates/ai00-core/src/sampler/typical.rs b/crates/ai00-core/src/sampler/typical.rs index e8a65ea6..4f4966b3 100644 --- a/crates/ai00-core/src/sampler/typical.rs +++ b/crates/ai00-core/src/sampler/typical.rs @@ -6,7 +6,7 @@ use salvo::oapi::ToSchema; use serde::{Deserialize, Serialize}; use voracious_radix_sort::RadixSort; -use super::{utils, Sampler}; +use super::{radix, Sampler}; #[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)] #[derivative(Default)] @@ -79,13 +79,12 @@ impl Sampler for TypicalSampler { let entropy = probs.iter().map(|(_, x, y)| x * y).sum::(); let mut sorted = probs .into_iter() - .map(|(id, x, y)| utils::DoubleF32WithIndex(id, x, (y - entropy).abs())) + .map(|(id, x, y)| radix::DoubleF32WithIndex(id, x, (y - entropy).abs())) .collect_vec(); sorted.voracious_sort(); let sorted = sorted .into_iter() - .rev() - .map(|utils::DoubleF32WithIndex(id, x, _)| (id, x)) + .map(|radix::DoubleF32WithIndex(id, x, _)| (id, x)) .take(params.top_k) .scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| { if *cum > params.tau {