Skip to content

Commit

Permalink
Add routers
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Sep 3, 2024
1 parent ab4221c commit 02789f7
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 72 deletions.
20 changes: 20 additions & 0 deletions candle-holder-serve/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,25 @@ clap = { workspace = true }
axum = { workspace = true }
tokio = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
candle-holder = { path = "../candle-holder", version = "0.1.0" }
candle-holder-models = { path = "../candle-holder-models", version = "0.1.0", features = [
"tokenizers",
] }
candle-holder-tokenizers = { path = "../candle-holder-tokenizers", version = "0.1.0" }
candle-holder-pipelines = { path = "../candle-holder-pipelines", version = "0.1.0" }
candle-core = { workspace = true }
candle-nn = { workspace = true }
accelerate-src = { workspace = true, optional = true }
intel-mkl-src = { workspace = true, optional = true }

[features]
metal = ["candle-core/metal", "candle-holder-models/metal"]
cuda = ["candle-core/cuda", "candle-holder-models/cuda"]
cudnn = ["candle-core/cudnn", "candle-holder-models/cudnn"]
accelerate = ["dep:accelerate-src", "candle-holder-models/accelerate"]
mkl = ["dep:intel-mkl-src", "candle-holder-models/mkl"]
flash-attn = ["candle-holder-models/flash-attn"]
94 changes: 94 additions & 0 deletions candle-holder-serve/src/cli.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use anyhow::{anyhow, Result};
use candle_core::Device;
use clap::Parser;
use serde::Serialize;
use std::str::FromStr;

#[derive(Debug, Parser)]
#[command(version, about, long_about = None)]
pub struct Cli {
/// The host to listen on.
#[arg(long, default_value = "0.0.0.0:3000")]
host: String,

/// The Hugging Face repository id of the model to be loaded.
#[arg(short, long)]
model: String,

/// The name of the pipeline to be served.
#[arg(short, long)]
pipeline: Pipeline,

/// The device to run the pipeline on.
#[arg(short, long, value_parser = parse_device, default_value = "cpu")]
device: DeviceOption,
}

impl Cli {
pub fn host(&self) -> &str {
&self.host
}

pub fn model(&self) -> &str {
&self.model
}

pub fn pipeline(&self) -> &Pipeline {
&self.pipeline
}

/// Get the [`candle_core::Device`] corresponding to the selected device option.
///
/// # Errors
///
/// Returns an error if the requested device is not available.
pub fn device(&self) -> Result<Device> {
match self.device {
DeviceOption::Cuda(device_id) if cfg!(feature = "cuda") => {
Ok(Device::new_cuda(device_id)?)
}
DeviceOption::Metal if cfg!(feature = "metal") => Ok(Device::new_metal(0)?),
DeviceOption::Cpu => Ok(Device::Cpu),
_ => Err(anyhow!("Requested device is not available")),
}
}
}

#[derive(Debug, Parser, Clone, Serialize, clap::ValueEnum)]
#[serde(rename_all = "kebab-case")]
pub enum Pipeline {
FeatureExtraction,
FillMask,
TextClassification,
TextGeneration,
TokenClassification,
ZeroShotClassification,
}

#[derive(Debug, Clone, clap::ValueEnum)]
pub enum DeviceOption {
Cpu,
Metal,
#[value(skip)]
Cuda(usize),
}

impl FromStr for DeviceOption {
type Err = anyhow::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"cpu" => Ok(DeviceOption::Cpu),
"metal" => Ok(DeviceOption::Metal),
s if s.starts_with("cuda:") => {
let id = s.strip_prefix("cuda:").unwrap().parse::<usize>()?;
Ok(DeviceOption::Cuda(id))
}
_ => Err(anyhow!("Invalid device option: {}", s)),
}
}
}

fn parse_device(s: &str) -> Result<DeviceOption, anyhow::Error> {
DeviceOption::from_str(s)
}
101 changes: 29 additions & 72 deletions candle-holder-serve/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,81 +1,38 @@
mod cli;
mod routes;

use anyhow::{anyhow, Result};
use axum::{routing::get, Router};
use anyhow::Result;
use axum::Router;
use clap::Parser;
use serde::Serialize;
use std::str::FromStr;

#[derive(Debug, Parser)]
#[command(version, about, long_about = None)]
pub struct Cli {
/// The host to listen on.
#[arg(long, default_value = "0.0.0.0:3000")]
host: String,

/// The Hugging Face repository id of the model to be loaded.
#[arg(short, long)]
model: String,

/// The name of the pipeline to be served.
#[arg(short, long)]
pipeline: Pipeline,

/// The device to run the pipeline on.
#[arg(short, long, value_parser = parse_device, default_value = "cpu")]
device: DeviceOption,
}

#[derive(Debug, Parser, Clone, Serialize, clap::ValueEnum)]
#[serde(rename_all = "kebab-case")]
pub enum Pipeline {
FeatureExtraction,
FillMask,
TextClassification,
TextGeneration,
TokenClassification,
ZeroShotClassification,
}

#[derive(Debug, Clone, clap::ValueEnum)]
pub enum DeviceOption {
Cpu,
Metal,
#[value(skip)]
Cuda(usize),
}

impl FromStr for DeviceOption {
type Err = anyhow::Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"cpu" => Ok(DeviceOption::Cpu),
"metal" => Ok(DeviceOption::Metal),
s if s.starts_with("cuda:") => {
let id = s.strip_prefix("cuda:").unwrap().parse::<usize>()?;
Ok(DeviceOption::Cuda(id))
}
_ => Err(anyhow!("Invalid device option: {}", s)),
}
}
}

fn parse_device(s: &str) -> Result<DeviceOption, anyhow::Error> {
DeviceOption::from_str(s)
}
use crate::cli::{Cli, Pipeline};
use crate::routes::{
feature_extraction, fill_mask, text_classification, text_generation, token_classification,
zero_shot_classification,
};

#[tokio::main]
async fn main() {
let args = Cli::parse();

// Create a new router
let app = Router::new().route("/", get(root));
async fn main() -> Result<()> {
// initialize tracing
tracing_subscriber::fmt::init();

let listener = tokio::net::TcpListener::bind(args.host).await.unwrap();
axum::serve(listener, app).await.unwrap();
}
// Parse the command line arguments
let args = Cli::parse();

async fn root() -> &'static str {
"Hello, World!"
// Initialize the router based on the pipeline
let inference_router = match args.pipeline() {
Pipeline::FeatureExtraction => feature_extraction::router(&args)?,
Pipeline::FillMask => fill_mask::router(&args)?,
Pipeline::TextClassification => text_classification::router(&args)?,
Pipeline::TextGeneration => text_generation::router(&args)?,
Pipeline::TokenClassification => token_classification::router(&args)?,
Pipeline::ZeroShotClassification => zero_shot_classification::router(&args)?,
};
let router = Router::new().nest("/", inference_router);

tracing::info!("Listening on {}", args.host());
let listener = tokio::net::TcpListener::bind(args.host()).await.unwrap();
axum::serve(listener, router).await.unwrap();

Ok(())
}
32 changes: 32 additions & 0 deletions candle-holder-serve/src/routes/feature_extraction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use anyhow::Result;
use axum::{routing::post, Router};
use candle_holder_pipelines::FeatureExtractionPipeline;
use std::sync::Arc;

use crate::cli::Cli;

pub fn router(args: &Cli) -> Result<Router> {
let model = args.model();
let device = args.device()?;

tracing::info!(
"Loading feature extraction pipeline for model '{}' on device {:?}",
model,
device
);

let pipeline = Arc::new(FeatureExtractionPipeline::new(
&args.model(),
&args.device()?,
None,
None,
)?);

Ok(Router::new()
.route("/", post(inference))
.with_state(pipeline))
}

async fn inference() -> &'static str {
"inference"
}
32 changes: 32 additions & 0 deletions candle-holder-serve/src/routes/fill_mask.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use anyhow::Result;
use axum::{routing::post, Router};
use candle_holder_pipelines::FillMaskPipeline;
use std::sync::Arc;

use crate::cli::Cli;

pub fn router(args: &Cli) -> Result<Router> {
let model = args.model();
let device = args.device()?;

tracing::info!(
"Loading fill mask pipeline for model '{}' on device {:?}",
model,
device
);

let pipeline = Arc::new(FillMaskPipeline::new(
&args.model(),
&args.device()?,
None,
None,
)?);

Ok(Router::new()
.route("/", post(inference))
.with_state(pipeline))
}

async fn inference() -> &'static str {
"inference"
}
6 changes: 6 additions & 0 deletions candle-holder-serve/src/routes/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pub mod feature_extraction;
pub mod fill_mask;
pub mod text_classification;
pub mod text_generation;
pub mod token_classification;
pub mod zero_shot_classification;
32 changes: 32 additions & 0 deletions candle-holder-serve/src/routes/text_classification.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use anyhow::Result;
use axum::{routing::post, Router};
use candle_holder_pipelines::TextClassificationPipeline;
use std::sync::Arc;

use crate::cli::Cli;

pub fn router(args: &Cli) -> Result<Router> {
let model = args.model();
let device = args.device()?;

tracing::info!(
"Loading text classification pipeline for model '{}' on device {:?}",
model,
device
);

let pipeline = Arc::new(TextClassificationPipeline::new(
&args.model(),
&args.device()?,
None,
None,
)?);

Ok(Router::new()
.route("/", post(inference))
.with_state(pipeline))
}

async fn inference() -> &'static str {
"inference"
}
32 changes: 32 additions & 0 deletions candle-holder-serve/src/routes/text_generation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use anyhow::Result;
use axum::{routing::post, Router};
use candle_holder_pipelines::TextGenerationPipeline;
use std::sync::Arc;

use crate::cli::Cli;

pub fn router(args: &Cli) -> Result<Router> {
let model = args.model();
let device = args.device()?;

tracing::info!(
"Loading text generation pipeline for model '{}' on device {:?}",
model,
device
);

let pipeline = Arc::new(TextGenerationPipeline::new(
&args.model(),
&args.device()?,
None,
None,
)?);

Ok(Router::new()
.route("/", post(inference))
.with_state(pipeline))
}

async fn inference() -> &'static str {
"inference"
}
Loading

0 comments on commit 02789f7

Please sign in to comment.