diff --git a/candle-holder-models/src/generation/generate.rs b/candle-holder-models/src/generation/generate.rs index 22dc8f3..ce78f44 100644 --- a/candle-holder-models/src/generation/generate.rs +++ b/candle-holder-models/src/generation/generate.rs @@ -40,7 +40,7 @@ pub fn generate<'a, M: PreTrainedModel + ?Sized>( mut token_streamer: Option + 'a>>, seed: Option, ) -> Result> { - let num_return_sequences = generation_config.get_num_return_sequences(); + let num_return_sequences = generation_config.get_num_return_sequences().max(1); let mut input_ids = input_ids.repeat((num_return_sequences, 1))?; let mut output = input_ids.to_vec2::()?; let input_ids_dims = input_ids.dims2()?; diff --git a/candle-holder-serve/src/router_macro.rs b/candle-holder-serve/src/router_macro.rs index 3054196..64c078d 100644 --- a/candle-holder-serve/src/router_macro.rs +++ b/candle-holder-serve/src/router_macro.rs @@ -1,6 +1,6 @@ #[macro_export] macro_rules! generate_router { - ($pipeline:ident, $request:ident, $response:ident, $process_fn:expr) => { + ($pipeline:ident, $request:ident, $response:ident, $process_fn:expr, $warm_up_fn:expr) => { use anyhow::Result; use axum::{routing::post, Router}; use std::sync::Arc; @@ -27,6 +27,11 @@ macro_rules! generate_router { let pipeline = Arc::new($pipeline::new(&args.model(), &args.device()?, dtype, None)?); + tracing::info!("Warming up the model..."); + $warm_up_fn(&pipeline).unwrap_or_else(|e| { + tracing::error!("Failed to warm up the model: {}", e); + }); + let (tx, rx) = mpsc::channel::>>(32); diff --git a/candle-holder-serve/src/routes/feature_extraction.rs b/candle-holder-serve/src/routes/feature_extraction.rs index 6217314..525cf58 100644 --- a/candle-holder-serve/src/routes/feature_extraction.rs +++ b/candle-holder-serve/src/routes/feature_extraction.rs @@ -53,7 +53,8 @@ generate_router!( FeatureExtractionPipeline, FeatureExtractionInferenceRequest, FeatureExtractionInferenceResponse, - process_feature_extraction + process_feature_extraction, + feature_extraction_warm_up ); pub(crate) fn process_feature_extraction( @@ -83,3 +84,11 @@ pub(crate) fn process_feature_extraction( } } } + +pub(crate) fn feature_extraction_warm_up(pipeline: &FeatureExtractionPipeline) -> Result<()> { + pipeline.run( + "This is a sample text to warm up the feature extraction pipeline.", + None, + )?; + Ok(()) +} diff --git a/candle-holder-serve/src/routes/fill_mask.rs b/candle-holder-serve/src/routes/fill_mask.rs index 75abda8..63ec3da 100644 --- a/candle-holder-serve/src/routes/fill_mask.rs +++ b/candle-holder-serve/src/routes/fill_mask.rs @@ -47,7 +47,8 @@ generate_router!( FillMaskPipeline, FillMaskInferenceRequest, FillMaskInferenceResponse, - process_feature_extraction + process_feature_extraction, + fill_mask_warm_up ); pub(crate) fn process_feature_extraction( @@ -120,3 +121,8 @@ pub(crate) fn process_feature_extraction( } } } + +pub(crate) fn fill_mask_warm_up(pipeline: &FillMaskPipeline) -> Result<()> { + pipeline.run("Hello, my name is [MASK].", None)?; + Ok(()) +} diff --git a/candle-holder-serve/src/routes/text_classification.rs b/candle-holder-serve/src/routes/text_classification.rs index 89b3776..6c0563d 100644 --- a/candle-holder-serve/src/routes/text_classification.rs +++ b/candle-holder-serve/src/routes/text_classification.rs @@ -45,7 +45,8 @@ generate_router!( TextClassificationPipeline, TextClassificationInferenceRequest, TextClassificationInferenceResponse, - process_text_classification + process_text_classification, + text_classification_warm_up ); pub(crate) fn process_text_classification( @@ -84,3 +85,8 @@ pub(crate) fn process_text_classification( } } } + +pub(crate) fn text_classification_warm_up(pipeline: &TextClassificationPipeline) -> Result<()> { + pipeline.run("warm up", None)?; + Ok(()) +} diff --git a/candle-holder-serve/src/routes/text_generation.rs b/candle-holder-serve/src/routes/text_generation.rs index 7d8e6f0..9bc005b 100644 --- a/candle-holder-serve/src/routes/text_generation.rs +++ b/candle-holder-serve/src/routes/text_generation.rs @@ -65,7 +65,8 @@ generate_router!( TextGenerationPipeline, TextGenerationInferenceRequest, TextGenerationInferenceResponse, - process_text_generation + process_text_generation, + text_generation_warm_up ); pub(crate) fn process_text_generation( @@ -126,3 +127,8 @@ pub(crate) fn process_text_generation( } } } + +pub(crate) fn text_generation_warm_up(pipeline: &TextGenerationPipeline) -> Result<()> { + pipeline.run("warm up", None)?; + Ok(()) +} diff --git a/candle-holder-serve/src/routes/token_classification.rs b/candle-holder-serve/src/routes/token_classification.rs index 3db1e09..ef72258 100644 --- a/candle-holder-serve/src/routes/token_classification.rs +++ b/candle-holder-serve/src/routes/token_classification.rs @@ -54,7 +54,8 @@ generate_router!( TokenClassificationPipeline, TokenClassificationInferenceRequest, TokenClassificationInferenceResponse, - process_token_classification + process_token_classification, + token_classification_warm_up ); pub(crate) fn process_token_classification( @@ -127,3 +128,8 @@ pub(crate) fn process_token_classification( } } } + +pub(crate) fn token_classification_warm_up(pipeline: &TokenClassificationPipeline) -> Result<()> { + pipeline.run("warm up", None)?; + Ok(()) +} diff --git a/candle-holder-serve/src/routes/zero_shot_classification.rs b/candle-holder-serve/src/routes/zero_shot_classification.rs index e85046b..a8d6258 100644 --- a/candle-holder-serve/src/routes/zero_shot_classification.rs +++ b/candle-holder-serve/src/routes/zero_shot_classification.rs @@ -65,7 +65,8 @@ generate_router!( ZeroShotClassificationPipeline, ZeroShotClassificationInferenceRequest, ZeroShotClassificationInferenceResponse, - process_zero_shot_classification + process_zero_shot_classification, + zero_shot_classification_warm_up ); pub(crate) fn process_zero_shot_classification( @@ -121,3 +122,10 @@ pub(crate) fn process_zero_shot_classification( } } } + +pub(crate) fn zero_shot_classification_warm_up( + pipeline: &ZeroShotClassificationPipeline, +) -> Result<()> { + pipeline.run("Hello, world!", vec!["world"], None)?; + Ok(()) +}