Skip to content

Commit

Permalink
need to make JsValue Sendable
Browse files Browse the repository at this point in the history
  • Loading branch information
sachiniyer committed Jan 27, 2024
1 parent 3fd5414 commit 941f32c
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 346 deletions.
5 changes: 4 additions & 1 deletion site/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ readme = "README.md"
repository = "https://github.com/sachiniyer/mnist-wasm.git"
version = "0.1.0"

[lib]
crate-type = ["cdylib", "rlib"]

[dependencies]
model = { path = "../model" }
# yew = { git = "https://github.com/yewstack/yew/", features = ["csr"] }
Expand All @@ -19,7 +22,7 @@ postcard = "1.0"
wasm-bindgen = "0.2"
getrandom = { version = "0.2", features = ["js"] }
wasm-logger = "0.2.0"
wasm-bindgen-futures = "0.4"
wasm-bindgen-futures = "0.4.40"
futures = "0.3"
web-sys = "0.3"
js-sys = "0.3"
Expand Down
34 changes: 17 additions & 17 deletions site/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use model::util::{Data, DataInfo, DataSingle, Weights};
use wasm_bindgen::prelude::*;
use wasm_bindgen::JsValue;
use wasm_bindgen_futures::JsFuture;
use web_sys::{Request, RequestInit, RequestMode};
use web_sys::{Request, RequestInit, RequestMode, Response};
use reqwest::Client;

const API_URL: &str = "http://127.0.0.0:8000";

pub struct Sendable<T: ?Sized>(pub Box<T>);

pub async fn get_weights() -> Weights {
let client = Client::new();
serde_json::from_str(
Expand Down Expand Up @@ -37,23 +40,20 @@ pub async fn get_sample() -> DataSingle {
data.data.get(0).unwrap().clone()
}

pub async fn get_block(block_size: usize) -> JsFuture {
let mut opts = RequestInit::new();
opts.method("POST");
opts.mode(RequestMode::Cors);
opts.body(Some(&JsValue::from_str(&serde_json::to_string(&DataInfo { block: block_size }).unwrap())));
let request = Request::new_with_str_and_init(
format!("{}/datablock", API_URL).as_str(),
&opts,
)
.unwrap();
request
.headers()
.set("Accept", "application/json")
.unwrap();
pub async fn get_block(block_size: usize) -> Sendable<Result<JsValue, JsValue>> {
let data_info = DataInfo { block: block_size };
let serialized_data_info = serde_json::to_string(&data_info).map_err(|e| JsValue::from_str(&e.to_string())).unwrap();
// let mut opts = RequestInit::new();
// opts.method("POST");
// opts.mode(RequestMode::Cors);
// opts.body(Some(&JsValue::from_str(&serialized_data_info)));

let request = Request::new(Method::POST, &format!("{}/datablock", API_URL))
.mode(RequestMode::Cors)
.header("Content-Type", "application/json")
.header("Accept", "application/json");

let window = web_sys::window().unwrap();
// return a future that is fulfilled when the request is complete
JsFuture::from(window.fetch_with_request(&request))
}

pub async fn send_weights(weights: Weights) {
Expand Down
56 changes: 25 additions & 31 deletions site/src/home.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
use crate::{api::{get_weights, send_weights, weights_delete},
queue::use_queue,
counter::use_counter,
model_agent::{ControlSignal, ModelReactor},
Grid};
use model::{
util,
util::{Data, Weights},
util::Weights,
Model,
};
use std::{
collections::VecDeque,
sync::{Arc, Mutex},
};
use std::sync::{Arc, Mutex};
use wasm_bindgen::JsCast;
use wasm_bindgen_futures::spawn_local;
use web_sys::{EventTarget, HtmlInputElement};
use yew::{function_component, functional::use_effect, html, prelude::*, Html};
use yew::{function_component, html, prelude::*, Html};
use yew_agent::reactor::{use_reactor_bridge, ReactorEvent};

#[function_component(Home)]
Expand All @@ -31,46 +26,45 @@ pub fn home() -> Html {
let learning_rate_handle = use_state(|| 0.035);
let accuracy_handle = use_state(|| 0.0);
let cache_size_handle = use_state(|| 5);
// can I change this to state as well?
let local_train_toggle = Arc::new(Mutex::new(false));
let data_cache_pipe = use_queue(VecDeque::<Data>::from(vec![]));
let data_cache_futures = use_counter(0);
let data_cache_external = use_state(|| 0);
let data_caching = use_state(||0);
let data_cached = use_state(|| 0);

let model_handle = use_state(|| {
Model::new(
(util::random_dist(784, 128), util::random_dist(128, 10)),
(*learning_rate_handle, *learning_rate_handle),
)
});

let iter_handle_response = iter_handle.clone();
let train_loss_handle_response = train_loss_handle.clone();
let accuracy_handle_response = accuracy_handle.clone();
let data_cache_external_response = data_cache_external.clone();
let data_caching_response = data_caching.clone();
let data_cached_response = data_cached.clone();
let learning_rate_handle_response = learning_rate_handle.clone();
let model_handle_response = model_handle.clone();

let block_size_handle_model = block_size_handle.clone();

let model_sub = use_reactor_bridge::<ModelReactor, _>(move |event| match event {
ReactorEvent::Output(status) => {
iter_handle_response.set(status.iteration);
train_loss_handle_response.set(status.loss);
accuracy_handle_response.set(status.acc);
data_cache_external_response.set(status.data_len);
data_cached_response.set(status.data_len);
data_caching_response.set(status.data_futures_len);
block_size_handle_model.set(status.batch_size);
learning_rate_handle_response.set(status.lrate);
model_handle_response.set(Model::new(
status.weights.weights,
(*learning_rate_handle_response, *learning_rate_handle_response),
));
}
_ => (),
});

let data_cache_pipe_effect = data_cache_pipe.clone();
let model_sub_effect = model_sub.clone();
use_effect(move || {
if data_cache_pipe_effect.len() > 0 {
spawn_local(async move {
let data = data_cache_pipe_effect.pop_front().unwrap();
});
}
});

let model_handle = use_state(|| {
Model::new(
(util::random_dist(784, 128), util::random_dist(128, 10)),
(*learning_rate_handle, *learning_rate_handle),
)
});

let infer_callback = {
let inference_handler = inference_handler.clone();
Expand Down Expand Up @@ -363,8 +357,8 @@ pub fn home() -> Html {
<p id="trainloss">{ format!("Loss: {}", *train_loss_handle) }</p>
<p id="acc">{ format!("Accuracy: {}", *accuracy_handle) }</p>
<p id="training"> { format!("Training: {}", *local_train_toggle.lock().unwrap()) }</p>
<p id="cached">{ format!("Caching: {}", *data_cache_futures + data_cache_pipe.len()) }</p>
<p id="cached">{ format!("Cached: {}", *data_cache_external) }</p>
<p id="cached">{ format!("Caching: {}", *data_caching) }</p>
<p id="cached">{ format!("Cached: {}", *data_cached) }</p>
</div>
</div>
</div>
Expand Down
138 changes: 0 additions & 138 deletions site/src/hooks/counter.rs

This file was deleted.

2 changes: 0 additions & 2 deletions site/src/hooks/mod.rs

This file was deleted.

Loading

0 comments on commit 941f32c

Please sign in to comment.