From 5c2f496214c2e6b217e17ba44ffe941a25427ee0 Mon Sep 17 00:00:00 2001 From: gabrielmbmb Date: Sun, 8 Sep 2024 09:23:28 +0200 Subject: [PATCH] Add text generation route --- .../src/generation/sampling.rs | 4 +- .../src/routes/text_generation.rs | 152 ++++++++++++++---- 2 files changed, 126 insertions(+), 30 deletions(-) diff --git a/candle-holder-models/src/generation/sampling.rs b/candle-holder-models/src/generation/sampling.rs index 6d793c9..1810759 100644 --- a/candle-holder-models/src/generation/sampling.rs +++ b/candle-holder-models/src/generation/sampling.rs @@ -8,7 +8,7 @@ use crate::generation::config::GenerationConfig; /// An enum containing the available sampling strategies for selecting the next token id of a /// sequence using the outputs logits of an auto-regressive model. -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum SamplingStrategy { /// Greedy sampling selects the token with the highest probability. Greedy, @@ -104,7 +104,7 @@ impl LogitSampler { /// /// The ID of the sampled token. pub fn sample(&mut self, logits: &Tensor) -> Result { - if self.temperature.is_none() { + if self.strategy == SamplingStrategy::Greedy || self.temperature.is_none() { return Ok(logits.argmax(D::Minus1)?.to_scalar()?); } diff --git a/candle-holder-serve/src/routes/text_generation.rs b/candle-holder-serve/src/routes/text_generation.rs index fdf8944..7d8e6f0 100644 --- a/candle-holder-serve/src/routes/text_generation.rs +++ b/candle-holder-serve/src/routes/text_generation.rs @@ -1,32 +1,128 @@ -use anyhow::Result; -use axum::{routing::post, Router}; +use candle_holder_models::{GenerationConfig, GenerationParams}; use candle_holder_pipelines::TextGenerationPipeline; -use std::sync::Arc; - -use crate::cli::Cli; - -pub fn router(args: &Cli) -> Result { - let model = args.model(); - let device = args.device()?; - - tracing::info!( - "Loading text generation pipeline for model '{}' on device {:?}", - model, - device - ); - - let pipeline = Arc::new(TextGenerationPipeline::new( - &args.model(), - &args.device()?, - None, - None, - )?); - - Ok(Router::new() - .route("/", post(inference)) - .with_state(pipeline)) +use serde::{Deserialize, Serialize}; + +use crate::generate_router; + +#[derive(Debug, Clone, Deserialize)] +struct TextGenerationInferenceParams { + #[serde(default)] + top_k: Option, + #[serde(default)] + top_p: Option, + #[serde(default)] + temperature: Option, + #[serde(default)] + repetition_penalty: Option, + #[serde(default)] + max_new_tokens: Option, + #[serde(default)] + do_sample: Option, + #[serde(default)] + num_return_sequences: Option, +} + +impl Default for TextGenerationInferenceParams { + fn default() -> Self { + Self { + top_k: None, + top_p: None, + temperature: None, + repetition_penalty: None, + max_new_tokens: Some(50), + do_sample: Some(true), + num_return_sequences: Some(1), + } + } +} + +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +enum Inputs { + Single(String), + Multiple(Vec), +} + +#[derive(Debug, Clone, Deserialize)] +pub(crate) struct TextGenerationInferenceRequest { + inputs: Inputs, + parameters: Option, +} + +#[derive(Debug, Clone, Serialize)] +pub(crate) struct TextGenerationResult { + generated_text: String, } -async fn inference() -> &'static str { - "inference" +#[derive(Debug, Clone, Serialize)] +#[serde(untagged)] +pub(crate) enum TextGenerationInferenceResponse { + Single(Vec), + Multiple(Vec>), +} + +generate_router!( + TextGenerationPipeline, + TextGenerationInferenceRequest, + TextGenerationInferenceResponse, + process_text_generation +); + +pub(crate) fn process_text_generation( + pipeline: &TextGenerationPipeline, + request: TextGenerationInferenceRequest, +) -> Result { + let params = request.parameters.unwrap_or_default(); + let generation_params = GenerationParams { + generation_config: Some(GenerationConfig { + top_k: params.top_k, + top_p: params.top_p, + temperature: params.temperature.unwrap_or(0.7), + repetition_penalty: params.repetition_penalty, + max_new_tokens: params.max_new_tokens, + do_sample: params.do_sample.unwrap_or(false), + num_return_sequences: params.num_return_sequences.unwrap_or(1).max(1), + ..Default::default() + }), + tokenizer: None, + stopping_criteria: None, + token_streamer: None, + seed: None, + }; + + match request.inputs { + Inputs::Single(text) => { + let outputs = pipeline.run(text, Some(generation_params)).map_err(|e| { + tracing::error!("Failed to run pipeline: {}", e); + ErrorResponse::new(500, "Failed to process request") + })?; + let results = outputs + .into_iter() + .map(|output| TextGenerationResult { + generated_text: output.get_text().unwrap_or("".to_string()), + }) + .collect(); + Ok(TextGenerationInferenceResponse::Single(results)) + } + Inputs::Multiple(texts) => { + let outputs = pipeline + .run_batch(texts, Some(generation_params)) + .map_err(|e| { + tracing::error!("Failed to run pipeline: {}", e); + ErrorResponse::new(500, "Failed to process request") + })?; + let results: Vec> = outputs + .into_iter() + .map(|output| { + output + .into_iter() + .map(|output| TextGenerationResult { + generated_text: output.get_text().unwrap_or("".to_string()), + }) + .collect() + }) + .collect(); + Ok(TextGenerationInferenceResponse::Multiple(results)) + } + } }