Skip to content

Commit 52963da

Browse files
authored
Merge pull request #12 from lambdaclass/initialization
Improved initialization and state ownership
2 parents 28f7a7b + 620da61 commit 52963da

File tree

16 files changed

+389
-234
lines changed

16 files changed

+389
-234
lines changed

concurrency/src/tasks/error.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
#[derive(Debug)]
22
pub enum GenServerError {
3-
CallbackError,
4-
ServerError,
3+
Callback,
4+
Initialization,
5+
Server,
56
}
67

78
impl<T> From<spawned_rt::tasks::mpsc::SendError<T>> for GenServerError {
89
fn from(_value: spawned_rt::tasks::mpsc::SendError<T>) -> Self {
9-
Self::ServerError
10+
Self::Server
1011
}
1112
}

concurrency/src/tasks/gen_server.rs

Lines changed: 86 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@ impl<G: GenServer> Clone for GenServerHandle<G> {
2020
}
2121

2222
impl<G: GenServer> GenServerHandle<G> {
23-
pub(crate) fn new(mut initial_state: G::State) -> Self {
23+
pub(crate) fn new(initial_state: G::State) -> Self {
2424
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
2525
let handle = GenServerHandle { tx };
2626
let mut gen_server: G = GenServer::new();
2727
let handle_clone = handle.clone();
2828
// Ignore the JoinHandle for now. Maybe we'll use it in the future
2929
let _join_handle = rt::spawn(async move {
3030
if gen_server
31-
.run(&handle, &mut rx, &mut initial_state)
31+
.run(&handle, &mut rx, initial_state)
3232
.await
3333
.is_err()
3434
{
@@ -38,7 +38,7 @@ impl<G: GenServer> GenServerHandle<G> {
3838
handle_clone
3939
}
4040

41-
pub(crate) fn new_blocking(mut initial_state: G::State) -> Self {
41+
pub(crate) fn new_blocking(initial_state: G::State) -> Self {
4242
let (tx, mut rx) = mpsc::channel::<GenServerInMsg<G>>();
4343
let handle = GenServerHandle { tx };
4444
let mut gen_server: G = GenServer::new();
@@ -47,7 +47,7 @@ impl<G: GenServer> GenServerHandle<G> {
4747
let _join_handle = rt::spawn_blocking(|| {
4848
rt::block_on(async move {
4949
if gen_server
50-
.run(&handle, &mut rx, &mut initial_state)
50+
.run(&handle, &mut rx, initial_state)
5151
.await
5252
.is_err()
5353
{
@@ -70,34 +70,34 @@ impl<G: GenServer> GenServerHandle<G> {
7070
})?;
7171
match oneshot_rx.await {
7272
Ok(result) => result,
73-
Err(_) => Err(GenServerError::ServerError),
73+
Err(_) => Err(GenServerError::Server),
7474
}
7575
}
7676

7777
pub async fn cast(&mut self, message: G::CastMsg) -> Result<(), GenServerError> {
7878
self.tx
7979
.send(GenServerInMsg::Cast { message })
80-
.map_err(|_error| GenServerError::ServerError)
80+
.map_err(|_error| GenServerError::Server)
8181
}
8282
}
8383

84-
pub enum GenServerInMsg<A: GenServer> {
84+
pub enum GenServerInMsg<G: GenServer> {
8585
Call {
86-
sender: oneshot::Sender<Result<A::OutMsg, GenServerError>>,
87-
message: A::CallMsg,
86+
sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
87+
message: G::CallMsg,
8888
},
8989
Cast {
90-
message: A::CastMsg,
90+
message: G::CastMsg,
9191
},
9292
}
9393

94-
pub enum CallResponse<U> {
95-
Reply(U),
96-
Stop(U),
94+
pub enum CallResponse<G: GenServer> {
95+
Reply(G::State, G::OutMsg),
96+
Stop(G::OutMsg),
9797
}
9898

99-
pub enum CastResponse {
100-
NoReply,
99+
pub enum CastResponse<G: GenServer> {
100+
NoReply(G::State),
101101
Stop,
102102
}
103103

@@ -109,7 +109,7 @@ where
109109
type CastMsg: Send + Sized;
110110
type OutMsg: Send + Sized;
111111
type State: Clone + Send;
112-
type Error: Debug;
112+
type Error: Debug + Send;
113113

114114
fn new() -> Self;
115115

@@ -130,25 +130,46 @@ where
130130
&mut self,
131131
handle: &GenServerHandle<Self>,
132132
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
133-
state: &mut Self::State,
133+
state: Self::State,
134134
) -> impl Future<Output = Result<(), GenServerError>> + Send {
135135
async {
136-
self.main_loop(handle, rx, state).await?;
137-
Ok(())
136+
match self.init(handle, state).await {
137+
Ok(new_state) => {
138+
self.main_loop(handle, rx, new_state).await?;
139+
Ok(())
140+
}
141+
Err(err) => {
142+
tracing::error!("Initialization failed: {err:?}");
143+
Err(GenServerError::Initialization)
144+
}
145+
}
138146
}
139147
}
140148

149+
/// Initialization function. It's called before main loop. It
150+
/// can be overrided on implementations in case initial steps are
151+
/// required.
152+
fn init(
153+
&mut self,
154+
_handle: &GenServerHandle<Self>,
155+
state: Self::State,
156+
) -> impl Future<Output = Result<Self::State, Self::Error>> + Send {
157+
async { Ok(state) }
158+
}
159+
141160
fn main_loop(
142161
&mut self,
143162
handle: &GenServerHandle<Self>,
144163
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
145-
state: &mut Self::State,
164+
mut state: Self::State,
146165
) -> impl Future<Output = Result<(), GenServerError>> + Send {
147166
async {
148167
loop {
149-
if !self.receive(handle, rx, state).await? {
168+
let (new_state, cont) = self.receive(handle, rx, state).await?;
169+
if !cont {
150170
break;
151171
}
172+
state = new_state;
152173
}
153174
tracing::trace!("Stopping GenServer");
154175
Ok(())
@@ -159,81 +180,88 @@ where
159180
&mut self,
160181
handle: &GenServerHandle<Self>,
161182
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
162-
state: &mut Self::State,
163-
) -> impl std::future::Future<Output = Result<bool, GenServerError>> + Send {
164-
async {
183+
state: Self::State,
184+
) -> impl std::future::Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
185+
async move {
165186
let message = rx.recv().await;
166187

167188
// Save current state in case of a rollback
168189
let state_clone = state.clone();
169190

170-
let (keep_running, error) = match message {
191+
let (keep_running, new_state) = match message {
171192
Some(GenServerInMsg::Call { sender, message }) => {
172-
let (keep_running, error, response) =
193+
let (keep_running, new_state, response) =
173194
match AssertUnwindSafe(self.handle_call(message, handle, state))
174195
.catch_unwind()
175196
.await
176197
{
177198
Ok(response) => match response {
178-
CallResponse::Reply(response) => (true, None, Ok(response)),
179-
CallResponse::Stop(response) => (false, None, Ok(response)),
199+
CallResponse::Reply(new_state, response) => {
200+
(true, new_state, Ok(response))
201+
}
202+
CallResponse::Stop(response) => (false, state_clone, Ok(response)),
180203
},
181-
Err(error) => (true, Some(error), Err(GenServerError::CallbackError)),
204+
Err(error) => {
205+
tracing::trace!(
206+
"Error in callback, reverting state - Error: '{error:?}'"
207+
);
208+
(true, state_clone, Err(GenServerError::Callback))
209+
}
182210
};
183211
// Send response back
184212
if sender.send(response).is_err() {
185213
tracing::trace!(
186214
"GenServer failed to send response back, client must have died"
187215
)
188216
};
189-
(keep_running, error)
217+
(keep_running, new_state)
190218
}
191219
Some(GenServerInMsg::Cast { message }) => {
192220
match AssertUnwindSafe(self.handle_cast(message, handle, state))
193221
.catch_unwind()
194222
.await
195223
{
196224
Ok(response) => match response {
197-
CastResponse::NoReply => (true, None),
198-
CastResponse::Stop => (false, None),
225+
CastResponse::NoReply(new_state) => (true, new_state),
226+
CastResponse::Stop => (false, state_clone),
199227
},
200-
Err(error) => (true, Some(error)),
228+
Err(error) => {
229+
tracing::trace!(
230+
"Error in callback, reverting state - Error: '{error:?}'"
231+
);
232+
(true, state_clone)
233+
}
201234
}
202235
}
203236
None => {
204237
// Channel has been closed; won't receive further messages. Stop the server.
205-
(false, None)
238+
(false, state)
206239
}
207240
};
208-
if let Some(error) = error {
209-
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
210-
// Restore initial state (ie. dismiss any change)
211-
*state = state_clone;
212-
};
213-
Ok(keep_running)
241+
Ok((new_state, keep_running))
214242
}
215243
}
216244

217245
fn handle_call(
218246
&mut self,
219247
message: Self::CallMsg,
220248
handle: &GenServerHandle<Self>,
221-
state: &mut Self::State,
222-
) -> impl std::future::Future<Output = CallResponse<Self::OutMsg>> + Send;
249+
state: Self::State,
250+
) -> impl std::future::Future<Output = CallResponse<Self>> + Send;
223251

224252
fn handle_cast(
225253
&mut self,
226254
message: Self::CastMsg,
227255
handle: &GenServerHandle<Self>,
228-
state: &mut Self::State,
229-
) -> impl std::future::Future<Output = CastResponse> + Send;
256+
state: Self::State,
257+
) -> impl std::future::Future<Output = CastResponse<Self>> + Send;
230258
}
231259

232260
#[cfg(test)]
233261
mod tests {
234262
use super::*;
235263
use crate::tasks::send_after;
236-
use std::{process::exit, thread, time::Duration};
264+
use std::{thread, time::Duration};
237265
struct BadlyBehavedTask;
238266

239267
#[derive(Clone)]
@@ -261,17 +289,17 @@ mod tests {
261289
&mut self,
262290
_: Self::CallMsg,
263291
_: &GenServerHandle<Self>,
264-
_: &mut Self::State,
265-
) -> CallResponse<Self::OutMsg> {
292+
_: Self::State,
293+
) -> CallResponse<Self> {
266294
CallResponse::Stop(())
267295
}
268296

269297
async fn handle_cast(
270298
&mut self,
271299
_: Self::CastMsg,
272300
_: &GenServerHandle<Self>,
273-
_: &mut Self::State,
274-
) -> CastResponse {
301+
_: Self::State,
302+
) -> CastResponse<Self> {
275303
rt::sleep(Duration::from_millis(20)).await;
276304
thread::sleep(Duration::from_secs(2));
277305
CastResponse::Stop
@@ -300,10 +328,13 @@ mod tests {
300328
&mut self,
301329
message: Self::CallMsg,
302330
_: &GenServerHandle<Self>,
303-
state: &mut Self::State,
304-
) -> CallResponse<Self::OutMsg> {
331+
state: Self::State,
332+
) -> CallResponse<Self> {
305333
match message {
306-
InMessage::GetCount => CallResponse::Reply(OutMsg::Count(state.count)),
334+
InMessage::GetCount => {
335+
let count = state.count;
336+
CallResponse::Reply(state, OutMsg::Count(count))
337+
}
307338
InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)),
308339
}
309340
}
@@ -312,12 +343,12 @@ mod tests {
312343
&mut self,
313344
_: Self::CastMsg,
314345
handle: &GenServerHandle<Self>,
315-
state: &mut Self::State,
316-
) -> CastResponse {
346+
mut state: Self::State,
347+
) -> CastResponse<Self> {
317348
state.count += 1;
318349
println!("{:?}: good still alive", thread::current().id());
319350
send_after(Duration::from_millis(100), handle.to_owned(), ());
320-
CastResponse::NoReply
351+
CastResponse::NoReply(state)
321352
}
322353
}
323354

concurrency/src/threads/error.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
#[derive(Debug)]
22
pub enum GenServerError {
3-
CallbackError,
4-
ServerError,
3+
Callback,
4+
Initialization,
5+
Server,
56
}
67

78
impl<T> From<spawned_rt::threads::mpsc::SendError<T>> for GenServerError {
89
fn from(_value: spawned_rt::threads::mpsc::SendError<T>) -> Self {
9-
Self::ServerError
10+
Self::Server
1011
}
1112
}

0 commit comments

Comments
 (0)