-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ab4221c
commit 02789f7
Showing
10 changed files
with
341 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
use anyhow::{anyhow, Result}; | ||
use candle_core::Device; | ||
use clap::Parser; | ||
use serde::Serialize; | ||
use std::str::FromStr; | ||
|
||
#[derive(Debug, Parser)] | ||
#[command(version, about, long_about = None)] | ||
pub struct Cli { | ||
/// The host to listen on. | ||
#[arg(long, default_value = "0.0.0.0:3000")] | ||
host: String, | ||
|
||
/// The Hugging Face repository id of the model to be loaded. | ||
#[arg(short, long)] | ||
model: String, | ||
|
||
/// The name of the pipeline to be served. | ||
#[arg(short, long)] | ||
pipeline: Pipeline, | ||
|
||
/// The device to run the pipeline on. | ||
#[arg(short, long, value_parser = parse_device, default_value = "cpu")] | ||
device: DeviceOption, | ||
} | ||
|
||
impl Cli { | ||
pub fn host(&self) -> &str { | ||
&self.host | ||
} | ||
|
||
pub fn model(&self) -> &str { | ||
&self.model | ||
} | ||
|
||
pub fn pipeline(&self) -> &Pipeline { | ||
&self.pipeline | ||
} | ||
|
||
/// Get the [`candle_core::Device`] corresponding to the selected device option. | ||
/// | ||
/// # Errors | ||
/// | ||
/// Returns an error if the requested device is not available. | ||
pub fn device(&self) -> Result<Device> { | ||
match self.device { | ||
DeviceOption::Cuda(device_id) if cfg!(feature = "cuda") => { | ||
Ok(Device::new_cuda(device_id)?) | ||
} | ||
DeviceOption::Metal if cfg!(feature = "metal") => Ok(Device::new_metal(0)?), | ||
DeviceOption::Cpu => Ok(Device::Cpu), | ||
_ => Err(anyhow!("Requested device is not available")), | ||
} | ||
} | ||
} | ||
|
||
#[derive(Debug, Parser, Clone, Serialize, clap::ValueEnum)] | ||
#[serde(rename_all = "kebab-case")] | ||
pub enum Pipeline { | ||
FeatureExtraction, | ||
FillMask, | ||
TextClassification, | ||
TextGeneration, | ||
TokenClassification, | ||
ZeroShotClassification, | ||
} | ||
|
||
#[derive(Debug, Clone, clap::ValueEnum)] | ||
pub enum DeviceOption { | ||
Cpu, | ||
Metal, | ||
#[value(skip)] | ||
Cuda(usize), | ||
} | ||
|
||
impl FromStr for DeviceOption { | ||
type Err = anyhow::Error; | ||
|
||
fn from_str(s: &str) -> Result<Self, Self::Err> { | ||
match s { | ||
"cpu" => Ok(DeviceOption::Cpu), | ||
"metal" => Ok(DeviceOption::Metal), | ||
s if s.starts_with("cuda:") => { | ||
let id = s.strip_prefix("cuda:").unwrap().parse::<usize>()?; | ||
Ok(DeviceOption::Cuda(id)) | ||
} | ||
_ => Err(anyhow!("Invalid device option: {}", s)), | ||
} | ||
} | ||
} | ||
|
||
fn parse_device(s: &str) -> Result<DeviceOption, anyhow::Error> { | ||
DeviceOption::from_str(s) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,81 +1,38 @@ | ||
mod cli; | ||
mod routes; | ||
|
||
use anyhow::{anyhow, Result}; | ||
use axum::{routing::get, Router}; | ||
use anyhow::Result; | ||
use axum::Router; | ||
use clap::Parser; | ||
use serde::Serialize; | ||
use std::str::FromStr; | ||
|
||
#[derive(Debug, Parser)] | ||
#[command(version, about, long_about = None)] | ||
pub struct Cli { | ||
/// The host to listen on. | ||
#[arg(long, default_value = "0.0.0.0:3000")] | ||
host: String, | ||
|
||
/// The Hugging Face repository id of the model to be loaded. | ||
#[arg(short, long)] | ||
model: String, | ||
|
||
/// The name of the pipeline to be served. | ||
#[arg(short, long)] | ||
pipeline: Pipeline, | ||
|
||
/// The device to run the pipeline on. | ||
#[arg(short, long, value_parser = parse_device, default_value = "cpu")] | ||
device: DeviceOption, | ||
} | ||
|
||
#[derive(Debug, Parser, Clone, Serialize, clap::ValueEnum)] | ||
#[serde(rename_all = "kebab-case")] | ||
pub enum Pipeline { | ||
FeatureExtraction, | ||
FillMask, | ||
TextClassification, | ||
TextGeneration, | ||
TokenClassification, | ||
ZeroShotClassification, | ||
} | ||
|
||
#[derive(Debug, Clone, clap::ValueEnum)] | ||
pub enum DeviceOption { | ||
Cpu, | ||
Metal, | ||
#[value(skip)] | ||
Cuda(usize), | ||
} | ||
|
||
impl FromStr for DeviceOption { | ||
type Err = anyhow::Error; | ||
|
||
fn from_str(s: &str) -> Result<Self, Self::Err> { | ||
match s { | ||
"cpu" => Ok(DeviceOption::Cpu), | ||
"metal" => Ok(DeviceOption::Metal), | ||
s if s.starts_with("cuda:") => { | ||
let id = s.strip_prefix("cuda:").unwrap().parse::<usize>()?; | ||
Ok(DeviceOption::Cuda(id)) | ||
} | ||
_ => Err(anyhow!("Invalid device option: {}", s)), | ||
} | ||
} | ||
} | ||
|
||
fn parse_device(s: &str) -> Result<DeviceOption, anyhow::Error> { | ||
DeviceOption::from_str(s) | ||
} | ||
use crate::cli::{Cli, Pipeline}; | ||
use crate::routes::{ | ||
feature_extraction, fill_mask, text_classification, text_generation, token_classification, | ||
zero_shot_classification, | ||
}; | ||
|
||
#[tokio::main] | ||
async fn main() { | ||
let args = Cli::parse(); | ||
|
||
// Create a new router | ||
let app = Router::new().route("/", get(root)); | ||
async fn main() -> Result<()> { | ||
// initialize tracing | ||
tracing_subscriber::fmt::init(); | ||
|
||
let listener = tokio::net::TcpListener::bind(args.host).await.unwrap(); | ||
axum::serve(listener, app).await.unwrap(); | ||
} | ||
// Parse the command line arguments | ||
let args = Cli::parse(); | ||
|
||
async fn root() -> &'static str { | ||
"Hello, World!" | ||
// Initialize the router based on the pipeline | ||
let inference_router = match args.pipeline() { | ||
Pipeline::FeatureExtraction => feature_extraction::router(&args)?, | ||
Pipeline::FillMask => fill_mask::router(&args)?, | ||
Pipeline::TextClassification => text_classification::router(&args)?, | ||
Pipeline::TextGeneration => text_generation::router(&args)?, | ||
Pipeline::TokenClassification => token_classification::router(&args)?, | ||
Pipeline::ZeroShotClassification => zero_shot_classification::router(&args)?, | ||
}; | ||
let router = Router::new().nest("/", inference_router); | ||
|
||
tracing::info!("Listening on {}", args.host()); | ||
let listener = tokio::net::TcpListener::bind(args.host()).await.unwrap(); | ||
axum::serve(listener, router).await.unwrap(); | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use anyhow::Result; | ||
use axum::{routing::post, Router}; | ||
use candle_holder_pipelines::FeatureExtractionPipeline; | ||
use std::sync::Arc; | ||
|
||
use crate::cli::Cli; | ||
|
||
pub fn router(args: &Cli) -> Result<Router> { | ||
let model = args.model(); | ||
let device = args.device()?; | ||
|
||
tracing::info!( | ||
"Loading feature extraction pipeline for model '{}' on device {:?}", | ||
model, | ||
device | ||
); | ||
|
||
let pipeline = Arc::new(FeatureExtractionPipeline::new( | ||
&args.model(), | ||
&args.device()?, | ||
None, | ||
None, | ||
)?); | ||
|
||
Ok(Router::new() | ||
.route("/", post(inference)) | ||
.with_state(pipeline)) | ||
} | ||
|
||
async fn inference() -> &'static str { | ||
"inference" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use anyhow::Result; | ||
use axum::{routing::post, Router}; | ||
use candle_holder_pipelines::FillMaskPipeline; | ||
use std::sync::Arc; | ||
|
||
use crate::cli::Cli; | ||
|
||
pub fn router(args: &Cli) -> Result<Router> { | ||
let model = args.model(); | ||
let device = args.device()?; | ||
|
||
tracing::info!( | ||
"Loading fill mask pipeline for model '{}' on device {:?}", | ||
model, | ||
device | ||
); | ||
|
||
let pipeline = Arc::new(FillMaskPipeline::new( | ||
&args.model(), | ||
&args.device()?, | ||
None, | ||
None, | ||
)?); | ||
|
||
Ok(Router::new() | ||
.route("/", post(inference)) | ||
.with_state(pipeline)) | ||
} | ||
|
||
async fn inference() -> &'static str { | ||
"inference" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
pub mod feature_extraction; | ||
pub mod fill_mask; | ||
pub mod text_classification; | ||
pub mod text_generation; | ||
pub mod token_classification; | ||
pub mod zero_shot_classification; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use anyhow::Result; | ||
use axum::{routing::post, Router}; | ||
use candle_holder_pipelines::TextClassificationPipeline; | ||
use std::sync::Arc; | ||
|
||
use crate::cli::Cli; | ||
|
||
pub fn router(args: &Cli) -> Result<Router> { | ||
let model = args.model(); | ||
let device = args.device()?; | ||
|
||
tracing::info!( | ||
"Loading text classification pipeline for model '{}' on device {:?}", | ||
model, | ||
device | ||
); | ||
|
||
let pipeline = Arc::new(TextClassificationPipeline::new( | ||
&args.model(), | ||
&args.device()?, | ||
None, | ||
None, | ||
)?); | ||
|
||
Ok(Router::new() | ||
.route("/", post(inference)) | ||
.with_state(pipeline)) | ||
} | ||
|
||
async fn inference() -> &'static str { | ||
"inference" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use anyhow::Result; | ||
use axum::{routing::post, Router}; | ||
use candle_holder_pipelines::TextGenerationPipeline; | ||
use std::sync::Arc; | ||
|
||
use crate::cli::Cli; | ||
|
||
pub fn router(args: &Cli) -> Result<Router> { | ||
let model = args.model(); | ||
let device = args.device()?; | ||
|
||
tracing::info!( | ||
"Loading text generation pipeline for model '{}' on device {:?}", | ||
model, | ||
device | ||
); | ||
|
||
let pipeline = Arc::new(TextGenerationPipeline::new( | ||
&args.model(), | ||
&args.device()?, | ||
None, | ||
None, | ||
)?); | ||
|
||
Ok(Router::new() | ||
.route("/", post(inference)) | ||
.with_state(pipeline)) | ||
} | ||
|
||
async fn inference() -> &'static str { | ||
"inference" | ||
} |
Oops, something went wrong.