Skip to content

Commit

Permalink
a ton of changes to bring data caching into the model itself - what I…
Browse files Browse the repository at this point in the history
… should have done from the beginning
  • Loading branch information
sachiniyer committed Nov 17, 2023
1 parent 6552b3a commit 2686d22
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 162 deletions.
1 change: 0 additions & 1 deletion site/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
<body class="min-h-screen">
<link data-trunk rel="rust" href="Cargo.toml" data-bin="app" data-type="main" data-weak-refs />
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker_model" data-type="worker" data-weak-refs />
<link data-trunk rel="rust" href="Cargo.toml" data-bin="worker_data" data-type="worker" data-weak-refs />
</body>

</html>
6 changes: 0 additions & 6 deletions site/src/bin/worker_data.rs

This file was deleted.

36 changes: 0 additions & 36 deletions site/src/data_agent.rs

This file was deleted.

1 change: 1 addition & 0 deletions site/src/get_data.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

45 changes: 8 additions & 37 deletions site/src/home.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use crate::{api::{get_weights, send_weights, weights_delete},
queue::use_queue,
counter::use_counter,
data_agent::DataTask,
model_agent::{ControlSignal, ModelReactor},
Grid};
use model::{
Expand All @@ -14,12 +13,10 @@ use std::{
sync::{Arc, Mutex},
};
use wasm_bindgen::JsCast;
use wasm_bindgen_futures::spawn_local;
use web_sys::{EventTarget, HtmlInputElement};
use yew::{function_component, functional::use_effect, html, prelude::*, Html};
use yew_agent::{
oneshot::use_oneshot_runner,
reactor::{use_reactor_bridge, ReactorEvent},
};
use yew_agent::reactor::{use_reactor_bridge, ReactorEvent};

#[function_component(Home)]
pub fn home() -> Html {
Expand All @@ -35,7 +32,6 @@ pub fn home() -> Html {
let accuracy_handle = use_state(|| 0.0);
let cache_size_handle = use_state(|| 5);
let local_train_toggle = Arc::new(Mutex::new(false));
let data_task = use_oneshot_runner::<DataTask>();
let data_cache_pipe = use_queue(VecDeque::<Data>::from(vec![]));
let data_cache_futures = use_counter(0);
let data_cache_external = use_state(|| 0);
Expand All @@ -47,10 +43,6 @@ pub fn home() -> Html {
let learning_rate_handle_response = learning_rate_handle.clone();

let block_size_handle_model = block_size_handle.clone();
let cache_size_handle_model = cache_size_handle.clone();
let data_cache_pipe_model = data_cache_pipe.clone();
let data_cache_futures_model = data_cache_futures.clone();
let data_agent = data_task.clone();
let model_sub = use_reactor_bridge::<ModelReactor, _>(move |event| match event {
ReactorEvent::Output(status) => {
iter_handle_response.set(status.iteration);
Expand All @@ -59,26 +51,6 @@ pub fn home() -> Html {
data_cache_external_response.set(status.data_len);
block_size_handle_model.set(status.batch_size);
learning_rate_handle_response.set(status.lrate);
let block_size_handle = block_size_handle_model.clone();
let data_cache_futures = data_cache_futures_model.clone();
let data_cache_pipe = data_cache_pipe_model.clone();
let data_agent = data_agent.clone();
let currently_cached =
*data_cache_futures + data_cache_pipe.len() + status.data_len;
for _ in 0..(*cache_size_handle_model - currently_cached) {
web_sys::console::log_1(&"caching".into());
let block_size_handle = block_size_handle.clone();
let data_cache_futures = data_cache_futures.clone();
let data_cache_pipe = data_cache_pipe.clone();
let data_agent = data_agent.clone();
wasm_bindgen_futures::spawn_local(async move {
data_cache_futures.increase();
let data = data_agent.run(*block_size_handle).await;
data_cache_pipe.push_back(data);
data_cache_futures.decrease();
});

}
}
_ => (),
});
Expand All @@ -87,9 +59,8 @@ pub fn home() -> Html {
let model_sub_effect = model_sub.clone();
use_effect(move || {
if data_cache_pipe_effect.len() > 0 {
wasm_bindgen_futures::spawn_local(async move {
spawn_local(async move {
let data = data_cache_pipe_effect.pop_front().unwrap();
model_sub_effect.send(ControlSignal::AddData(data));
});
}
});
Expand All @@ -107,7 +78,7 @@ pub fn home() -> Html {
Callback::from(move |grid: [[bool; 28]; 28]| {
let inference_handler = inference_handler.clone();
let model = model.clone();
wasm_bindgen_futures::spawn_local(async move {
spawn_local(async move {
let grid_infer = grid
.iter()
.flatten()
Expand Down Expand Up @@ -161,7 +132,7 @@ pub fn home() -> Html {
let grid = (*grid_component_handler).clone();
let loss_handle = loss_handle.clone();
let model_handle = model_handle.clone();
wasm_bindgen_futures::spawn_local(async move {
spawn_local(async move {
let mut model = (*model_handle).clone();
let grid_train = grid
.iter()
Expand Down Expand Up @@ -192,7 +163,7 @@ pub fn home() -> Html {
Callback::from(move |_| {
let model_handle = model_handle.clone();
let learning_rate_handle = learning_rate_handle.clone();
wasm_bindgen_futures::spawn_local(async move {
spawn_local(async move {
let weights = get_weights().await;
let new_model = Model::new(
weights.weights,
Expand All @@ -211,7 +182,7 @@ pub fn home() -> Html {
let model_handle = model_handle.clone();
Callback::from(move |_| {
let model_handle = model_handle.clone();
wasm_bindgen_futures::spawn_local(async move {
spawn_local(async move {
let model = (*model_handle).clone();
let weights = model.export_weights();
send_weights(Weights { weights }).await;
Expand All @@ -229,7 +200,7 @@ pub fn home() -> Html {
Callback::from(move |_| {
let model_handle = model_handle.clone();
let learning_rate_handle = learning_rate_handle.clone();
wasm_bindgen_futures::spawn_local(async move {
spawn_local(async move {
weights_delete().await;
let weights = get_weights().await;
let new_model = Model::new(
Expand Down
10 changes: 4 additions & 6 deletions site/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use wasm_logger::{init, Config};
use yew::prelude::*;
use yew_agent::{oneshot::OneshotProvider, reactor::ReactorProvider};
use yew_agent::reactor::ReactorProvider;

pub mod api;
pub mod grid;
pub mod home;
Expand All @@ -11,17 +13,13 @@ use hooks::{counter, queue};

pub mod model_agent;
use model_agent::ModelReactor;
pub mod data_agent;
use data_agent::{DataTask, Postcard};

#[function_component]
pub fn App() -> Html {
wasm_logger::init(wasm_logger::Config::default());
init(Config::default());
html! {
<OneshotProvider<DataTask, Postcard> path="/worker_data.js">
<ReactorProvider<ModelReactor> path="/worker_model.js">
<Home />
</ReactorProvider<ModelReactor>>
</OneshotProvider<DataTask, Postcard>>
}
}
Loading

0 comments on commit 2686d22

Please sign in to comment.