Skip to content

Commit

Permalink
a few more changes, trying to count requests instead of sync
Browse files Browse the repository at this point in the history
  • Loading branch information
sachiniyer committed Jan 31, 2024
1 parent 941f32c commit 6100d7a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 79 deletions.
2 changes: 1 addition & 1 deletion site/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ yew = { path = "../yew/packages/yew", features = ["csr"] }
# yew-agent = { git = "https://github.com/yewstack/yew/" }
yew-agent = { path = "../yew/packages/yew-agent" }
postcard = "1.0"
wasm-bindgen = "0.2"
wasm-bindgen = "0.2.90"
getrandom = { version = "0.2", features = ["js"] }
wasm-logger = "0.2.0"
wasm-bindgen-futures = "0.4.40"
Expand Down
33 changes: 15 additions & 18 deletions site/src/api.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
use model::util::{Data, DataInfo, DataSingle, Weights};
use wasm_bindgen::prelude::*;
use wasm_bindgen::JsValue;
use wasm_bindgen_futures::JsFuture;
use web_sys::{Request, RequestInit, RequestMode, Response};
use reqwest::Client;

const API_URL: &str = "http://127.0.0.0:8000";
Expand Down Expand Up @@ -40,20 +36,21 @@ pub async fn get_sample() -> DataSingle {
data.data.get(0).unwrap().clone()
}

pub async fn get_block(block_size: usize) -> Sendable<Result<JsValue, JsValue>> {
let data_info = DataInfo { block: block_size };
let serialized_data_info = serde_json::to_string(&data_info).map_err(|e| JsValue::from_str(&e.to_string())).unwrap();
// let mut opts = RequestInit::new();
// opts.method("POST");
// opts.mode(RequestMode::Cors);
// opts.body(Some(&JsValue::from_str(&serialized_data_info)));

let request = Request::new(Method::POST, &format!("{}/datablock", API_URL))
.mode(RequestMode::Cors)
.header("Content-Type", "application/json")
.header("Accept", "application/json");

let window = web_sys::window().unwrap();
pub async fn get_block(block_size: usize) -> Data {
let client = Client::new();
serde_json::from_str(
&client
.post(format!("{}/datablock", API_URL))
.json(&DataInfo { block: block_size })
.header("Content-Type", "application/json")
.header("Accept", "application/json")
.send()
.await
.unwrap()
.text()
.await
.unwrap(),
).unwrap()
}

pub async fn send_weights(weights: Weights) {
Expand Down
81 changes: 21 additions & 60 deletions site/src/model_agent.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use futures::{FutureExt, SinkExt, StreamExt, future::Future, select};
use futures::{FutureExt, SinkExt, StreamExt};
use model::{
util::{random_dist, train_handler_wrapper, Data, Weights},
Model,
Expand All @@ -11,8 +11,8 @@ use std::{
};
use yew::platform::time::sleep;
use yew_agent::prelude::*;
use crate::api::{get_block, Sendable};
use wasm_bindgen::JsValue;
use crate::api::get_block;
use wasm_bindgen_futures::spawn_local;

#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum ControlSignal {
Expand All @@ -38,14 +38,10 @@ pub struct ResponseSignal {
pub cache_size: usize,
}

struct SendableFuture(Box<dyn Future<Output = Result<Sendable<JsValue>, Sendable<JsValue>>> + Send>);

unsafe impl Send for SendableFuture {}


pub struct ModelData {
data_vec: Arc<Mutex<VecDeque<Data>>>,
data_futures: Arc<Mutex<Vec<SendableFuture>>>,
data_given: i64,
data_taken: i64,
training: bool,
batch_size: usize,
lrate: f64,
Expand All @@ -62,7 +58,8 @@ impl ModelData {
fn new() -> Self {
Self {
data_vec: Arc::new(Mutex::new(VecDeque::new())),
data_futures: Arc::new(Mutex::new(Vec::new())),
data_given: 0,
data_taken: 0,
training: false,
batch_size: 128,
lrate: 0.01,
Expand Down Expand Up @@ -93,7 +90,7 @@ impl ModelData {
batch_size: self.batch_size,
lrate: self.lrate,
data_len: self.data_vec.lock().unwrap().len(),
data_futures_len: self.data_futures.lock().unwrap().len(),
data_futures_len: (self.data_given - self.data_taken) as usize,
iteration: self.iteration,
cache_size: self.cache_size,
}
Expand All @@ -103,8 +100,10 @@ impl ModelData {
let mut data = self.data_vec.lock().unwrap().pop_front();
while data.is_some() && data.as_ref().unwrap().data.len() != self.batch_size {
data = self.data_vec.lock().unwrap().pop_front();
self.data_taken += 1;
}
if data.is_some() {
self.data_taken += 1;
let (loss, acc) = train_handler_wrapper(&data.unwrap(), &mut self.model, self.batch_size);
self.loss = loss;
self.acc = acc;
Expand All @@ -113,58 +112,20 @@ impl ModelData {
}
}

// // 2. Function to loop over the promises and call a polling function
// async fn loop_and_poll_requests(urls: Vec<&str>) -> Vec<JsValue> {
// let mut futures = vec![];

// for url in urls {
// let future = make_request(url).fuse(); // `.fuse()` allows us to poll a future multiple times
// futures.push(future);
// }

// let mut results = Vec::new();
// for future in futures {
// match poll_promise(future).await {
// Ok(result) => results.push(result),
// Err(e) => web_sys::console::log_1(&format!("Error: {:?}", e).into()),
// }
// }

// results
// }

async fn cache_data(&mut self) {
self.add_futures().await;

let mut data_futures = self.data_futures.lock().unwrap();

for i in 0..data_futures.len() {
let mut data_future = data_futures.get_mut(i).unwrap();
}

}

async fn add_futures(&mut self) {
let difference = self.cache_size - self.data_futures.lock().unwrap().len() - self.data_vec.lock().unwrap().len();
for _ in 0..difference {
let data_future = SendableFuture(Box::new(get_block(self.batch_size)));
self.data_futures.lock().unwrap().push(data_future);
}
let missing = (self.cache_size as i64 - self.data_vec.lock().unwrap().len() as i64) - (self.data_given - self.data_taken);
self.add_futures(missing).await;
}

async fn poll_future<F>(future: F) -> Result<Data, ()>
where
F: std::future::Future<Output = Data>,
{
let wait_time = Duration::from_millis(10);
let future = future.fuse();
let timeout = sleep(wait_time).fuse();

futures::pin_mut!(future, timeout);

select! {
result = future => Ok(result),
_ = timeout => Err(()),
async fn add_futures(&mut self, missing: i64) {
for _ in 0..missing {
self.data_given += 1;
let data_vec_handle = self.data_vec.clone();
let batch_size = self.batch_size;
spawn_local(async move {
let data = get_block(batch_size).await;
data_vec_handle.lock().unwrap().push_back(data);
});
}
}

Expand Down

0 comments on commit 6100d7a

Please sign in to comment.