Skip to content

Commit

Permalink
Add text generation route
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 8, 2024
1 parent db0e4ff commit 5c2f496
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 30 deletions.
4 changes: 2 additions & 2 deletions candle-holder-models/src/generation/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::generation::config::GenerationConfig;

/// An enum containing the available sampling strategies for selecting the next token id of a
/// sequence using the outputs logits of an auto-regressive model.
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub enum SamplingStrategy {
/// Greedy sampling selects the token with the highest probability.
Greedy,
Expand Down Expand Up @@ -104,7 +104,7 @@ impl LogitSampler {
///
/// The ID of the sampled token.
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
if self.temperature.is_none() {
if self.strategy == SamplingStrategy::Greedy || self.temperature.is_none() {
return Ok(logits.argmax(D::Minus1)?.to_scalar()?);
}

Expand Down
152 changes: 124 additions & 28 deletions candle-holder-serve/src/routes/text_generation.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,128 @@
use anyhow::Result;
use axum::{routing::post, Router};
use candle_holder_models::{GenerationConfig, GenerationParams};
use candle_holder_pipelines::TextGenerationPipeline;
use std::sync::Arc;

use crate::cli::Cli;

pub fn router(args: &Cli) -> Result<Router> {
let model = args.model();
let device = args.device()?;

tracing::info!(
"Loading text generation pipeline for model '{}' on device {:?}",
model,
device
);

let pipeline = Arc::new(TextGenerationPipeline::new(
&args.model(),
&args.device()?,
None,
None,
)?);

Ok(Router::new()
.route("/", post(inference))
.with_state(pipeline))
use serde::{Deserialize, Serialize};

use crate::generate_router;

#[derive(Debug, Clone, Deserialize)]
struct TextGenerationInferenceParams {
#[serde(default)]
top_k: Option<usize>,
#[serde(default)]
top_p: Option<f32>,
#[serde(default)]
temperature: Option<f64>,
#[serde(default)]
repetition_penalty: Option<f64>,
#[serde(default)]
max_new_tokens: Option<usize>,
#[serde(default)]
do_sample: Option<bool>,
#[serde(default)]
num_return_sequences: Option<usize>,
}

impl Default for TextGenerationInferenceParams {
fn default() -> Self {
Self {
top_k: None,
top_p: None,
temperature: None,
repetition_penalty: None,
max_new_tokens: Some(50),
do_sample: Some(true),
num_return_sequences: Some(1),
}
}
}

#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
enum Inputs {
Single(String),
Multiple(Vec<String>),
}

#[derive(Debug, Clone, Deserialize)]
pub(crate) struct TextGenerationInferenceRequest {
inputs: Inputs,
parameters: Option<TextGenerationInferenceParams>,
}

#[derive(Debug, Clone, Serialize)]
pub(crate) struct TextGenerationResult {
generated_text: String,
}

async fn inference() -> &'static str {
"inference"
#[derive(Debug, Clone, Serialize)]
#[serde(untagged)]
pub(crate) enum TextGenerationInferenceResponse {
Single(Vec<TextGenerationResult>),
Multiple(Vec<Vec<TextGenerationResult>>),
}

generate_router!(
TextGenerationPipeline,
TextGenerationInferenceRequest,
TextGenerationInferenceResponse,
process_text_generation
);

pub(crate) fn process_text_generation(
pipeline: &TextGenerationPipeline,
request: TextGenerationInferenceRequest,
) -> Result<TextGenerationInferenceResponse, ErrorResponse> {
let params = request.parameters.unwrap_or_default();
let generation_params = GenerationParams {
generation_config: Some(GenerationConfig {
top_k: params.top_k,
top_p: params.top_p,
temperature: params.temperature.unwrap_or(0.7),
repetition_penalty: params.repetition_penalty,
max_new_tokens: params.max_new_tokens,
do_sample: params.do_sample.unwrap_or(false),
num_return_sequences: params.num_return_sequences.unwrap_or(1).max(1),
..Default::default()
}),
tokenizer: None,
stopping_criteria: None,
token_streamer: None,
seed: None,
};

match request.inputs {
Inputs::Single(text) => {
let outputs = pipeline.run(text, Some(generation_params)).map_err(|e| {
tracing::error!("Failed to run pipeline: {}", e);
ErrorResponse::new(500, "Failed to process request")
})?;
let results = outputs
.into_iter()
.map(|output| TextGenerationResult {
generated_text: output.get_text().unwrap_or("".to_string()),
})
.collect();
Ok(TextGenerationInferenceResponse::Single(results))
}
Inputs::Multiple(texts) => {
let outputs = pipeline
.run_batch(texts, Some(generation_params))
.map_err(|e| {
tracing::error!("Failed to run pipeline: {}", e);
ErrorResponse::new(500, "Failed to process request")
})?;
let results: Vec<Vec<TextGenerationResult>> = outputs
.into_iter()
.map(|output| {
output
.into_iter()
.map(|output| TextGenerationResult {
generated_text: output.get_text().unwrap_or("".to_string()),
})
.collect()
})
.collect();
Ok(TextGenerationInferenceResponse::Multiple(results))
}
}
}

0 comments on commit 5c2f496

Please sign in to comment.