diff --git a/candle-holder-serve/Cargo.toml b/candle-holder-serve/Cargo.toml index dc05995..2ce1679 100644 --- a/candle-holder-serve/Cargo.toml +++ b/candle-holder-serve/Cargo.toml @@ -14,5 +14,25 @@ clap = { workspace = true } axum = { workspace = true } tokio = { workspace = true } serde = { workspace = true } +serde_json = { workspace = true } +thiserror = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } +candle-holder = { path = "../candle-holder", version = "0.1.0" } +candle-holder-models = { path = "../candle-holder-models", version = "0.1.0", features = [ + "tokenizers", +] } +candle-holder-tokenizers = { path = "../candle-holder-tokenizers", version = "0.1.0" } +candle-holder-pipelines = { path = "../candle-holder-pipelines", version = "0.1.0" } +candle-core = { workspace = true } +candle-nn = { workspace = true } +accelerate-src = { workspace = true, optional = true } +intel-mkl-src = { workspace = true, optional = true } + +[features] +metal = ["candle-core/metal", "candle-holder-models/metal"] +cuda = ["candle-core/cuda", "candle-holder-models/cuda"] +cudnn = ["candle-core/cudnn", "candle-holder-models/cudnn"] +accelerate = ["dep:accelerate-src", "candle-holder-models/accelerate"] +mkl = ["dep:intel-mkl-src", "candle-holder-models/mkl"] +flash-attn = ["candle-holder-models/flash-attn"] diff --git a/candle-holder-serve/src/cli.rs b/candle-holder-serve/src/cli.rs new file mode 100644 index 0000000..d118fc3 --- /dev/null +++ b/candle-holder-serve/src/cli.rs @@ -0,0 +1,94 @@ +use anyhow::{anyhow, Result}; +use candle_core::Device; +use clap::Parser; +use serde::Serialize; +use std::str::FromStr; + +#[derive(Debug, Parser)] +#[command(version, about, long_about = None)] +pub struct Cli { + /// The host to listen on. + #[arg(long, default_value = "0.0.0.0:3000")] + host: String, + + /// The Hugging Face repository id of the model to be loaded. + #[arg(short, long)] + model: String, + + /// The name of the pipeline to be served. + #[arg(short, long)] + pipeline: Pipeline, + + /// The device to run the pipeline on. + #[arg(short, long, value_parser = parse_device, default_value = "cpu")] + device: DeviceOption, +} + +impl Cli { + pub fn host(&self) -> &str { + &self.host + } + + pub fn model(&self) -> &str { + &self.model + } + + pub fn pipeline(&self) -> &Pipeline { + &self.pipeline + } + + /// Get the [`candle_core::Device`] corresponding to the selected device option. + /// + /// # Errors + /// + /// Returns an error if the requested device is not available. + pub fn device(&self) -> Result { + match self.device { + DeviceOption::Cuda(device_id) if cfg!(feature = "cuda") => { + Ok(Device::new_cuda(device_id)?) + } + DeviceOption::Metal if cfg!(feature = "metal") => Ok(Device::new_metal(0)?), + DeviceOption::Cpu => Ok(Device::Cpu), + _ => Err(anyhow!("Requested device is not available")), + } + } +} + +#[derive(Debug, Parser, Clone, Serialize, clap::ValueEnum)] +#[serde(rename_all = "kebab-case")] +pub enum Pipeline { + FeatureExtraction, + FillMask, + TextClassification, + TextGeneration, + TokenClassification, + ZeroShotClassification, +} + +#[derive(Debug, Clone, clap::ValueEnum)] +pub enum DeviceOption { + Cpu, + Metal, + #[value(skip)] + Cuda(usize), +} + +impl FromStr for DeviceOption { + type Err = anyhow::Error; + + fn from_str(s: &str) -> Result { + match s { + "cpu" => Ok(DeviceOption::Cpu), + "metal" => Ok(DeviceOption::Metal), + s if s.starts_with("cuda:") => { + let id = s.strip_prefix("cuda:").unwrap().parse::()?; + Ok(DeviceOption::Cuda(id)) + } + _ => Err(anyhow!("Invalid device option: {}", s)), + } + } +} + +fn parse_device(s: &str) -> Result { + DeviceOption::from_str(s) +} diff --git a/candle-holder-serve/src/main.rs b/candle-holder-serve/src/main.rs index 1d83b4c..6227033 100644 --- a/candle-holder-serve/src/main.rs +++ b/candle-holder-serve/src/main.rs @@ -1,81 +1,38 @@ +mod cli; mod routes; -use anyhow::{anyhow, Result}; -use axum::{routing::get, Router}; +use anyhow::Result; +use axum::Router; use clap::Parser; -use serde::Serialize; -use std::str::FromStr; -#[derive(Debug, Parser)] -#[command(version, about, long_about = None)] -pub struct Cli { - /// The host to listen on. - #[arg(long, default_value = "0.0.0.0:3000")] - host: String, - - /// The Hugging Face repository id of the model to be loaded. - #[arg(short, long)] - model: String, - - /// The name of the pipeline to be served. - #[arg(short, long)] - pipeline: Pipeline, - - /// The device to run the pipeline on. - #[arg(short, long, value_parser = parse_device, default_value = "cpu")] - device: DeviceOption, -} - -#[derive(Debug, Parser, Clone, Serialize, clap::ValueEnum)] -#[serde(rename_all = "kebab-case")] -pub enum Pipeline { - FeatureExtraction, - FillMask, - TextClassification, - TextGeneration, - TokenClassification, - ZeroShotClassification, -} - -#[derive(Debug, Clone, clap::ValueEnum)] -pub enum DeviceOption { - Cpu, - Metal, - #[value(skip)] - Cuda(usize), -} - -impl FromStr for DeviceOption { - type Err = anyhow::Error; - - fn from_str(s: &str) -> Result { - match s { - "cpu" => Ok(DeviceOption::Cpu), - "metal" => Ok(DeviceOption::Metal), - s if s.starts_with("cuda:") => { - let id = s.strip_prefix("cuda:").unwrap().parse::()?; - Ok(DeviceOption::Cuda(id)) - } - _ => Err(anyhow!("Invalid device option: {}", s)), - } - } -} - -fn parse_device(s: &str) -> Result { - DeviceOption::from_str(s) -} +use crate::cli::{Cli, Pipeline}; +use crate::routes::{ + feature_extraction, fill_mask, text_classification, text_generation, token_classification, + zero_shot_classification, +}; #[tokio::main] -async fn main() { - let args = Cli::parse(); - - // Create a new router - let app = Router::new().route("/", get(root)); +async fn main() -> Result<()> { + // initialize tracing + tracing_subscriber::fmt::init(); - let listener = tokio::net::TcpListener::bind(args.host).await.unwrap(); - axum::serve(listener, app).await.unwrap(); -} + // Parse the command line arguments + let args = Cli::parse(); -async fn root() -> &'static str { - "Hello, World!" + // Initialize the router based on the pipeline + let inference_router = match args.pipeline() { + Pipeline::FeatureExtraction => feature_extraction::router(&args)?, + Pipeline::FillMask => fill_mask::router(&args)?, + Pipeline::TextClassification => text_classification::router(&args)?, + Pipeline::TextGeneration => text_generation::router(&args)?, + Pipeline::TokenClassification => token_classification::router(&args)?, + Pipeline::ZeroShotClassification => zero_shot_classification::router(&args)?, + }; + let router = Router::new().nest("/", inference_router); + + tracing::info!("Listening on {}", args.host()); + let listener = tokio::net::TcpListener::bind(args.host()).await.unwrap(); + axum::serve(listener, router).await.unwrap(); + + Ok(()) } diff --git a/candle-holder-serve/src/routes/feature_extraction.rs b/candle-holder-serve/src/routes/feature_extraction.rs new file mode 100644 index 0000000..11b1d1c --- /dev/null +++ b/candle-holder-serve/src/routes/feature_extraction.rs @@ -0,0 +1,32 @@ +use anyhow::Result; +use axum::{routing::post, Router}; +use candle_holder_pipelines::FeatureExtractionPipeline; +use std::sync::Arc; + +use crate::cli::Cli; + +pub fn router(args: &Cli) -> Result { + let model = args.model(); + let device = args.device()?; + + tracing::info!( + "Loading feature extraction pipeline for model '{}' on device {:?}", + model, + device + ); + + let pipeline = Arc::new(FeatureExtractionPipeline::new( + &args.model(), + &args.device()?, + None, + None, + )?); + + Ok(Router::new() + .route("/", post(inference)) + .with_state(pipeline)) +} + +async fn inference() -> &'static str { + "inference" +} diff --git a/candle-holder-serve/src/routes/fill_mask.rs b/candle-holder-serve/src/routes/fill_mask.rs new file mode 100644 index 0000000..c534ab1 --- /dev/null +++ b/candle-holder-serve/src/routes/fill_mask.rs @@ -0,0 +1,32 @@ +use anyhow::Result; +use axum::{routing::post, Router}; +use candle_holder_pipelines::FillMaskPipeline; +use std::sync::Arc; + +use crate::cli::Cli; + +pub fn router(args: &Cli) -> Result { + let model = args.model(); + let device = args.device()?; + + tracing::info!( + "Loading fill mask pipeline for model '{}' on device {:?}", + model, + device + ); + + let pipeline = Arc::new(FillMaskPipeline::new( + &args.model(), + &args.device()?, + None, + None, + )?); + + Ok(Router::new() + .route("/", post(inference)) + .with_state(pipeline)) +} + +async fn inference() -> &'static str { + "inference" +} diff --git a/candle-holder-serve/src/routes/mod.rs b/candle-holder-serve/src/routes/mod.rs new file mode 100644 index 0000000..ab90a8b --- /dev/null +++ b/candle-holder-serve/src/routes/mod.rs @@ -0,0 +1,6 @@ +pub mod feature_extraction; +pub mod fill_mask; +pub mod text_classification; +pub mod text_generation; +pub mod token_classification; +pub mod zero_shot_classification; diff --git a/candle-holder-serve/src/routes/text_classification.rs b/candle-holder-serve/src/routes/text_classification.rs new file mode 100644 index 0000000..225c4fe --- /dev/null +++ b/candle-holder-serve/src/routes/text_classification.rs @@ -0,0 +1,32 @@ +use anyhow::Result; +use axum::{routing::post, Router}; +use candle_holder_pipelines::TextClassificationPipeline; +use std::sync::Arc; + +use crate::cli::Cli; + +pub fn router(args: &Cli) -> Result { + let model = args.model(); + let device = args.device()?; + + tracing::info!( + "Loading text classification pipeline for model '{}' on device {:?}", + model, + device + ); + + let pipeline = Arc::new(TextClassificationPipeline::new( + &args.model(), + &args.device()?, + None, + None, + )?); + + Ok(Router::new() + .route("/", post(inference)) + .with_state(pipeline)) +} + +async fn inference() -> &'static str { + "inference" +} diff --git a/candle-holder-serve/src/routes/text_generation.rs b/candle-holder-serve/src/routes/text_generation.rs new file mode 100644 index 0000000..fdf8944 --- /dev/null +++ b/candle-holder-serve/src/routes/text_generation.rs @@ -0,0 +1,32 @@ +use anyhow::Result; +use axum::{routing::post, Router}; +use candle_holder_pipelines::TextGenerationPipeline; +use std::sync::Arc; + +use crate::cli::Cli; + +pub fn router(args: &Cli) -> Result { + let model = args.model(); + let device = args.device()?; + + tracing::info!( + "Loading text generation pipeline for model '{}' on device {:?}", + model, + device + ); + + let pipeline = Arc::new(TextGenerationPipeline::new( + &args.model(), + &args.device()?, + None, + None, + )?); + + Ok(Router::new() + .route("/", post(inference)) + .with_state(pipeline)) +} + +async fn inference() -> &'static str { + "inference" +} diff --git a/candle-holder-serve/src/routes/token_classification.rs b/candle-holder-serve/src/routes/token_classification.rs new file mode 100644 index 0000000..78aeeb5 --- /dev/null +++ b/candle-holder-serve/src/routes/token_classification.rs @@ -0,0 +1,32 @@ +use anyhow::Result; +use axum::{routing::post, Router}; +use candle_holder_pipelines::TokenClassificationPipeline; +use std::sync::Arc; + +use crate::cli::Cli; + +pub fn router(args: &Cli) -> Result { + let model = args.model(); + let device = args.device()?; + + tracing::info!( + "Loading token classification pipeline for model '{}' on device {:?}", + model, + device + ); + + let pipeline = Arc::new(TokenClassificationPipeline::new( + &args.model(), + &args.device()?, + None, + None, + )?); + + Ok(Router::new() + .route("/", post(inference)) + .with_state(pipeline)) +} + +async fn inference() -> &'static str { + "inference" +} diff --git a/candle-holder-serve/src/routes/zero_shot_classification.rs b/candle-holder-serve/src/routes/zero_shot_classification.rs new file mode 100644 index 0000000..ee66bf4 --- /dev/null +++ b/candle-holder-serve/src/routes/zero_shot_classification.rs @@ -0,0 +1,32 @@ +use anyhow::Result; +use axum::{routing::post, Router}; +use candle_holder_pipelines::ZeroShotClassificationPipeline; +use std::sync::Arc; + +use crate::cli::Cli; + +pub fn router(args: &Cli) -> Result { + let model = args.model(); + let device = args.device()?; + + tracing::info!( + "Loading zero shot classification pipeline for model '{}' on device {:?}", + model, + device + ); + + let pipeline = Arc::new(ZeroShotClassificationPipeline::new( + &args.model(), + &args.device()?, + None, + None, + )?); + + Ok(Router::new() + .route("/", post(inference)) + .with_state(pipeline)) +} + +async fn inference() -> &'static str { + "inference" +}