diff --git a/site/Cargo.toml b/site/Cargo.toml index ac70d35..a050814 100644 --- a/site/Cargo.toml +++ b/site/Cargo.toml @@ -19,7 +19,7 @@ yew = { path = "../yew/packages/yew", features = ["csr"] } # yew-agent = { git = "https://github.com/yewstack/yew/" } yew-agent = { path = "../yew/packages/yew-agent" } postcard = "1.0" -wasm-bindgen = "0.2" +wasm-bindgen = "0.2.90" getrandom = { version = "0.2", features = ["js"] } wasm-logger = "0.2.0" wasm-bindgen-futures = "0.4.40" diff --git a/site/src/api.rs b/site/src/api.rs index c0579ec..c101034 100644 --- a/site/src/api.rs +++ b/site/src/api.rs @@ -1,8 +1,4 @@ use model::util::{Data, DataInfo, DataSingle, Weights}; -use wasm_bindgen::prelude::*; -use wasm_bindgen::JsValue; -use wasm_bindgen_futures::JsFuture; -use web_sys::{Request, RequestInit, RequestMode, Response}; use reqwest::Client; const API_URL: &str = "http://127.0.0.0:8000"; @@ -40,20 +36,21 @@ pub async fn get_sample() -> DataSingle { data.data.get(0).unwrap().clone() } -pub async fn get_block(block_size: usize) -> Sendable> { - let data_info = DataInfo { block: block_size }; - let serialized_data_info = serde_json::to_string(&data_info).map_err(|e| JsValue::from_str(&e.to_string())).unwrap(); - // let mut opts = RequestInit::new(); - // opts.method("POST"); - // opts.mode(RequestMode::Cors); - // opts.body(Some(&JsValue::from_str(&serialized_data_info))); - - let request = Request::new(Method::POST, &format!("{}/datablock", API_URL)) - .mode(RequestMode::Cors) - .header("Content-Type", "application/json") - .header("Accept", "application/json"); - - let window = web_sys::window().unwrap(); +pub async fn get_block(block_size: usize) -> Data { + let client = Client::new(); + serde_json::from_str( + &client + .post(format!("{}/datablock", API_URL)) + .json(&DataInfo { block: block_size }) + .header("Content-Type", "application/json") + .header("Accept", "application/json") + .send() + .await + .unwrap() + .text() + .await + .unwrap(), + ).unwrap() } pub async fn send_weights(weights: Weights) { diff --git a/site/src/model_agent.rs b/site/src/model_agent.rs index 7ffc529..b130424 100644 --- a/site/src/model_agent.rs +++ b/site/src/model_agent.rs @@ -1,4 +1,4 @@ -use futures::{FutureExt, SinkExt, StreamExt, future::Future, select}; +use futures::{FutureExt, SinkExt, StreamExt}; use model::{ util::{random_dist, train_handler_wrapper, Data, Weights}, Model, @@ -11,8 +11,8 @@ use std::{ }; use yew::platform::time::sleep; use yew_agent::prelude::*; -use crate::api::{get_block, Sendable}; -use wasm_bindgen::JsValue; +use crate::api::get_block; +use wasm_bindgen_futures::spawn_local; #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum ControlSignal { @@ -38,14 +38,10 @@ pub struct ResponseSignal { pub cache_size: usize, } -struct SendableFuture(Box, Sendable>> + Send>); - -unsafe impl Send for SendableFuture {} - - pub struct ModelData { data_vec: Arc>>, - data_futures: Arc>>, + data_given: i64, + data_taken: i64, training: bool, batch_size: usize, lrate: f64, @@ -62,7 +58,8 @@ impl ModelData { fn new() -> Self { Self { data_vec: Arc::new(Mutex::new(VecDeque::new())), - data_futures: Arc::new(Mutex::new(Vec::new())), + data_given: 0, + data_taken: 0, training: false, batch_size: 128, lrate: 0.01, @@ -93,7 +90,7 @@ impl ModelData { batch_size: self.batch_size, lrate: self.lrate, data_len: self.data_vec.lock().unwrap().len(), - data_futures_len: self.data_futures.lock().unwrap().len(), + data_futures_len: (self.data_given - self.data_taken) as usize, iteration: self.iteration, cache_size: self.cache_size, } @@ -103,8 +100,10 @@ impl ModelData { let mut data = self.data_vec.lock().unwrap().pop_front(); while data.is_some() && data.as_ref().unwrap().data.len() != self.batch_size { data = self.data_vec.lock().unwrap().pop_front(); + self.data_taken += 1; } if data.is_some() { + self.data_taken += 1; let (loss, acc) = train_handler_wrapper(&data.unwrap(), &mut self.model, self.batch_size); self.loss = loss; self.acc = acc; @@ -113,58 +112,20 @@ impl ModelData { } } -// // 2. Function to loop over the promises and call a polling function -// async fn loop_and_poll_requests(urls: Vec<&str>) -> Vec { -// let mut futures = vec![]; - -// for url in urls { -// let future = make_request(url).fuse(); // `.fuse()` allows us to poll a future multiple times -// futures.push(future); -// } - -// let mut results = Vec::new(); -// for future in futures { -// match poll_promise(future).await { -// Ok(result) => results.push(result), -// Err(e) => web_sys::console::log_1(&format!("Error: {:?}", e).into()), -// } -// } - -// results -// } - async fn cache_data(&mut self) { - self.add_futures().await; - - let mut data_futures = self.data_futures.lock().unwrap(); - - for i in 0..data_futures.len() { - let mut data_future = data_futures.get_mut(i).unwrap(); - } - - } - - async fn add_futures(&mut self) { - let difference = self.cache_size - self.data_futures.lock().unwrap().len() - self.data_vec.lock().unwrap().len(); - for _ in 0..difference { - let data_future = SendableFuture(Box::new(get_block(self.batch_size))); - self.data_futures.lock().unwrap().push(data_future); - } + let missing = (self.cache_size as i64 - self.data_vec.lock().unwrap().len() as i64) - (self.data_given - self.data_taken); + self.add_futures(missing).await; } - async fn poll_future(future: F) -> Result - where - F: std::future::Future, - { - let wait_time = Duration::from_millis(10); - let future = future.fuse(); - let timeout = sleep(wait_time).fuse(); - - futures::pin_mut!(future, timeout); - - select! { - result = future => Ok(result), - _ = timeout => Err(()), + async fn add_futures(&mut self, missing: i64) { + for _ in 0..missing { + self.data_given += 1; + let data_vec_handle = self.data_vec.clone(); + let batch_size = self.batch_size; + spawn_local(async move { + let data = get_block(batch_size).await; + data_vec_handle.lock().unwrap().push_back(data); + }); } }