@@ -20,15 +20,15 @@ impl<G: GenServer> Clone for GenServerHandle<G> {
2020}
2121
2222impl < G : GenServer > GenServerHandle < G > {
23- pub ( crate ) fn new ( mut initial_state : G :: State ) -> Self {
23+ pub ( crate ) fn new ( initial_state : G :: State ) -> Self {
2424 let ( tx, mut rx) = mpsc:: channel :: < GenServerInMsg < G > > ( ) ;
2525 let handle = GenServerHandle { tx } ;
2626 let mut gen_server: G = GenServer :: new ( ) ;
2727 let handle_clone = handle. clone ( ) ;
2828 // Ignore the JoinHandle for now. Maybe we'll use it in the future
2929 let _join_handle = rt:: spawn ( async move {
3030 if gen_server
31- . run ( & handle, & mut rx, & mut initial_state)
31+ . run ( & handle, & mut rx, initial_state)
3232 . await
3333 . is_err ( )
3434 {
@@ -38,7 +38,7 @@ impl<G: GenServer> GenServerHandle<G> {
3838 handle_clone
3939 }
4040
41- pub ( crate ) fn new_blocking ( mut initial_state : G :: State ) -> Self {
41+ pub ( crate ) fn new_blocking ( initial_state : G :: State ) -> Self {
4242 let ( tx, mut rx) = mpsc:: channel :: < GenServerInMsg < G > > ( ) ;
4343 let handle = GenServerHandle { tx } ;
4444 let mut gen_server: G = GenServer :: new ( ) ;
@@ -47,7 +47,7 @@ impl<G: GenServer> GenServerHandle<G> {
4747 let _join_handle = rt:: spawn_blocking ( || {
4848 rt:: block_on ( async move {
4949 if gen_server
50- . run ( & handle, & mut rx, & mut initial_state)
50+ . run ( & handle, & mut rx, initial_state)
5151 . await
5252 . is_err ( )
5353 {
@@ -70,34 +70,34 @@ impl<G: GenServer> GenServerHandle<G> {
7070 } ) ?;
7171 match oneshot_rx. await {
7272 Ok ( result) => result,
73- Err ( _) => Err ( GenServerError :: ServerError ) ,
73+ Err ( _) => Err ( GenServerError :: Server ) ,
7474 }
7575 }
7676
7777 pub async fn cast ( & mut self , message : G :: CastMsg ) -> Result < ( ) , GenServerError > {
7878 self . tx
7979 . send ( GenServerInMsg :: Cast { message } )
80- . map_err ( |_error| GenServerError :: ServerError )
80+ . map_err ( |_error| GenServerError :: Server )
8181 }
8282}
8383
84- pub enum GenServerInMsg < A : GenServer > {
84+ pub enum GenServerInMsg < G : GenServer > {
8585 Call {
86- sender : oneshot:: Sender < Result < A :: OutMsg , GenServerError > > ,
87- message : A :: CallMsg ,
86+ sender : oneshot:: Sender < Result < G :: OutMsg , GenServerError > > ,
87+ message : G :: CallMsg ,
8888 } ,
8989 Cast {
90- message : A :: CastMsg ,
90+ message : G :: CastMsg ,
9191 } ,
9292}
9393
94- pub enum CallResponse < U > {
95- Reply ( U ) ,
96- Stop ( U ) ,
94+ pub enum CallResponse < G : GenServer > {
95+ Reply ( G :: State , G :: OutMsg ) ,
96+ Stop ( G :: OutMsg ) ,
9797}
9898
99- pub enum CastResponse {
100- NoReply ,
99+ pub enum CastResponse < G : GenServer > {
100+ NoReply ( G :: State ) ,
101101 Stop ,
102102}
103103
@@ -109,7 +109,7 @@ where
109109 type CastMsg : Send + Sized ;
110110 type OutMsg : Send + Sized ;
111111 type State : Clone + Send ;
112- type Error : Debug ;
112+ type Error : Debug + Send ;
113113
114114 fn new ( ) -> Self ;
115115
@@ -130,25 +130,46 @@ where
130130 & mut self ,
131131 handle : & GenServerHandle < Self > ,
132132 rx : & mut mpsc:: Receiver < GenServerInMsg < Self > > ,
133- state : & mut Self :: State ,
133+ state : Self :: State ,
134134 ) -> impl Future < Output = Result < ( ) , GenServerError > > + Send {
135135 async {
136- self . main_loop ( handle, rx, state) . await ?;
137- Ok ( ( ) )
136+ match self . init ( handle, state) . await {
137+ Ok ( new_state) => {
138+ self . main_loop ( handle, rx, new_state) . await ?;
139+ Ok ( ( ) )
140+ }
141+ Err ( err) => {
142+ tracing:: error!( "Initialization failed: {err:?}" ) ;
143+ Err ( GenServerError :: Initialization )
144+ }
145+ }
138146 }
139147 }
140148
149+ /// Initialization function. It's called before main loop. It
150+ /// can be overrided on implementations in case initial steps are
151+ /// required.
152+ fn init (
153+ & mut self ,
154+ _handle : & GenServerHandle < Self > ,
155+ state : Self :: State ,
156+ ) -> impl Future < Output = Result < Self :: State , Self :: Error > > + Send {
157+ async { Ok ( state) }
158+ }
159+
141160 fn main_loop (
142161 & mut self ,
143162 handle : & GenServerHandle < Self > ,
144163 rx : & mut mpsc:: Receiver < GenServerInMsg < Self > > ,
145- state : & mut Self :: State ,
164+ mut state : Self :: State ,
146165 ) -> impl Future < Output = Result < ( ) , GenServerError > > + Send {
147166 async {
148167 loop {
149- if !self . receive ( handle, rx, state) . await ? {
168+ let ( new_state, cont) = self . receive ( handle, rx, state) . await ?;
169+ if !cont {
150170 break ;
151171 }
172+ state = new_state;
152173 }
153174 tracing:: trace!( "Stopping GenServer" ) ;
154175 Ok ( ( ) )
@@ -159,81 +180,88 @@ where
159180 & mut self ,
160181 handle : & GenServerHandle < Self > ,
161182 rx : & mut mpsc:: Receiver < GenServerInMsg < Self > > ,
162- state : & mut Self :: State ,
163- ) -> impl std:: future:: Future < Output = Result < bool , GenServerError > > + Send {
164- async {
183+ state : Self :: State ,
184+ ) -> impl std:: future:: Future < Output = Result < ( Self :: State , bool ) , GenServerError > > + Send {
185+ async move {
165186 let message = rx. recv ( ) . await ;
166187
167188 // Save current state in case of a rollback
168189 let state_clone = state. clone ( ) ;
169190
170- let ( keep_running, error ) = match message {
191+ let ( keep_running, new_state ) = match message {
171192 Some ( GenServerInMsg :: Call { sender, message } ) => {
172- let ( keep_running, error , response) =
193+ let ( keep_running, new_state , response) =
173194 match AssertUnwindSafe ( self . handle_call ( message, handle, state) )
174195 . catch_unwind ( )
175196 . await
176197 {
177198 Ok ( response) => match response {
178- CallResponse :: Reply ( response) => ( true , None , Ok ( response) ) ,
179- CallResponse :: Stop ( response) => ( false , None , Ok ( response) ) ,
199+ CallResponse :: Reply ( new_state, response) => {
200+ ( true , new_state, Ok ( response) )
201+ }
202+ CallResponse :: Stop ( response) => ( false , state_clone, Ok ( response) ) ,
180203 } ,
181- Err ( error) => ( true , Some ( error) , Err ( GenServerError :: CallbackError ) ) ,
204+ Err ( error) => {
205+ tracing:: trace!(
206+ "Error in callback, reverting state - Error: '{error:?}'"
207+ ) ;
208+ ( true , state_clone, Err ( GenServerError :: Callback ) )
209+ }
182210 } ;
183211 // Send response back
184212 if sender. send ( response) . is_err ( ) {
185213 tracing:: trace!(
186214 "GenServer failed to send response back, client must have died"
187215 )
188216 } ;
189- ( keep_running, error )
217+ ( keep_running, new_state )
190218 }
191219 Some ( GenServerInMsg :: Cast { message } ) => {
192220 match AssertUnwindSafe ( self . handle_cast ( message, handle, state) )
193221 . catch_unwind ( )
194222 . await
195223 {
196224 Ok ( response) => match response {
197- CastResponse :: NoReply => ( true , None ) ,
198- CastResponse :: Stop => ( false , None ) ,
225+ CastResponse :: NoReply ( new_state ) => ( true , new_state ) ,
226+ CastResponse :: Stop => ( false , state_clone ) ,
199227 } ,
200- Err ( error) => ( true , Some ( error) ) ,
228+ Err ( error) => {
229+ tracing:: trace!(
230+ "Error in callback, reverting state - Error: '{error:?}'"
231+ ) ;
232+ ( true , state_clone)
233+ }
201234 }
202235 }
203236 None => {
204237 // Channel has been closed; won't receive further messages. Stop the server.
205- ( false , None )
238+ ( false , state )
206239 }
207240 } ;
208- if let Some ( error) = error {
209- tracing:: trace!( "Error in callback, reverting state - Error: '{error:?}'" ) ;
210- // Restore initial state (ie. dismiss any change)
211- * state = state_clone;
212- } ;
213- Ok ( keep_running)
241+ Ok ( ( new_state, keep_running) )
214242 }
215243 }
216244
217245 fn handle_call (
218246 & mut self ,
219247 message : Self :: CallMsg ,
220248 handle : & GenServerHandle < Self > ,
221- state : & mut Self :: State ,
222- ) -> impl std:: future:: Future < Output = CallResponse < Self :: OutMsg > > + Send ;
249+ state : Self :: State ,
250+ ) -> impl std:: future:: Future < Output = CallResponse < Self > > + Send ;
223251
224252 fn handle_cast (
225253 & mut self ,
226254 message : Self :: CastMsg ,
227255 handle : & GenServerHandle < Self > ,
228- state : & mut Self :: State ,
229- ) -> impl std:: future:: Future < Output = CastResponse > + Send ;
256+ state : Self :: State ,
257+ ) -> impl std:: future:: Future < Output = CastResponse < Self > > + Send ;
230258}
231259
232260#[ cfg( test) ]
233261mod tests {
234262 use super :: * ;
235263 use crate :: tasks:: send_after;
236- use std:: { process :: exit , thread, time:: Duration } ;
264+ use std:: { thread, time:: Duration } ;
237265 struct BadlyBehavedTask ;
238266
239267 #[ derive( Clone ) ]
@@ -261,17 +289,17 @@ mod tests {
261289 & mut self ,
262290 _: Self :: CallMsg ,
263291 _: & GenServerHandle < Self > ,
264- _: & mut Self :: State ,
265- ) -> CallResponse < Self :: OutMsg > {
292+ _: Self :: State ,
293+ ) -> CallResponse < Self > {
266294 CallResponse :: Stop ( ( ) )
267295 }
268296
269297 async fn handle_cast (
270298 & mut self ,
271299 _: Self :: CastMsg ,
272300 _: & GenServerHandle < Self > ,
273- _: & mut Self :: State ,
274- ) -> CastResponse {
301+ _: Self :: State ,
302+ ) -> CastResponse < Self > {
275303 rt:: sleep ( Duration :: from_millis ( 20 ) ) . await ;
276304 thread:: sleep ( Duration :: from_secs ( 2 ) ) ;
277305 CastResponse :: Stop
@@ -300,10 +328,13 @@ mod tests {
300328 & mut self ,
301329 message : Self :: CallMsg ,
302330 _: & GenServerHandle < Self > ,
303- state : & mut Self :: State ,
304- ) -> CallResponse < Self :: OutMsg > {
331+ state : Self :: State ,
332+ ) -> CallResponse < Self > {
305333 match message {
306- InMessage :: GetCount => CallResponse :: Reply ( OutMsg :: Count ( state. count ) ) ,
334+ InMessage :: GetCount => {
335+ let count = state. count ;
336+ CallResponse :: Reply ( state, OutMsg :: Count ( count) )
337+ }
307338 InMessage :: Stop => CallResponse :: Stop ( OutMsg :: Count ( state. count ) ) ,
308339 }
309340 }
@@ -312,12 +343,12 @@ mod tests {
312343 & mut self ,
313344 _: Self :: CastMsg ,
314345 handle : & GenServerHandle < Self > ,
315- state : & mut Self :: State ,
316- ) -> CastResponse {
346+ mut state : Self :: State ,
347+ ) -> CastResponse < Self > {
317348 state. count += 1 ;
318349 println ! ( "{:?}: good still alive" , thread:: current( ) . id( ) ) ;
319350 send_after ( Duration :: from_millis ( 100 ) , handle. to_owned ( ) , ( ) ) ;
320- CastResponse :: NoReply
351+ CastResponse :: NoReply ( state )
321352 }
322353 }
323354
0 commit comments