Skip to content

Commit

Permalink
Streaming chat from Rust works
Browse files Browse the repository at this point in the history
  • Loading branch information
tachyonicbytes committed Jan 16, 2024
1 parent 710dbb7 commit 7b694d6
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 30 deletions.
72 changes: 44 additions & 28 deletions examples/tauri-postgres/src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use std::{
};

use tokio::sync::mpsc;
use std::env;

/**
* Structures
Expand Down Expand Up @@ -100,6 +101,7 @@ struct TerminalState {

struct AsyncProcInputTx {
inner: AsyncMutex<mpsc::Sender<String>>,
flag: AsyncMutex<bool>,
}

/******************
Expand Down Expand Up @@ -348,23 +350,51 @@ async fn tauri_chat(connection: State<'_, DbConnection>, question: &str, context
}
}

fn chat_token<R: tauri::Runtime>(message: String, manager: &impl Manager<R>) {
eprintln!("rs2js");
eprintln!("{}", message);
manager
.emit_all("chatToken", message)
.unwrap();
}

#[tauri::command(rename_all = "snake_case")]
async fn tauri_async_chat(connection: State<'_, DbConnection>, question: &str, context: &str) -> Result<(), ()> {
async fn start_chat(
question: String,
context: String,
state: tauri::State<'_, AsyncProcInputTx>,
connection: tauri::State<'_, DbConnection>
) -> Result<(), String> {
eprintln!("js2rs");
eprintln!("{}", question);

let mut temp = connection.llama.lock().await;
let llama2 = temp.as_mut().unwrap();

let mut temp2 = connection.writer.lock().await;
let writer = temp2.as_mut();
let mut temp = connection.writer.lock().await;
let writer = temp.as_mut();

// reset the flag, because we answer a new question
*state.flag.lock().await = true;

let model = "llama2:latest".to_string();

let generation_request = GenerationRequest::new(model, question.to_string());
let mut stream = llama2.generate_stream(generation_request).await.unwrap();
while let Some(result) = stream.next().await {
let async_proc_input_tx = state.inner.lock().await;
let flag = state.flag.lock().await.clone();

if flag {
break;
}

match result {
Ok(response) => {
let content = response.response;
write!(writer, "{}", content).unwrap(); // yield content here!
async_proc_input_tx
.send(response.response)
.await
.map_err(|e| e.to_string());
}
Err(err) => {
panic!("STILL TESTING THIS");
Expand All @@ -375,28 +405,13 @@ async fn tauri_async_chat(connection: State<'_, DbConnection>, question: &str, c
Ok(())
}

use std::env;
#[tauri::command(rename_all = "snake_case")]
async fn stop_chat(state: tauri::State<'_, AsyncProcInputTx>) -> Result<(), String> {
eprintln!("stop_chat");

fn rs2js<R: tauri::Runtime>(message: String, manager: &impl Manager<R>) {
eprintln!("rs2js");
eprintln!("{}", message);
manager
.emit_all("rs2js", format!("rs: {}", message))
.unwrap();
}
*state.flag.lock().await = true;

#[tauri::command]
async fn js2rs(
message: String,
state: tauri::State<'_, AsyncProcInputTx>,
) -> Result<(), String> {
eprintln!("js2rs");
eprintln!("{}", message);
let async_proc_input_tx = state.inner.lock().await;
async_proc_input_tx
.send(message)
.await
.map_err(|e| e.to_string())
Ok(())
}

async fn async_process_model(
Expand Down Expand Up @@ -500,6 +515,7 @@ fn main() {
})
.manage(AsyncProcInputTx {
inner: AsyncMutex::new(async_proc_input_tx),
flag: AsyncMutex::new(false),
})
// .setup(|app| {
// // terminal
Expand All @@ -524,7 +540,7 @@ fn main() {
tauri::async_runtime::spawn(async move {
loop {
if let Some(output) = async_proc_output_rx.recv().await {
rs2js(output, &app_handle);
chat_token(output, &app_handle);
}
}
});
Expand All @@ -541,8 +557,8 @@ fn main() {
async_write_to_pty,
async_resize_pty,
send_recv_postgres_terminal,
tauri_async_chat,
js2rs,
start_chat,
stop_chat,
])
.on_window_event(move |event| match event.event() {
WindowEvent::Destroyed => {
Expand Down
4 changes: 2 additions & 2 deletions examples/tauri-postgres/src/pages/Chat/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@ function Chat() {
.map((issue: any) => `${issue.title}\n${issue.description}`)
.join("\n\n\n");
console.log("startChat", { question: question, context: context ?? "" });
invoke("startChat", { question: question, context: context ?? "" });
invoke("start_chat", { question: question, context: context ?? "" });
};

const stopChat = async () => {
setWorking(false);
console.log("stopChat");
invoke("stopChat");
invoke("stop_chat");
};

useEffect(() => {
Expand Down

0 comments on commit 7b694d6

Please sign in to comment.