Skip to content

Commit

Permalink
Use normal tokio tasks for core functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed May 21, 2024
1 parent aafcb43 commit 3228cc7
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 26 deletions.
3 changes: 1 addition & 2 deletions crates/ai00-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,15 +442,14 @@ async fn load_runtime(
Ok(runtime)
}

#[tokio::main]
pub async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
let env: Arc<RwLock<Environment>> = Default::default();
let queue: Arc<Mutex<Vec<GenerateContext>>> = Default::default();

let sender = {
let (sender, receiver) = flume::unbounded();
let env = env.clone();
tokio::task::spawn_blocking(move || crate::run::run(receiver, env));
tokio::task::spawn(crate::run::run(receiver, env));
sender
};

Expand Down
6 changes: 1 addition & 5 deletions crates/ai00-core/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,7 @@ impl Environment {
match runtime.queue(context).await.expect("queue task error") {
SlotResult::Success(batch) => log::info!("queued task at slot {batch}"),
SlotResult::Fault(batch) => log::info!("swapped task at slot {batch}"),
SlotResult::Failure(context) => {
log::info!("failed to queue task");
queue.push(*context);
}
SlotResult::Failure(context) => queue.push(*context),
SlotResult::Error(reason) => log::warn!("queue task failed: {}", reason),
}
}
Expand Down Expand Up @@ -973,7 +970,6 @@ impl Runtime {
}
}

#[tokio::main]
pub async fn run(receiver: Receiver<()>, env: Arc<RwLock<Environment>>) {
{
// this task constantly runs, cleaning up state cache
Expand Down
8 changes: 3 additions & 5 deletions crates/ai00-server/src/api/model.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::time::Duration;

use ai00_core::{
run::{InitState, StateId},
ReloadRequest, RuntimeInfo, SaveRequest, ThreadRequest,
Expand All @@ -10,7 +8,7 @@ use serde::Serialize;
use web_rwkv::runtime::model::ModelInfo;

use super::*;
use crate::{build_path, types::ThreadState};
use crate::{build_path, types::ThreadState, SLEEP};

#[derive(Debug, Clone, Serialize)]
pub struct InfoResponse {
Expand All @@ -33,7 +31,7 @@ pub async fn info(depot: &mut Depot) -> Json<InfoResponse> {
model,
states,
..
} = request_info(sender.to_owned(), Duration::from_millis(500)).await;
} = request_info(sender.to_owned(), SLEEP).await;
let states = states
.into_iter()
.map(|(id, InitState { name, .. })| InitStateInfo { id, name })
Expand All @@ -50,7 +48,7 @@ pub async fn info(depot: &mut Depot) -> Json<InfoResponse> {
pub async fn state(depot: &mut Depot, res: &mut Response) {
let ThreadState { sender, .. } = depot.obtain::<ThreadState>().unwrap();
let (info_sender, info_receiver) = flume::unbounded();
let task = request_info_stream(sender.to_owned(), info_sender, Duration::from_millis(500));
let task = request_info_stream(sender.to_owned(), info_sender, SLEEP);
tokio::task::spawn(task);

let stream = info_receiver.into_stream().map(
Expand Down
7 changes: 4 additions & 3 deletions crates/ai00-server/src/api/oai/chat.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use std::{collections::HashMap, sync::Arc};

use ai00_core::{
run::StateId, sampler::Sampler, FinishReason, GenerateRequest, ThreadRequest, Token,
Expand All @@ -15,6 +15,7 @@ use super::*;
use crate::{
api::request_info,
types::{Array, ThreadState},
SLEEP,
};

#[derive(Debug, Default, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, ToSchema)]
Expand Down Expand Up @@ -184,7 +185,7 @@ struct PartialChatResponse {

async fn respond_one(depot: &mut Depot, request: ChatRequest, res: &mut Response) {
let ThreadState { sender, .. } = depot.obtain::<ThreadState>().unwrap();
let info = request_info(sender.clone(), Duration::from_millis(500)).await;
let info = request_info(sender.clone(), SLEEP).await;
let model_name = info.reload.model_path.to_string_lossy().into_owned();

let (token_sender, token_receiver) = flume::unbounded();
Expand Down Expand Up @@ -233,7 +234,7 @@ async fn respond_one(depot: &mut Depot, request: ChatRequest, res: &mut Response

async fn respond_stream(depot: &mut Depot, request: ChatRequest, res: &mut Response) {
let ThreadState { sender, .. } = depot.obtain::<ThreadState>().unwrap();
let info = request_info(sender.clone(), Duration::from_millis(500)).await;
let info = request_info(sender.clone(), SLEEP).await;
let model_name = info.reload.model_path.to_string_lossy().into_owned();

let (token_sender, token_receiver) = flume::unbounded();
Expand Down
7 changes: 4 additions & 3 deletions crates/ai00-server/src/api/oai/completion.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashMap, sync::Arc, time::Duration};
use std::{collections::HashMap, sync::Arc};

use ai00_core::{
run::StateId, FinishReason, GenerateRequest, ThreadRequest, Token, TokenCounter, MAX_TOKENS,
Expand All @@ -16,6 +16,7 @@ use super::*;
use crate::{
api::request_info,
types::{Array, ThreadState},
SLEEP,
};

#[derive(Debug, Deserialize, ToSchema)]
Expand Down Expand Up @@ -131,7 +132,7 @@ pub struct PartialCompletionResponse {

async fn respond_one(depot: &mut Depot, request: CompletionRequest, res: &mut Response) {
let ThreadState { sender, .. } = depot.obtain::<ThreadState>().unwrap();
let info = request_info(sender.clone(), Duration::from_millis(500)).await;
let info = request_info(sender.clone(), SLEEP).await;
let model_name = info.reload.model_path.to_string_lossy().into_owned();

let (token_sender, token_receiver) = flume::unbounded();
Expand Down Expand Up @@ -177,7 +178,7 @@ async fn respond_one(depot: &mut Depot, request: CompletionRequest, res: &mut Re

async fn respond_stream(depot: &mut Depot, request: CompletionRequest, res: &mut Response) {
let ThreadState { sender, .. } = depot.obtain::<ThreadState>().unwrap();
let info = request_info(sender.clone(), Duration::from_millis(500)).await;
let info = request_info(sender.clone(), SLEEP).await;
let model_name = info.reload.model_path.to_string_lossy().into_owned();

let (token_sender, token_receiver) = flume::unbounded();
Expand Down
5 changes: 2 additions & 3 deletions crates/ai00-server/src/api/oai/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::time::Duration;

use ai00_core::{GenerateRequest, ThreadRequest, Token, TokenCounter};
use futures_util::StreamExt;
use salvo::{
Expand All @@ -11,6 +9,7 @@ use serde::{Deserialize, Serialize};
use crate::{
api::request_info,
types::{Array, ThreadState},
SLEEP,
};

#[derive(Debug, Default, Clone, Deserialize, ToSchema, ToParameters)]
Expand Down Expand Up @@ -57,7 +56,7 @@ pub async fn embeddings(
) -> Json<EmbeddingResponse> {
let request = req.to_owned(); // req.parse_json::<EmbeddingRequest>().await.unwrap();
let ThreadState { sender, .. } = depot.obtain::<ThreadState>().unwrap();
let info = request_info(sender.clone(), Duration::from_millis(500)).await;
let info = request_info(sender.clone(), SLEEP).await;
let model_name = info.reload.model_path.to_string_lossy().into_owned();

let (token_sender, token_receiver) = flume::unbounded();
Expand Down
6 changes: 2 additions & 4 deletions crates/ai00-server/src/api/oai/info.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use std::time::Duration;

use salvo::{
oapi::{ToResponse, ToSchema},
prelude::*,
};
use serde::Serialize;

use crate::{api::request_info, types::ThreadState};
use crate::{api::request_info, types::ThreadState, SLEEP};

#[derive(Debug, Serialize, ToSchema)]
struct ModelChoice {
Expand All @@ -23,7 +21,7 @@ pub struct ModelResponse {
#[endpoint]
pub async fn models(depot: &mut Depot) -> Json<ModelResponse> {
let ThreadState { sender, .. } = depot.obtain::<ThreadState>().unwrap();
let info = request_info(sender.to_owned(), Duration::from_millis(500)).await;
let info = request_info(sender.to_owned(), SLEEP).await;
let model_name = info
.reload
.model_path
Expand Down
5 changes: 4 additions & 1 deletion crates/ai00-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{
io::Cursor,
net::{IpAddr, Ipv4Addr, SocketAddr},
path::{Path, PathBuf},
time::Duration,
};

use ai00_core::{model_route, ThreadRequest};
Expand Down Expand Up @@ -30,6 +31,8 @@ mod api;
mod config;
mod types;

const SLEEP: Duration = Duration::from_millis(500);

pub fn build_path(path: impl AsRef<Path>, name: impl AsRef<Path>) -> Result<PathBuf> {
let permitted = path.as_ref();
let name = name.as_ref();
Expand Down Expand Up @@ -116,7 +119,7 @@ async fn main() {
log::info!("{}\tversion: {}", bin_name, version);

let (sender, receiver) = flume::unbounded::<ThreadRequest>();
tokio::task::spawn_blocking(move || model_route(receiver));
tokio::task::spawn(model_route(receiver));

let (listen, config) = {
let path = args
Expand Down

0 comments on commit 3228cc7

Please sign in to comment.