Skip to content

Commit

Permalink
Add model warmup
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 8, 2024
1 parent 5c2f496 commit 5c41161
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 8 deletions.
2 changes: 1 addition & 1 deletion candle-holder-models/src/generation/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ pub fn generate<'a, M: PreTrainedModel + ?Sized>(
mut token_streamer: Option<Box<dyn TokenStreamer<'a> + 'a>>,
seed: Option<u64>,
) -> Result<Vec<GenerateOutput>> {
let num_return_sequences = generation_config.get_num_return_sequences();
let num_return_sequences = generation_config.get_num_return_sequences().max(1);
let mut input_ids = input_ids.repeat((num_return_sequences, 1))?;
let mut output = input_ids.to_vec2::<u32>()?;
let input_ids_dims = input_ids.dims2()?;
Expand Down
7 changes: 6 additions & 1 deletion candle-holder-serve/src/router_macro.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#[macro_export]
macro_rules! generate_router {
($pipeline:ident, $request:ident, $response:ident, $process_fn:expr) => {
($pipeline:ident, $request:ident, $response:ident, $process_fn:expr, $warm_up_fn:expr) => {
use anyhow::Result;
use axum::{routing::post, Router};
use std::sync::Arc;
Expand All @@ -27,6 +27,11 @@ macro_rules! generate_router {

let pipeline = Arc::new($pipeline::new(&args.model(), &args.device()?, dtype, None)?);

tracing::info!("Warming up the model...");
$warm_up_fn(&pipeline).unwrap_or_else(|e| {
tracing::error!("Failed to warm up the model: {}", e);
});

let (tx, rx) =
mpsc::channel::<InferenceTask<$request, Result<$response, ErrorResponse>>>(32);

Expand Down
11 changes: 10 additions & 1 deletion candle-holder-serve/src/routes/feature_extraction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ generate_router!(
FeatureExtractionPipeline,
FeatureExtractionInferenceRequest,
FeatureExtractionInferenceResponse,
process_feature_extraction
process_feature_extraction,
feature_extraction_warm_up
);

pub(crate) fn process_feature_extraction(
Expand Down Expand Up @@ -83,3 +84,11 @@ pub(crate) fn process_feature_extraction(
}
}
}

pub(crate) fn feature_extraction_warm_up(pipeline: &FeatureExtractionPipeline) -> Result<()> {
pipeline.run(
"This is a sample text to warm up the feature extraction pipeline.",
None,
)?;
Ok(())
}
8 changes: 7 additions & 1 deletion candle-holder-serve/src/routes/fill_mask.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ generate_router!(
FillMaskPipeline,
FillMaskInferenceRequest,
FillMaskInferenceResponse,
process_feature_extraction
process_feature_extraction,
fill_mask_warm_up
);

pub(crate) fn process_feature_extraction(
Expand Down Expand Up @@ -120,3 +121,8 @@ pub(crate) fn process_feature_extraction(
}
}
}

pub(crate) fn fill_mask_warm_up(pipeline: &FillMaskPipeline) -> Result<()> {
pipeline.run("Hello, my name is [MASK].", None)?;
Ok(())
}
8 changes: 7 additions & 1 deletion candle-holder-serve/src/routes/text_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ generate_router!(
TextClassificationPipeline,
TextClassificationInferenceRequest,
TextClassificationInferenceResponse,
process_text_classification
process_text_classification,
text_classification_warm_up
);

pub(crate) fn process_text_classification(
Expand Down Expand Up @@ -84,3 +85,8 @@ pub(crate) fn process_text_classification(
}
}
}

pub(crate) fn text_classification_warm_up(pipeline: &TextClassificationPipeline) -> Result<()> {
pipeline.run("warm up", None)?;
Ok(())
}
8 changes: 7 additions & 1 deletion candle-holder-serve/src/routes/text_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ generate_router!(
TextGenerationPipeline,
TextGenerationInferenceRequest,
TextGenerationInferenceResponse,
process_text_generation
process_text_generation,
text_generation_warm_up
);

pub(crate) fn process_text_generation(
Expand Down Expand Up @@ -126,3 +127,8 @@ pub(crate) fn process_text_generation(
}
}
}

pub(crate) fn text_generation_warm_up(pipeline: &TextGenerationPipeline) -> Result<()> {
pipeline.run("warm up", None)?;
Ok(())
}
8 changes: 7 additions & 1 deletion candle-holder-serve/src/routes/token_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ generate_router!(
TokenClassificationPipeline,
TokenClassificationInferenceRequest,
TokenClassificationInferenceResponse,
process_token_classification
process_token_classification,
token_classification_warm_up
);

pub(crate) fn process_token_classification(
Expand Down Expand Up @@ -127,3 +128,8 @@ pub(crate) fn process_token_classification(
}
}
}

pub(crate) fn token_classification_warm_up(pipeline: &TokenClassificationPipeline) -> Result<()> {
pipeline.run("warm up", None)?;
Ok(())
}
10 changes: 9 additions & 1 deletion candle-holder-serve/src/routes/zero_shot_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ generate_router!(
ZeroShotClassificationPipeline,
ZeroShotClassificationInferenceRequest,
ZeroShotClassificationInferenceResponse,
process_zero_shot_classification
process_zero_shot_classification,
zero_shot_classification_warm_up
);

pub(crate) fn process_zero_shot_classification(
Expand Down Expand Up @@ -121,3 +122,10 @@ pub(crate) fn process_zero_shot_classification(
}
}
}

pub(crate) fn zero_shot_classification_warm_up(
pipeline: &ZeroShotClassificationPipeline,
) -> Result<()> {
pipeline.run("Hello, world!", vec!["world"], None)?;
Ok(())
}

0 comments on commit 5c41161

Please sign in to comment.