Skip to content

Commit

Permalink
Better api documents.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Jun 8, 2024
1 parent 37439ac commit d490858
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 35 deletions.
2 changes: 1 addition & 1 deletion crates/ai00-server/src/api/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct AuthResponse {
message: Option<String>,
}

/// Exchange `appkey` and `app_secret` with the authorization token.
/// Exchange `key` and `app_secret` with the authorization token.
#[endpoint(
responses(
(status_code = 200, description = "Exchange the token successfully.", body = AuthResponse),
Expand Down
6 changes: 3 additions & 3 deletions crates/ai00-server/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ pub mod file;
pub mod model;
pub mod oai;

pub use adapter::adapters;
pub use file::{dir, load_config, models, save_config, unzip};
pub use model::{info, load, load_state, save, state, unload};
// pub use adapter::adapters;
// pub use file::{dir, load_config, models, save_config, unzip};
// pub use model::{info, load, load_state, save, state, unload};

pub async fn try_request_info(sender: Sender<ThreadRequest>) -> Result<RuntimeInfo> {
let (info_sender, info_receiver) = flume::unbounded();
Expand Down
35 changes: 32 additions & 3 deletions crates/ai00-server/src/api/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct InitStateInfo {
name: String,
}

/// Report the current runtime info.
#[handler]
pub async fn info(depot: &mut Depot) -> Json<InfoResponse> {
let ThreadState { sender, .. } = depot.obtain::<ThreadState>().unwrap();
Expand All @@ -43,6 +44,8 @@ pub async fn info(depot: &mut Depot) -> Json<InfoResponse> {
})
}

/// Report the current runtime info every half second.
///
/// `/api/models/state`.
#[handler]
pub async fn state(depot: &mut Depot, res: &mut Response) {
Expand Down Expand Up @@ -76,8 +79,16 @@ pub async fn state(depot: &mut Depot, res: &mut Response) {
salvo::sse::stream(res, stream);
}

/// Load a runtime with models, LoRA, initial states, etc.
///
/// `/api/models/load`.
#[handler]
#[endpoint(
responses(
(status_code = 200, description = "Load the initial state successfully."),
(status_code = 404, description = "Cannot locate the file requested."),
(status_code = 500, description = "Server thread exited."),
)
)]
pub async fn load(depot: &mut Depot, req: &mut Request) -> StatusCode {
let ThreadState { sender, path } = depot.obtain::<ThreadState>().unwrap();
let (result_sender, result_receiver) = flume::unbounded();
Expand Down Expand Up @@ -111,6 +122,8 @@ pub async fn load(depot: &mut Depot, req: &mut Request) -> StatusCode {
}
}

/// Unload the current runtime.
///
/// `/api/models/unload`.
#[handler]
pub async fn unload(depot: &mut Depot) -> StatusCode {
Expand All @@ -120,8 +133,16 @@ pub async fn unload(depot: &mut Depot) -> StatusCode {
StatusCode::OK
}

/// Load an initial state from the path.
///
/// `/api/models/state/load`.
#[handler]
#[endpoint(
responses(
(status_code = 200, description = "Load the initial state successfully."),
(status_code = 404, description = "Cannot locate the file requested."),
(status_code = 500, description = "Server thread exited."),
)
)]
pub async fn load_state(depot: &mut Depot, req: &mut Request) -> StatusCode {
let ThreadState { sender, path } = depot.obtain::<ThreadState>().unwrap();
let (result_sender, result_receiver) = flume::unbounded();
Expand All @@ -142,8 +163,16 @@ pub async fn load_state(depot: &mut Depot, req: &mut Request) -> StatusCode {
}
}

/// Save the current model as a prefab.
///
/// `/api/models/save`.
#[handler]
#[endpoint(
responses(
(status_code = 200, description = "Save the model successfully."),
(status_code = 404, description = "Cannot locate the file requested."),
(status_code = 500, description = "Server thread exited."),
)
)]
pub async fn save(depot: &mut Depot, req: &mut Request) -> StatusCode {
let ThreadState { sender, path } = depot.obtain::<ThreadState>().unwrap();
let (result_sender, result_receiver) = flume::unbounded();
Expand Down
10 changes: 5 additions & 5 deletions crates/ai00-server/src/api/oai/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,11 @@ async fn respond_stream(depot: &mut Depot, request: ChatRequest, res: &mut Respo

/// Generate chat completions with context.
#[endpoint(
responses(
(status_code = 200, description = "Generate one response if `stream` is false.", body = ChatResponse),
(status_code = 201, description = "Generate SSE response if `stream` is true. `StatusCode` should be 200.", body = PartialChatResponse)
)
)]
responses(
(status_code = 200, description = "Generate one response if `stream` is false.", body = ChatResponse),
(status_code = 201, description = "Generate SSE response if `stream` is true. `StatusCode` should be 200.", body = PartialChatResponse)
)
)]
pub async fn chat_completions(depot: &mut Depot, req: JsonBody<ChatRequest>, res: &mut Response) {
let request = req.0;
match request.stream {
Expand Down
10 changes: 5 additions & 5 deletions crates/ai00-server/src/api/oai/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,11 @@ async fn respond_stream(depot: &mut Depot, request: CompletionRequest, res: &mut

/// Generate completions for the given text.
#[endpoint(
responses(
(status_code = 200, description = "Generate one response if `stream` is false.", body = CompletionResponse),
(status_code = 201, description = "Generate SSE response if `stream` is true. `StatusCode` should be 200.", body = PartialCompletionResponse)
)
)]
responses(
(status_code = 200, description = "Generate one response if `stream` is false.", body = CompletionResponse),
(status_code = 201, description = "Generate SSE response if `stream` is true. `StatusCode` should be 200.", body = PartialCompletionResponse)
)
)]
pub async fn completions(depot: &mut Depot, req: JsonBody<CompletionRequest>, res: &mut Response) {
let request = req.0;
match request.stream {
Expand Down
10 changes: 5 additions & 5 deletions crates/ai00-server/src/api/oai/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use salvo::oapi::ToSchema;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;

pub mod chat;
pub mod choose;
pub mod completion;
pub mod embedding;
pub mod info;
mod chat;
mod choose;
mod completion;
mod embedding;
mod info;

pub use chat::chat_completions;
pub use choose::chooses;
Expand Down
26 changes: 13 additions & 13 deletions crates/ai00-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,19 @@ async fn main() {
.force_passed(listen.force_pass.unwrap_or_default());

let api_router = Router::with_hoop(auth_handler)
.push(Router::with_path("/adapters").get(api::adapters))
.push(Router::with_path("/models/info").get(api::info))
.push(Router::with_path("/models/save").post(api::save))
.push(Router::with_path("/models/load").post(api::load))
.push(Router::with_path("/models/unload").get(api::unload))
.push(Router::with_path("/models/state/load").post(api::load_state))
.push(Router::with_path("/models/state").get(api::state))
.push(Router::with_path("/models/list").get(api::models))
.push(Router::with_path("/files/unzip").post(api::unzip))
.push(Router::with_path("/files/dir").post(api::dir))
.push(Router::with_path("/files/ls").post(api::dir))
.push(Router::with_path("/files/config/load").post(api::load_config))
.push(Router::with_path("/files/config/save").post(api::save_config))
.push(Router::with_path("/adapters").get(api::adapter::adapters))
.push(Router::with_path("/models/info").get(api::model::info))
.push(Router::with_path("/models/save").post(api::model::save))
.push(Router::with_path("/models/load").post(api::model::load))
.push(Router::with_path("/models/unload").get(api::model::unload))
.push(Router::with_path("/models/state/load").post(api::model::load_state))
.push(Router::with_path("/models/state").get(api::model::state))
.push(Router::with_path("/models/list").get(api::file::models))
.push(Router::with_path("/files/unzip").post(api::file::unzip))
.push(Router::with_path("/files/dir").post(api::file::dir))
.push(Router::with_path("/files/ls").post(api::file::dir))
.push(Router::with_path("/files/config/load").post(api::file::load_config))
.push(Router::with_path("/files/config/save").post(api::file::save_config))
.push(Router::with_path("/oai/models").get(api::oai::models))
.push(Router::with_path("/oai/v1/models").get(api::oai::models))
.push(Router::with_path("/oai/completions").post(api::oai::completions))
Expand Down

0 comments on commit d490858

Please sign in to comment.