Skip to content

Commit

Permalink
Fix typical sampling.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Aug 11, 2024
1 parent dc1bda2 commit 0a2c49a
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 11 deletions.
6 changes: 3 additions & 3 deletions crates/ai00-core/src/sampler/mirostat.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{utils, Sampler};
use super::{radix, Sampler};
use derivative::Derivative;
use itertools::Itertools;
use salvo::oapi::ToSchema;
Expand Down Expand Up @@ -49,13 +49,13 @@ impl Sampler for MirostatSampler {
.iter()
.copied()
.enumerate()
.map(|(id, x)| utils::F32WithIndex(id, x))
.map(|(id, x)| radix::F32WithIndex(id, x))
.collect_vec();
sorted.voracious_sort();
let sorted = sorted
.into_iter()
.rev()
.scan((0, 0.0, 0.0), |(_, cum, _), utils::F32WithIndex(id, x)| {
.scan((0, 0.0, 0.0), |(_, cum, _), radix::F32WithIndex(id, x)| {
// if *cum > params.top_p {
// None
// } else {
Expand Down
4 changes: 3 additions & 1 deletion crates/ai00-core/src/sampler/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ pub mod bnf;
pub mod mirostat;
pub mod nucleus;
pub mod typical;
mod utils;

mod radix;

pub trait Sampler {
/// Initialize the sampler state.
fn init(&mut self, model_tokens: &[u16]);
Expand Down
6 changes: 3 additions & 3 deletions crates/ai00-core/src/sampler/nucleus.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::collections::HashMap;

use super::{utils, Sampler};
use super::{radix, Sampler};
use derivative::Derivative;
use itertools::Itertools;
use salvo::oapi::ToSchema;
Expand Down Expand Up @@ -72,14 +72,14 @@ impl Sampler for NucleusSampler {
.iter()
.copied()
.enumerate()
.map(|(id, x)| utils::F32WithIndex(id, x))
.map(|(id, x)| radix::F32WithIndex(id, x))
.collect_vec();
sorted.voracious_sort();
let sorted = sorted
.into_iter()
.rev()
.take(params.top_k)
.scan((0, 0.0, 0.0), |(_, cum, _), utils::F32WithIndex(id, x)| {
.scan((0, 0.0, 0.0), |(_, cum, _), radix::F32WithIndex(id, x)| {
if *cum > params.top_p {
None
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,49 @@
use std::cmp::Ordering;

use voracious_radix_sort::Radixable;

#[derive(Copy, Clone, Debug)]
pub struct F32WithIndex(pub usize, pub f32);

impl PartialOrd for F32WithIndex {
fn partial_cmp(&self, other: &F32WithIndex) -> Option<Ordering> {
self.1.partial_cmp(&other.1)
}
}

impl PartialEq for F32WithIndex {
fn eq(&self, other: &Self) -> bool {
self.1 == other.1
}
}

impl Radixable<f32> for F32WithIndex {
type Key = f32;

#[inline]
fn key(&self) -> Self::Key {
self.1
}
}

#[derive(Copy, Clone, Debug)]
pub struct DoubleF32WithIndex(pub usize, pub f32, pub f32);

impl PartialOrd for DoubleF32WithIndex {
fn partial_cmp(&self, other: &DoubleF32WithIndex) -> Option<Ordering> {
self.2.partial_cmp(&other.2)
}
}

impl PartialEq for DoubleF32WithIndex {
fn eq(&self, other: &Self) -> bool {
self.2 == other.2
}
}

impl Radixable<f32> for DoubleF32WithIndex {
type Key = f32;

#[inline]
fn key(&self) -> Self::Key {
self.2
Expand Down
7 changes: 3 additions & 4 deletions crates/ai00-core/src/sampler/typical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use salvo::oapi::ToSchema;
use serde::{Deserialize, Serialize};
use voracious_radix_sort::RadixSort;

use super::{utils, Sampler};
use super::{radix, Sampler};

#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
#[derivative(Default)]
Expand Down Expand Up @@ -79,13 +79,12 @@ impl Sampler for TypicalSampler {
let entropy = probs.iter().map(|(_, x, y)| x * y).sum::<f32>();
let mut sorted = probs
.into_iter()
.map(|(id, x, y)| utils::DoubleF32WithIndex(id, x, (y - entropy).abs()))
.map(|(id, x, y)| radix::DoubleF32WithIndex(id, x, (y - entropy).abs()))
.collect_vec();
sorted.voracious_sort();
let sorted = sorted
.into_iter()
.rev()
.map(|utils::DoubleF32WithIndex(id, x, _)| (id, x))
.map(|radix::DoubleF32WithIndex(id, x, _)| (id, x))
.take(params.top_k)
.scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| {
if *cum > params.tau {
Expand Down

0 comments on commit 0a2c49a

Please sign in to comment.