From 0f0e1db2d47fb3445c4fd35ba4688869f53ffaca Mon Sep 17 00:00:00 2001 From: cryscan Date: Tue, 25 Jul 2023 14:05:57 +0800 Subject: [PATCH] Add `/models` API. --- README.md | 2 ++ README_jp.md | 2 ++ README_zh.md | 2 ++ src/main.rs | 8 +++++++- src/models.rs | 26 ++++++++++++++++++++++++++ 5 files changed, 39 insertions(+), 1 deletion(-) create mode 100644 src/models.rs diff --git a/README.md b/README.md index 3c6c94d3..18229783 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,8 @@ QQ Group for communication: 30920262 The API service starts at port 3000, and the data input and output format follow the Openai API specification. +* `/v1/models` +* `/models` * `/v1/chat/completions` * `/chat/completions` * `/v1/completions` diff --git a/README_jp.md b/README_jp.md index f7b1844f..0179c1ee 100644 --- a/README_jp.md +++ b/README_jp.md @@ -93,6 +93,8 @@ OpenAIのChatGPT APIインターフェースと互換性があります。 APIサービスは3000ポートで開始され、データ入力と出力の形式はOpenai APIの規格に従います。 +* `/v1/models` +* `/models` * `/v1/chat/completions` * `/chat/completions` * `/v1/completions` diff --git a/README_zh.md b/README_zh.md index 31aafce5..f5924282 100644 --- a/README_zh.md +++ b/README_zh.md @@ -104,6 +104,8 @@ API 服务开启于 3000 端口, 数据输入已经输出格式遵循Openai API 规范。 +- `/v1/models` +- `/models` - `/v1/chat/completions` - `/chat/completions` - `/v1/completions` diff --git a/src/main.rs b/src/main.rs index 873e2fa9..43b6cb95 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,8 @@ use anyhow::Result; -use axum::{routing::post, Router}; +use axum::{ + routing::{get, post}, + Router, +}; use clap::Parser; use flume::Receiver; use itertools::Itertools; @@ -21,6 +24,7 @@ use web_rwkv::{BackedModelState, Environment, Model, Tokenizer}; mod chat; mod completion; mod embedding; +mod models; mod sampler; use crate::{ @@ -339,6 +343,8 @@ async fn main() -> Result<()> { std::thread::spawn(move || model_task(env, model_path, tokenizer_path, receiver)); let app = Router::new() + .route("/models", get(models::models)) + .route("/v1/models", get(models::models)) .route("/completions", post(completion::completions)) .route("/v1/completions", post(completion::completions)) .route("/chat/completions", post(chat::chat_completions)) diff --git a/src/models.rs b/src/models.rs new file mode 100644 index 00000000..ff0b7881 --- /dev/null +++ b/src/models.rs @@ -0,0 +1,26 @@ +use axum::{extract::State, Json}; +use serde::Serialize; + +use crate::ThreadState; + +#[derive(Debug, Serialize)] +pub struct ModelChoice { + pub object: String, + pub id: String, +} + +#[derive(Debug, Serialize)] +pub struct ModelResponse { + pub data: Vec, +} + +pub async fn models( + State(ThreadState { model_name, .. }): State, +) -> Json { + Json(ModelResponse { + data: vec![ModelChoice { + object: "models".into(), + id: model_name, + }], + }) +}