Skip to content

Commit

Permalink
Allow additionally loading states.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed May 21, 2024
1 parent e17377d commit 34f0279
Show file tree
Hide file tree
Showing 8 changed files with 200 additions and 30 deletions.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
[![All Contributors](https://img.shields.io/badge/all_contributors-7-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->



[English](README.md) | [中文](README_zh.md)
[![en](https://img.shields.io/badge/lang-en-red.svg)](README.md)
[![zh](https://img.shields.io/badge/lang-zh-blue.svg)](README.zh.md)

---

Expand Down Expand Up @@ -144,7 +143,7 @@ The API service starts at port 65530, and the data input and output format follo
* `/api/oai/v1/embeddings`
* `/api/oai/embeddings`

The following is an example of ai00 invocation based on Python and an out of the box tool class implementation
The following is an out-of-box example of Ai00 API invocations in Python:

```python
import openai
Expand Down
18 changes: 7 additions & 11 deletions README_zh.md → README.zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
[![All Contributors](https://img.shields.io/badge/all_contributors-5-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->



[English](README.md) | [中文](README_zh.md)
[![en](https://img.shields.io/badge/lang-en-blue.svg)](README.md)
[![zh](https://img.shields.io/badge/lang-zh-red.svg)](README.zh.md)

<div align="left">

Expand All @@ -41,11 +40,9 @@

### ⭕模型下载和转换

你必须(在构建时)[下载模型](https://huggingface.co/BlinkDL)并将其放置在`assets/models`中,如果你从源代码构建。
你可以从 HuggingFace 下载官方 RWKV World 系列模型,并使用提供的`convert_safetensors.py`进行转换。
如果你不想安装 Python,也可以前往[`web-rwkv`](https://github.com/cryscan/web-rwkv/releases)下载无依赖的转换器。
你必须[下载模型](https://huggingface.co/BlinkDL)并将其放置在`assets/models`中。

你可以在这里下载已经转换好的V4 模型[V5](https://huggingface.co/cgisky/AI00_RWKV_V5) 或者 [V6](https://huggingface.co/cgisky/ai00_rwkv_x060)
你可以在这里下载已经转换好的模型[V5](https://huggingface.co/cgisky/AI00_RWKV_V5) 或者 [V6](https://huggingface.co/cgisky/ai00_rwkv_x060)


## 安装、编译和使用
Expand All @@ -63,6 +60,7 @@
```bash
./ai00_rwkv_server
```

5. 打开浏览器,访问WebUI
[`https://localhost:65530`](https://localhost:65530)

Expand All @@ -77,7 +75,6 @@
cd ai00_rwkv_server
```


3. [下载模型](https://huggingface.co/cgisky/RWKV-safetensors-fp16)后把模型放在
`assets/models/`路径下,例如`assets/models/RWKV-x060-World-3B-v2-20240228-ctx4096.st`

Expand All @@ -87,7 +84,6 @@
cargo build --release
```


5. 编译完成后运行

```bash
Expand Down Expand Up @@ -128,7 +124,7 @@

## 📙目前可用的API

API 服务开启于 65530 端口, 数据输入已经输出格式遵循Openai API 规范。
API 服务开启于 65530 端口, 数据输入已经输出格式遵循 Openai API 规范。

- `/api/oai/v1/models`
- `/api/oai/models`
Expand All @@ -139,7 +135,7 @@ API 服务开启于 65530 端口, 数据输入已经输出格式遵循Openai API
- `/api/oai/v1/embeddings`
- `/api/oai/embeddings`

下面是一个基于Python和开箱即用工具类实现的ai00调用示例
下面是一个 Python 的 Ai00 API 调用示例,开箱即用:

```python
import openai
Expand Down
59 changes: 59 additions & 0 deletions crates/ai00-core/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Ai00-Core

This is the core library of the [Ai00](https://github.com/Ai00-X/ai00_server) server. It provides the following functionalities:

- Model/LoRA/initial state loading with auto version detecting;
- Samplers;
- BNF integration;
- State caching;
- Session management.

The purpose of this crate is to expose a state-less native inference API.

## Guide

The first thing is to start the runtime, an async task that serves all background stuff (e.g., model runtime, caches, session queue):

```rust
use ai00_core::model_route;

let (sender, receiver) = flume::unbounded::<ThreadRequest>();
tokio::spawn(model_route(receiver));
```

Then users can communicate with the runtime by sending `ThreadRequest`s, which are commands that request the runtime to do all kinds of stuff.
Check its definition:

```rust
pub enum ThreadRequest {
/// Acquire a list of current available adapters.
Adapter(Sender<AdapterList>),
/// Get the current runtime info.
Info(Sender<RuntimeInfo>),
/// Request the runtime to complement a prompt.
Generate {
request: Box<GenerateRequest>,
tokenizer: Arc<Tokenizer>,
sender: Sender<Token>,
},
/// Reload the runtime with custom config.
Reload {
request: Box<ReloadRequest>,
sender: Option<Sender<bool>>,
},
/// Unload the runtime.
Unload,
/// Additionally load an initial state.
StateLoad {
request: reload::State,
sender: Option<Sender<bool>>,
},
/// Unload an initial state given its id.
StateUnload(StateId),
/// Save the current model with config.
Save {
request: SaveRequest,
sender: Sender<bool>,
},
}
```
85 changes: 81 additions & 4 deletions crates/ai00-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub enum ThreadRequest {
Adapter(Sender<AdapterList>),
/// Get the current runtime info.
Info(Sender<RuntimeInfo>),
/// Request the server to complement a prompt.
/// Request the runtime to complement a prompt.
Generate {
request: Box<GenerateRequest>,
tokenizer: Arc<Tokenizer>,
Expand All @@ -101,6 +101,13 @@ pub enum ThreadRequest {
},
/// Unload the runtime.
Unload,
/// Additionally load an initial state.
StateLoad {
request: reload::State,
sender: Option<Sender<bool>>,
},
/// Unload an initial state given its id.
StateUnload(StateId),
/// Save the current model with config.
Save {
request: SaveRequest,
Expand Down Expand Up @@ -449,7 +456,7 @@ pub async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
let sender = {
let (sender, receiver) = flume::unbounded();
let env = env.clone();
tokio::task::spawn(crate::run::run(receiver, env));
tokio::spawn(crate::run::run(receiver, env));
sender
};

Expand Down Expand Up @@ -566,7 +573,7 @@ pub async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
}
Err(err) => {
callback(false);
log::error!("reload model failed: {}", err);
log::error!("load runtime failed: {}", err);
}
};
});
Expand All @@ -576,7 +583,7 @@ pub async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
tokio::spawn(async move {
let mut env = env.write().await;
let env = std::mem::take(&mut *env);
log::info!("model unloaded");
log::info!("runtime unloaded");

let context = match env {
Environment::Loaded(runtime) => runtime.context().clone(),
Expand All @@ -586,6 +593,76 @@ pub async fn model_route(receiver: Receiver<ThreadRequest>) -> Result<()> {
context.device.poll(Maintain::Wait);
});
}
ThreadRequest::StateLoad { request, sender } => {
let env = env.clone();
let load = async move {
let env = env.read().await;
let Environment::Loaded(runtime) = &*env else {
bail!("runtime not loaded")
};

let reload::State {
path,
name,
id,
default,
} = request;
let name = match name {
Some(name) => name,
None => match path.file_name() {
Some(name) => name.to_string_lossy().to_string(),
None => bail!("failed to parse state name"),
},
};
let file = File::open(&path).await?;
let data = unsafe { Mmap::map(&file)? };

let context = runtime.context();
let info = runtime.info();
let model = SafeTensors::deserialize(&data)?;
match load_init_state(context, info, model).await {
Ok(data) => {
let state = InitState {
name,
id,
data,
default,
};
log::info!("{:#?}", state);
runtime.load_init_state(state).await;
}
Err(err) => log::warn!("initial state not loaded: {}", err),
};
Ok(())
};
let callback = move |result: bool| {
if let Some(sender) = sender {
let _ = sender.send(result);
}
};
tokio::spawn(async move {
match load.await {
Ok(_) => {
callback(true);
log::info!("state loaded")
}
Err(err) => {
callback(false);
log::error!("load state failed: {}", err);
}
};
});
}
ThreadRequest::StateUnload(id) => {
let env = env.clone();
tokio::spawn(async move {
let env = env.read().await;
let Environment::Loaded(runtime) = &*env else {
return;
};
runtime.unload_init_state(id).await;
});
}
ThreadRequest::Generate {
request,
tokenizer,
Expand Down
16 changes: 16 additions & 0 deletions crates/ai00-core/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,22 @@ impl Runtime {
states
}

pub async fn load_init_state(&self, state: InitState) {
let mut caches = self.caches.lock().await;
caches.backed.insert(
state.id,
Cache {
state: Some(state),
cache: Trie::new(),
},
);
}

pub async fn unload_init_state(&self, id: StateId) {
let mut caches = self.caches.lock().await;
caches.backed.remove(&id);
}

pub async fn serialize_model(&self, path: PathBuf) -> Result<()> {
let model = self.model.clone();
let handle = tokio::task::spawn_blocking(move || {
Expand Down
2 changes: 1 addition & 1 deletion crates/ai00-server/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ pub mod oai;

pub use adapter::adapters;
pub use file::{dir, load_config, models, save_config, unzip};
pub use model::{info, load, save, state, unload};
pub use model::{info, load, load_state, save, state, unload};

pub async fn try_request_info(sender: Sender<ThreadRequest>) -> Result<RuntimeInfo> {
let (info_sender, info_receiver) = flume::unbounded();
Expand Down
40 changes: 31 additions & 9 deletions crates/ai00-server/src/api/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,37 @@ pub async fn load(depot: &mut Depot, req: &mut Request) -> StatusCode {
}
}

/// `/api/models/unload`.
#[handler]
pub async fn unload(depot: &mut Depot) -> StatusCode {
let ThreadState { sender, .. } = depot.obtain::<ThreadState>().unwrap();
let _ = sender.send(ThreadRequest::Unload);
while try_request_info(sender.clone()).await.is_ok() {}
StatusCode::OK
}

/// `/api/models/load_init_state`.
#[handler]
pub async fn load_state(depot: &mut Depot, req: &mut Request) -> StatusCode {
let ThreadState { sender, path } = depot.obtain::<ThreadState>().unwrap();
let (result_sender, result_receiver) = flume::unbounded();
let mut request: ai00_core::reload::State = req.parse_body().await.unwrap();

request.path = match build_path(path, &request.path) {
Ok(path) => path,
Err(_) => return StatusCode::NOT_FOUND,
};

let _ = sender.send(ThreadRequest::StateLoad {
request,
sender: Some(result_sender),
});
match result_receiver.recv_async().await.unwrap() {
true => StatusCode::OK,
false => StatusCode::INTERNAL_SERVER_ERROR,
}
}

/// `/api/models/save`.
#[handler]
pub async fn save(depot: &mut Depot, req: &mut Request) -> StatusCode {
Expand All @@ -133,12 +164,3 @@ pub async fn save(depot: &mut Depot, req: &mut Request) -> StatusCode {
false => StatusCode::INTERNAL_SERVER_ERROR,
}
}

/// `/api/models/unload`.
#[handler]
pub async fn unload(depot: &mut Depot) -> StatusCode {
let ThreadState { sender, .. } = depot.obtain::<ThreadState>().unwrap();
let _ = sender.send(ThreadRequest::Unload);
while try_request_info(sender.clone()).await.is_ok() {}
StatusCode::OK
}
3 changes: 2 additions & 1 deletion crates/ai00-server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async fn main() {
log::info!("{}\tversion: {}", bin_name, version);

let (sender, receiver) = flume::unbounded::<ThreadRequest>();
tokio::task::spawn(model_route(receiver));
tokio::spawn(model_route(receiver));

let (listen, config) = {
let path = args
Expand Down Expand Up @@ -204,6 +204,7 @@ async fn main() {
.push(Router::with_path("/models/save").post(api::save))
.push(Router::with_path("/models/load").post(api::load))
.push(Router::with_path("/models/unload").get(api::unload))
.push(Router::with_path("/models/state/load").post(api::load_state))
.push(Router::with_path("/models/state").get(api::state))
.push(Router::with_path("/models/list").get(api::models))
.push(Router::with_path("/files/unzip").post(api::unzip))
Expand Down

0 comments on commit 34f0279

Please sign in to comment.