Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Update to llm-samplers v0.0.7 #440

Merged
merged 2 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ clap = { version = "4.1.8", features = ["derive"] }
memmap2 = "0.5.10"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing = { version = "0.1", features = ["log"] }
llm-samplers = "=0.0.6"
llm-samplers = "=0.0.7"

# Config for 'cargo dist'
[workspace.metadata.dist]
Expand Down
9 changes: 9 additions & 0 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,15 @@ pub struct Generate {
/// top_p - The probability for the top tokens are added until the result is greater or equal to P and at least min_keep tokens have been seen.
/// p(0.95): The cumulative probability after which no more tokens are kept for sampling.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
///
/// top_a (default: disabled) - This sampler prunes tokens that don't meet a threshold based on the most probable token. The formula is `a1 * pow(max_prob, a2)`. See https://github.com/BlinkDL/RWKV-LM#the-top-a-sampling-method for more information.
/// a1(0.0): Threshold scale. A reasonable value is 0.2. Setting either a1 or a2 to 0 disables the sampler.
/// a2(0.0): Threshold power. A reasonable value is 2.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
///
/// min_p (default: disabled) - This sampler prunes tokens that don't meet a certain percentage of the most probable token. For example if `p` is `0.05` then after `min_keep` is satisfied, other tokens must be at least 5% of the most probable token. See https://github.com/ggerganov/llama.cpp/issues/3483 for more information.
/// p(0.0): Probability threshold. 0.05 to 0.2 are good starting values to try. Setting this to 0 disables the sampler.
/// min_keep(1): Minimum tokens to keep. Setting this to 0 is not recommended.
#[arg(long = "sampler", short = 's', verbatim_doc_comment)]
pub sampler_options: Vec<String>,

Expand Down
10 changes: 5 additions & 5 deletions binaries/llm-test/src/inference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ fn run_inference(
// Takes the most likely element from the logits, except if they've appeared in `previous_tokens`
// at all
#[derive(Debug, Default)]
struct DeterministicSampler(SampleGreedy<TokenId>);
struct DeterministicSampler(SampleGreedy);

impl Sampler<TokenId, f32> for DeterministicSampler {
impl Sampler for DeterministicSampler {
fn sample<'a>(
&mut self,
res: &mut dyn HasSamplerResources<TokenId = TokenId>,
logits: &'a mut Logits<TokenId, f32>,
) -> anyhow::Result<&'a mut Logits<TokenId, f32>> {
res: &mut dyn HasSamplerResources,
logits: &'a mut Logits,
) -> anyhow::Result<&'a mut Logits> {
let mut flat_bias = Default::default();

// This might look a little weird, but it's necessary because the resource
Expand Down
2 changes: 1 addition & 1 deletion crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub struct InferenceParameters {
/// This can be anything that implements [Sampler]. Refer to
/// the `llm-samplers` documentation for possible samplers and suggested
/// combinations: <https://docs.rs/llm-samplers>
pub sampler: Arc<Mutex<dyn Sampler<TokenId, f32>>>,
pub sampler: Arc<Mutex<dyn Sampler>>,
}

//Since Sampler implements Send and Sync, InferenceParameters should too.
Expand Down
52 changes: 34 additions & 18 deletions crates/llm-base/src/samplers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ pub enum SamplingError {
/// to ensure a valid configuration.
pub struct ConfiguredSamplers {
/// A builder from the `llm-samplers` crate.
pub builder: SamplerChainBuilder,
pub builder: SamplerChainBuilder<usize, f32>,
/// Mirostat 1 is present.
pub mirostat1: bool,
/// Mirostat 2 is present.
Expand All @@ -74,15 +74,17 @@ pub struct ConfiguredSamplers {
/// We call a configuration of samplers that run in a certain order a "chain".
/// Here is a description of the default chain `llm` uses:
///
/// 1. Repetition (present by default, multiple allowed)
/// 2. Frequency/Presence (optional, multiple allowed)
/// 3. Sequence Repetition (optional, multiple allowed)
/// 4. Top-K (present by default - incompatible with Mirostat)
/// 5. Tail Free (optional - incompatible with Mirostat)
/// 6. Locally Typical (optional - incompatible with Mirostat)
/// 7. Top-P (present by default - incompatible with Mirostat)
/// 8. Temperature (present by default)
/// 9. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution.
/// 1. Repetition (present by default, multiple allowed)
/// 2. Frequency/Presence (optional, multiple allowed)
/// 3. Sequence Repetition (optional, multiple allowed)
/// 4. Top-K (present by default - incompatible with Mirostat)
/// 5. Tail Free (optional - incompatible with Mirostat)
/// 6. Locally Typical (optional - incompatible with Mirostat)
/// 7. Top-P (present by default - incompatible with Mirostat)
/// 8. Top-A (optional - incompatible with Mirostat)
/// 9. Min-P (optional - incompatible with Mirostat)
/// 10. Temperature (present by default)
/// 11. A Mirostat 1 or 2 sampler if configured, otherwise Random Distribution.
///
/// Samplers listed as "present by default" but incompatible with Mirostat will
/// only be enabled by default if there is no Mirostat sampler enabled.
Expand Down Expand Up @@ -142,6 +144,20 @@ impl Default for ConfiguredSamplers {
Option::<SampleTopP>::None,
),
),
(
"topa",
SamplerSlot::new_single(
|| Box::new(SampleTopA::default().a1(0.0).a2(0.0)),
Option::<SampleTopA>::None,
),
),
(
"minp",
SamplerSlot::new_single(
|| Box::new(SampleMinP::default().p(0.0)),
Option::<SampleMinP>::None,
),
),
(
"temperature",
SamplerSlot::new_single(
Expand Down Expand Up @@ -203,7 +219,7 @@ impl ConfiguredSamplers {
))?
} else if (self.mirostat1 || self.mirostat2) && self.incompat_mirostat {
Err(SamplerConfigurationError::SamplerCombinationError(
"Cannot enable top-p, top-k, locally typical or tail free samplers with Mirostat 1 or 2".to_string(),
"Cannot enable top-p, top-k, top-a, min-p, locally typical or tail free samplers with Mirostat 1 or 2".to_string(),
))?
}
Ok(())
Expand Down Expand Up @@ -245,7 +261,9 @@ impl FromStr for ConfiguredSamplers {
.inspect(|(name, _slot)| match name.as_str() {
"mirostat1" => result.mirostat1 = true,
"mirostat2" => result.mirostat2 = true,
"topp" | "topk" | "locallytypical" | "tailfree" => result.incompat_mirostat = true,
"topa" | "minp" | "topp" | "topk" | "locallytypical" | "tailfree" => {
result.incompat_mirostat = true
}
_ => (),
})
.collect::<Vec<_>>();
Expand All @@ -269,7 +287,7 @@ impl FromStr for ConfiguredSamplers {
/// Sample a token. This convenience function handles building
/// the sampler resources and logits objects the sampler needs.
pub fn sample_token(
mut sampler: impl Sampler<TokenId, f32>,
mut sampler: impl Sampler,
rng: &mut impl rand::Rng,
previous_tokens: &[TokenId],
last_logits: impl IntoIterator<Item = f32>,
Expand Down Expand Up @@ -297,7 +315,7 @@ pub fn build_sampler(
n_vocab: usize,
bias: &[(TokenId, f32)],
args: &[impl AsRef<str>],
) -> Result<Arc<Mutex<dyn Sampler<TokenId, f32>>>, SamplerConfigurationError> {
) -> Result<Arc<Mutex<dyn Sampler>>, SamplerConfigurationError> {
let mut samplers = SamplerChain::new();

if !bias.is_empty() {
Expand Down Expand Up @@ -326,7 +344,7 @@ pub fn build_sampler(
}

/// Get the default sampler chain.
pub fn default_samplers() -> Arc<Mutex<dyn Sampler<TokenId, f32>>> {
pub fn default_samplers() -> Arc<Mutex<dyn Sampler>> {
let mut result = ConfiguredSamplers::default();
result.ensure_default_slots();
Arc::new(Mutex::new(result.builder.into_chain()))
Expand All @@ -349,8 +367,6 @@ impl<'pt, 'r> fmt::Debug for SamplerResources<'pt, 'r> {
}

impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> {
type TokenId = TokenId;

fn with_rng_mut(
&mut self,
fun: &mut dyn FnMut(&mut dyn rand::RngCore),
Expand All @@ -359,7 +375,7 @@ impl<'pt, 'r> HasSamplerResources for SamplerResources<'pt, 'r> {
Ok(())
}

fn with_last_tokens(&self, fun: &mut dyn FnMut(&[Self::TokenId])) -> Result<(), SamplerError> {
fn with_last_tokens(&self, fun: &mut dyn FnMut(&[TokenId])) -> Result<(), SamplerError> {
fun(self.previous_tokens);
Ok(())
}
Expand Down
Loading