@@ -187,13 +187,15 @@ def __init__(self, server_context: "ServerInvocationContext", name, serde) -> No
187187
188188 def value (self ) -> RestateDurableFuture [Any ]:
189189 handle = self .server_context .vm .sys_get_promise (self .name )
190+ update_restate_context_is_replaying (self .server_context .vm )
190191 return self .server_context .create_future (handle , self .serde )
191192
192193 def resolve (self , value : Any ) -> Awaitable [None ]:
193194 vm : VMWrapper = self .server_context .vm
194195 assert self .serde is not None
195196 value_buffer = self .serde .serialize (value )
196197 handle = vm .sys_complete_promise_success (self .name , value_buffer )
198+ update_restate_context_is_replaying (self .server_context .vm )
197199
198200 async def await_point ():
199201 if not self .server_context .vm .is_completed (handle ):
@@ -206,6 +208,7 @@ def reject(self, message: str, code: int = 500) -> Awaitable[None]:
206208 vm : VMWrapper = self .server_context .vm
207209 py_failure = Failure (code = code , message = message )
208210 handle = vm .sys_complete_promise_failure (self .name , py_failure )
211+ update_restate_context_is_replaying (self .server_context .vm )
209212
210213 async def await_point ():
211214 if not self .server_context .vm .is_completed (handle ):
@@ -217,6 +220,7 @@ async def await_point():
217220 def peek (self ) -> Awaitable [Any | None ]:
218221 vm : VMWrapper = self .server_context .vm
219222 handle = vm .sys_peek_promise (self .name )
223+ update_restate_context_is_replaying (self .server_context .vm )
220224 serde = self .serde
221225 assert serde is not None
222226 return self .server_context .create_future (handle , serde )
@@ -263,6 +267,12 @@ def cancel(self):
263267 for task in to_cancel :
264268 task .cancel ()
265269
270+ restate_context_is_replaying = contextvars .ContextVar ('restate_context_is_replaying' , default = False )
271+
272+ def update_restate_context_is_replaying (vm : VMWrapper ):
273+ """Update the context var 'restate_context_is_replaying'. This should be called after each vm.sys_*"""
274+ restate_context_is_replaying .set (vm .is_replaying ())
275+
266276# pylint: disable=R0902
267277class ServerInvocationContext (ObjectContext ):
268278 """This class implements the context for the restate framework based on the server."""
@@ -289,13 +299,16 @@ def __init__(self,
289299
290300 async def enter (self ):
291301 """Invoke the user code."""
302+ update_restate_context_is_replaying (self .vm )
292303 try :
293304 in_buffer = self .invocation .input_buffer
294305 out_buffer = await invoke_handler (handler = self .handler , ctx = self , in_buffer = in_buffer )
306+ restate_context_is_replaying .set (False )
295307 self .vm .sys_write_output_success (bytes (out_buffer ))
296308 self .vm .sys_end ()
297309 except TerminalError as t :
298310 failure = Failure (code = t .status_code , message = t .message )
311+ restate_context_is_replaying .set (False )
299312 self .vm .sys_write_output_failure (failure )
300313 self .vm .sys_end ()
301314 # pylint: disable=W0718
@@ -341,6 +354,7 @@ async def leave(self):
341354
342355 def on_attempt_finished (self ):
343356 """Notify the attempt finished event."""
357+ restate_context_is_replaying .set (False )
344358 self .request_finished_event .set ()
345359 try :
346360 self .tasks .cancel ()
@@ -446,25 +460,31 @@ def get(self, name: str,
446460 type_hint : Optional [typing .Type [T ]] = None
447461 ) -> Awaitable [Optional [T ]]:
448462 handle = self .vm .sys_get_state (name )
463+ update_restate_context_is_replaying (self .vm )
449464 if isinstance (serde , DefaultSerde ):
450465 serde = serde .with_maybe_type (type_hint )
451466 return self .create_future (handle , serde ) # type: ignore
452467
453468 def state_keys (self ) -> Awaitable [List [str ]]:
454- return self .create_future (self .vm .sys_get_state_keys ())
469+ handle = self .vm .sys_get_state_keys ()
470+ update_restate_context_is_replaying (self .vm )
471+ return self .create_future (handle )
455472
456473 def set (self , name : str , value : T , serde : Serde [T ] = DefaultSerde ()) -> None :
457474 """Set the value associated with the given name."""
458475 if isinstance (serde , DefaultSerde ):
459476 serde = serde .with_maybe_type (type (value ))
460477 buffer = serde .serialize (value )
461478 self .vm .sys_set_state (name , bytes (buffer ))
479+ update_restate_context_is_replaying (self .vm )
462480
463481 def clear (self , name : str ) -> None :
464482 self .vm .sys_clear_state (name )
483+ update_restate_context_is_replaying (self .vm )
465484
466485 def clear_all (self ) -> None :
467486 self .vm .sys_clear_all_state ()
487+ update_restate_context_is_replaying (self .vm )
468488
469489 def request (self ) -> Request :
470490 return Request (
@@ -542,6 +562,7 @@ def run(self,
542562 serde = serde .with_maybe_type (type_hint )
543563
544564 handle = self .vm .sys_run (name )
565+ update_restate_context_is_replaying (self .vm )
545566
546567 if args is not None :
547568 noargs_action = functools .partial (action , * args )
@@ -566,6 +587,7 @@ def run_typed(
566587 options .type_hint = signature .return_annotation
567588 options .serde = options .serde .with_maybe_type (options .type_hint )
568589 handle = self .vm .sys_run (name )
590+ update_restate_context_is_replaying (self .vm )
569591
570592 func = functools .partial (action , * args , ** kwargs )
571593 self .run_coros_to_execute [handle ] = lambda : self .create_run_coroutine (handle , func , options .serde , options .max_attempts , options .max_retry_duration )
@@ -574,7 +596,9 @@ def run_typed(
574596 def sleep (self , delta : timedelta ) -> RestateDurableSleepFuture :
575597 # convert timedelta to milliseconds
576598 millis = int (delta .total_seconds () * 1000 )
577- return self .create_sleep_future (self .vm .sys_sleep (millis )) # type: ignore
599+ handle = self .vm .sys_sleep (millis )
600+ update_restate_context_is_replaying (self .vm )
601+ return self .create_sleep_future (handle ) # type: ignore
578602
579603 def do_call (self ,
580604 tpe : HandlerType [I , O ],
@@ -615,9 +639,11 @@ def do_raw_call(self,
615639 if send_delay :
616640 ms = int (send_delay .total_seconds () * 1000 )
617641 send_handle = self .vm .sys_send (service , handler , parameter , key , delay = ms , idempotency_key = idempotency_key , headers = headers_kvs )
642+ update_restate_context_is_replaying (self .vm )
618643 return ServerSendHandle (self , send_handle )
619644 if send :
620645 send_handle = self .vm .sys_send (service , handler , parameter , key , idempotency_key = idempotency_key , headers = headers_kvs )
646+ update_restate_context_is_replaying (self .vm )
621647 return ServerSendHandle (self , send_handle )
622648
623649 handle = self .vm .sys_call (service = service ,
@@ -626,6 +652,7 @@ def do_raw_call(self,
626652 key = key ,
627653 idempotency_key = idempotency_key ,
628654 headers = headers_kvs )
655+ update_restate_context_is_replaying (self .vm )
629656
630657 return self .create_call_future (handle = handle .result_handle ,
631658 invocation_id_handle = handle .invocation_id_handle ,
@@ -712,6 +739,7 @@ def awakeable(self,
712739 if isinstance (serde , DefaultSerde ):
713740 serde = serde .with_maybe_type (type_hint )
714741 name , handle = self .vm .sys_awakeable ()
742+ update_restate_context_is_replaying (self .vm )
715743 return name , self .create_future (handle , serde )
716744
717745 def resolve_awakeable (self ,
@@ -722,9 +750,11 @@ def resolve_awakeable(self,
722750 serde = serde .with_maybe_type (type (value ))
723751 buf = serde .serialize (value )
724752 self .vm .sys_resolve_awakeable (name , buf )
753+ update_restate_context_is_replaying (self .vm )
725754
726755 def reject_awakeable (self , name : str , failure_message : str , failure_code : int = 500 ) -> None :
727- return self .vm .sys_reject_awakeable (name , Failure (code = failure_code , message = failure_message ))
756+ self .vm .sys_reject_awakeable (name , Failure (code = failure_code , message = failure_message ))
757+ update_restate_context_is_replaying (self .vm )
728758
729759 def promise (self , name : str , serde : typing .Optional [Serde [T ]] = JsonSerde (), type_hint : Optional [typing .Type [T ]] = None ) -> DurablePromise [T ]:
730760 """Create a durable promise."""
@@ -740,6 +770,7 @@ def cancel_invocation(self, invocation_id: str):
740770 if invocation_id is None :
741771 raise ValueError ("invocation_id cannot be None" )
742772 self .vm .sys_cancel (invocation_id )
773+ update_restate_context_is_replaying (self .vm )
743774
744775 def attach_invocation (self , invocation_id : str , serde : Serde [T ] = DefaultSerde (),
745776 type_hint : Optional [typing .Type [T ]] = None
@@ -749,4 +780,5 @@ def attach_invocation(self, invocation_id: str, serde: Serde[T] = DefaultSerde()
749780 if isinstance (serde , DefaultSerde ):
750781 serde = serde .with_maybe_type (type_hint )
751782 handle = self .vm .attach_invocation (invocation_id )
783+ update_restate_context_is_replaying (self .vm )
752784 return self .create_future (handle , serde )
0 commit comments