diff --git a/site/index.html b/site/index.html index e9957fe..ea9bc56 100644 --- a/site/index.html +++ b/site/index.html @@ -14,7 +14,6 @@ - diff --git a/site/src/bin/worker_data.rs b/site/src/bin/worker_data.rs deleted file mode 100644 index 8e5c828..0000000 --- a/site/src/bin/worker_data.rs +++ /dev/null @@ -1,6 +0,0 @@ -use site::data_agent::{DataTask, Postcard}; -use yew_agent::Registrable; - -fn main() { - DataTask::registrar().encoding::().register(); -} diff --git a/site/src/data_agent.rs b/site/src/data_agent.rs deleted file mode 100644 index c1bb74d..0000000 --- a/site/src/data_agent.rs +++ /dev/null @@ -1,36 +0,0 @@ -use crate::api::get_block; -use js_sys::Uint8Array; -use model::util::Data; -use serde::{Deserialize, Serialize}; -use wasm_bindgen::JsValue; -use yew_agent::prelude::*; -use yew_agent::Codec; - -pub struct Postcard; - -impl Codec for Postcard { - fn encode(input: I) -> JsValue - where - I: Serialize, - { - let data_json = serde_json::to_string(&input).expect("can't serialize a worker message"); - let data = data_json.as_bytes(); - let data = Uint8Array::from(data); - JsValue::from(data) - } - - fn decode(input: JsValue) -> O - where - O: for<'de> Deserialize<'de>, - { - let data = Uint8Array::from(input); - let data = data.to_vec(); - let data_json = String::from_utf8(data).expect("can't deserialize a worker message"); - serde_json::from_str(&data_json).expect("can't deserialize a worker message") - } -} - -#[oneshot] -pub async fn DataTask(n: usize) -> Data { - get_block(n).await -} diff --git a/site/src/get_data.rs b/site/src/get_data.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/site/src/get_data.rs @@ -0,0 +1 @@ + diff --git a/site/src/home.rs b/site/src/home.rs index 8edf003..9ac3f35 100644 --- a/site/src/home.rs +++ b/site/src/home.rs @@ -1,7 +1,6 @@ use crate::{api::{get_weights, send_weights, weights_delete}, queue::use_queue, counter::use_counter, - data_agent::DataTask, model_agent::{ControlSignal, ModelReactor}, Grid}; use model::{ @@ -14,12 +13,10 @@ use std::{ sync::{Arc, Mutex}, }; use wasm_bindgen::JsCast; +use wasm_bindgen_futures::spawn_local; use web_sys::{EventTarget, HtmlInputElement}; use yew::{function_component, functional::use_effect, html, prelude::*, Html}; -use yew_agent::{ - oneshot::use_oneshot_runner, - reactor::{use_reactor_bridge, ReactorEvent}, -}; +use yew_agent::reactor::{use_reactor_bridge, ReactorEvent}; #[function_component(Home)] pub fn home() -> Html { @@ -35,7 +32,6 @@ pub fn home() -> Html { let accuracy_handle = use_state(|| 0.0); let cache_size_handle = use_state(|| 5); let local_train_toggle = Arc::new(Mutex::new(false)); - let data_task = use_oneshot_runner::(); let data_cache_pipe = use_queue(VecDeque::::from(vec![])); let data_cache_futures = use_counter(0); let data_cache_external = use_state(|| 0); @@ -47,10 +43,6 @@ pub fn home() -> Html { let learning_rate_handle_response = learning_rate_handle.clone(); let block_size_handle_model = block_size_handle.clone(); - let cache_size_handle_model = cache_size_handle.clone(); - let data_cache_pipe_model = data_cache_pipe.clone(); - let data_cache_futures_model = data_cache_futures.clone(); - let data_agent = data_task.clone(); let model_sub = use_reactor_bridge::(move |event| match event { ReactorEvent::Output(status) => { iter_handle_response.set(status.iteration); @@ -59,26 +51,6 @@ pub fn home() -> Html { data_cache_external_response.set(status.data_len); block_size_handle_model.set(status.batch_size); learning_rate_handle_response.set(status.lrate); - let block_size_handle = block_size_handle_model.clone(); - let data_cache_futures = data_cache_futures_model.clone(); - let data_cache_pipe = data_cache_pipe_model.clone(); - let data_agent = data_agent.clone(); - let currently_cached = - *data_cache_futures + data_cache_pipe.len() + status.data_len; - for _ in 0..(*cache_size_handle_model - currently_cached) { - web_sys::console::log_1(&"caching".into()); - let block_size_handle = block_size_handle.clone(); - let data_cache_futures = data_cache_futures.clone(); - let data_cache_pipe = data_cache_pipe.clone(); - let data_agent = data_agent.clone(); - wasm_bindgen_futures::spawn_local(async move { - data_cache_futures.increase(); - let data = data_agent.run(*block_size_handle).await; - data_cache_pipe.push_back(data); - data_cache_futures.decrease(); - }); - - } } _ => (), }); @@ -87,9 +59,8 @@ pub fn home() -> Html { let model_sub_effect = model_sub.clone(); use_effect(move || { if data_cache_pipe_effect.len() > 0 { - wasm_bindgen_futures::spawn_local(async move { + spawn_local(async move { let data = data_cache_pipe_effect.pop_front().unwrap(); - model_sub_effect.send(ControlSignal::AddData(data)); }); } }); @@ -107,7 +78,7 @@ pub fn home() -> Html { Callback::from(move |grid: [[bool; 28]; 28]| { let inference_handler = inference_handler.clone(); let model = model.clone(); - wasm_bindgen_futures::spawn_local(async move { + spawn_local(async move { let grid_infer = grid .iter() .flatten() @@ -161,7 +132,7 @@ pub fn home() -> Html { let grid = (*grid_component_handler).clone(); let loss_handle = loss_handle.clone(); let model_handle = model_handle.clone(); - wasm_bindgen_futures::spawn_local(async move { + spawn_local(async move { let mut model = (*model_handle).clone(); let grid_train = grid .iter() @@ -192,7 +163,7 @@ pub fn home() -> Html { Callback::from(move |_| { let model_handle = model_handle.clone(); let learning_rate_handle = learning_rate_handle.clone(); - wasm_bindgen_futures::spawn_local(async move { + spawn_local(async move { let weights = get_weights().await; let new_model = Model::new( weights.weights, @@ -211,7 +182,7 @@ pub fn home() -> Html { let model_handle = model_handle.clone(); Callback::from(move |_| { let model_handle = model_handle.clone(); - wasm_bindgen_futures::spawn_local(async move { + spawn_local(async move { let model = (*model_handle).clone(); let weights = model.export_weights(); send_weights(Weights { weights }).await; @@ -229,7 +200,7 @@ pub fn home() -> Html { Callback::from(move |_| { let model_handle = model_handle.clone(); let learning_rate_handle = learning_rate_handle.clone(); - wasm_bindgen_futures::spawn_local(async move { + spawn_local(async move { weights_delete().await; let weights = get_weights().await; let new_model = Model::new( diff --git a/site/src/lib.rs b/site/src/lib.rs index 4ac89b6..1d358b5 100644 --- a/site/src/lib.rs +++ b/site/src/lib.rs @@ -1,5 +1,7 @@ +use wasm_logger::{init, Config}; use yew::prelude::*; -use yew_agent::{oneshot::OneshotProvider, reactor::ReactorProvider}; +use yew_agent::reactor::ReactorProvider; + pub mod api; pub mod grid; pub mod home; @@ -11,17 +13,13 @@ use hooks::{counter, queue}; pub mod model_agent; use model_agent::ModelReactor; -pub mod data_agent; -use data_agent::{DataTask, Postcard}; #[function_component] pub fn App() -> Html { - wasm_logger::init(wasm_logger::Config::default()); + init(Config::default()); html! { - path="/worker_data.js"> path="/worker_model.js"> > - > } } diff --git a/site/src/model_agent.rs b/site/src/model_agent.rs index c7d4102..8eb7a9e 100644 --- a/site/src/model_agent.rs +++ b/site/src/model_agent.rs @@ -7,10 +7,11 @@ use serde::{Deserialize, Serialize}; use std::{ collections::VecDeque, sync::{Arc, Mutex}, - time::Duration, + time::Duration }; use yew::platform::time::sleep; use yew_agent::prelude::*; +use crate::api::get_block; #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum ControlSignal { @@ -20,7 +21,7 @@ pub enum ControlSignal { SetWeights(Weights), SetBatchSize(usize), SetLearningRate(i64), - AddData(Data), + SetCacheSize(usize), } #[derive(Debug, Serialize, Deserialize)] @@ -31,79 +32,130 @@ pub struct ResponseSignal { pub batch_size: usize, pub lrate: f64, pub data_len: usize, + pub data_futures_len: usize, pub iteration: usize, + pub cache_size: usize, } -#[reactor] -pub async fn ModelReactor(mut scope: ReactorScope) { - web_sys::console::log_1(&"Model agent started".into()); - async fn respond( - scope: &mut ReactorScope, - weights: (Vec>, Vec>), - loss: f64, - acc: f64, - batch_size: usize, - lrate: f64, - iteration: usize, - data: &VecDeque, - ) { - scope - .send(ResponseSignal { - weights: Weights { weights }, - loss, - acc, - batch_size, - lrate, - data_len: data.len(), - iteration, - }) - .await - .unwrap(); +pub struct ModelData { + data_vec: Arc>>, + data_futures: Vec, + training: bool, + batch_size: usize, + lrate: f64, + loss: f64, + acc: f64, + iteration: usize, + cache_size: usize, + model: Model, + send_status: bool, + +} + +impl ModelData { + fn new() -> Self { + Self { + data_vec: Arc::new(Mutex::new(VecDeque::new())), + data_futures: Vec::new(), + training: false, + batch_size: 128, + lrate: 0.01, + loss: 0.0, + acc: 0.0, + iteration: 0, + cache_size: 5, + model: Model::new( + (random_dist(784, 128), random_dist(128, 10)), + (0.01, 0.01), + ), + send_status: true, + } } - let data_vec: Arc>> = Arc::new(Mutex::new(VecDeque::new())); - let mut training = false; - let mut batch_size: usize = 128; - let lrate: f64 = 0.01; - let mut loss = 0.0; - let mut acc = 0.0; - let mut iteration = 0; - let mut model = Model::new( - (random_dist(784, 128), random_dist(128, 10)), - (lrate, lrate), - ); - let mut send_status = true; + fn execute(&mut self) { + if self.training { + self.train(); + } + self.cache_data(); + } - loop { - if data_vec.lock().unwrap().len() != batch_size { - data_vec.lock().unwrap().clear(); + fn train(&mut self) { + let mut data = self.data_vec.lock().unwrap().pop_front(); + while data.is_some() && data.as_ref().unwrap().data.len() != self.batch_size { + data = self.data_vec.lock().unwrap().pop_front(); + } + if data.is_some() { + let (loss, acc) = train_handler_wrapper(&data.unwrap(), &mut self.model, self.batch_size); + self.loss = loss; + self.acc = acc; + self.iteration += 1; + self.send_status = true; } + } + + fn cache_data(&mut self) { + } - if send_status { - respond( - &mut scope, - model.export_weights(), - loss, - acc, - batch_size, - lrate, - iteration, - &data_vec.clone().lock().unwrap(), - ) - .await; - send_status = false; + fn add_futures(&mut self) { + if self.data_futures.len() + self.data_vec.lock().unwrap().len() < self.cache_size { + let mut data_vec = self.data_vec.lock().unwrap(); + for _ in 0..self.cache_size - self.data_futures.len() - data_vec.len() { + let data = get_block(self.batch_size); + } + } + } + + fn respond(&mut self) -> ResponseSignal { + ResponseSignal { + weights: Weights{ weights: self.model.export_weights()}, + loss: self.loss, + acc: self.acc, + batch_size: self.batch_size, + lrate: self.lrate, + data_len: self.data_vec.lock().unwrap().len(), + data_futures_len: self.data_futures.len(), + iteration: self.iteration, + cache_size: self.cache_size, } + } - if training { - if !data_vec.lock().unwrap().is_empty() { - (loss, acc) = train_handler_wrapper( - &data_vec.lock().unwrap().pop_front().unwrap(), - &mut model, - batch_size, - ); - iteration += 1; - } - send_status = true; + fn set_send_status(&mut self, status: bool) { + self.send_status = status; + } + + fn set_training(&mut self, status: bool) { + self.training = status; + if status { + self.iteration = 0; + } + } + + fn set_weights(&mut self, weights: Weights) { + self.model = Model::new(weights.weights, (self.lrate, self.lrate)); + } + + fn set_batch_size(&mut self, batch_size: usize) { + self.batch_size = batch_size; + } + + fn set_learning_rate(&mut self, lrate: f64) { + self.lrate = lrate; + self.model = Model::new(self.model.export_weights(), (lrate, lrate)); + } + + fn set_cache_size(&mut self, cache_size: usize) { + self.cache_size = cache_size; + } +} + +#[reactor] +pub async fn ModelReactor(mut scope: ReactorScope) { + web_sys::console::log_1(&"Model agent started".into()); + let mut data = ModelData::new(); + loop { + data.execute(); + if data.send_status { + scope.send(data.respond()).await.unwrap(); } futures::select! { c = scope.next() => { @@ -111,33 +163,31 @@ pub async fn ModelReactor(mut scope: ReactorScope match c { ControlSignal::Start => { web_sys::console::log_1(&"Starting training".into()); - iteration = 0; - training = true; + data.set_training(true); } ControlSignal::Stop => { web_sys::console::log_1(&"Stopping training".into()); - training = false; + data.set_training(false); } ControlSignal::GetStatus => { web_sys::console::log_1(&"Sending status".into()); - send_status = true; + data.set_send_status(true); } ControlSignal::SetWeights(w) => { web_sys::console::log_1(&"Setting weights".into()); - model = Model::new(w.weights, (lrate, lrate)); + data.set_weights(w); } ControlSignal::SetBatchSize(b) => { web_sys::console::log_1(&"Setting batch size".into()); - batch_size = b; + data.set_batch_size(b); } ControlSignal::SetLearningRate(l) => { web_sys::console::log_1(&"Setting learning rate".into()); - let l = l as f64; - model = Model::new(model.export_weights(), (l, l)); + data.set_learning_rate(l as f64); } - ControlSignal::AddData(d) => { - web_sys::console::log_1(&"Adding data".into()); - data_vec.lock().unwrap().push_back(d); + ControlSignal::SetCacheSize(c) => { + web_sys::console::log_1(&"Setting cache size".into()); + data.set_cache_size(c); } }; } else {