@@ -6,6 +6,7 @@ use std::{
6
6
use futures_util:: StreamExt ;
7
7
use global_error:: { GlobalError , GlobalResult } ;
8
8
use serde:: { de:: DeserializeOwned , Serialize } ;
9
+ use tokio:: sync:: watch;
9
10
use tracing:: Instrument as _;
10
11
use uuid:: Uuid ;
11
12
@@ -69,6 +70,8 @@ pub struct WorkflowCtx {
69
70
loop_location : Option < Location > ,
70
71
71
72
msg_ctx : MessageCtx ,
73
+ /// Used to stop workflow execution by the worker.
74
+ stop : watch:: Receiver < ( ) > ,
72
75
}
73
76
74
77
impl WorkflowCtx {
@@ -79,6 +82,7 @@ impl WorkflowCtx {
79
82
config : rivet_config:: Config ,
80
83
conn : rivet_connection:: Connection ,
81
84
data : PulledWorkflowData ,
85
+ stop : watch:: Receiver < ( ) > ,
82
86
) -> GlobalResult < Self > {
83
87
let msg_ctx = MessageCtx :: new ( & conn, data. ray_id ) . await ?;
84
88
let event_history = Arc :: new ( data. events ) ;
@@ -105,6 +109,7 @@ impl WorkflowCtx {
105
109
loop_location : None ,
106
110
107
111
msg_ctx,
112
+ stop,
108
113
} )
109
114
}
110
115
@@ -135,6 +140,9 @@ impl WorkflowCtx {
135
140
pub ( crate ) async fn run ( mut self ) -> WorkflowResult < ( ) > {
136
141
tracing:: debug!( name=%self . name, id=%self . workflow_id, "running workflow" ) ;
137
142
143
+ // Check for stop before running
144
+ self . check_stop ( ) ?;
145
+
138
146
// Lookup workflow
139
147
let workflow = self . registry . get_workflow ( & self . name ) ?;
140
148
@@ -176,8 +184,10 @@ impl WorkflowCtx {
176
184
}
177
185
}
178
186
Err ( err) => {
187
+ let wake_immediate = err. wake_immediate ( ) ;
188
+
179
189
// Retry the workflow if its recoverable
180
- let deadline_ts = if let Some ( deadline_ts) = err. deadline_ts ( ) {
190
+ let wake_deadline_ts = if let Some ( deadline_ts) = err. deadline_ts ( ) {
181
191
Some ( deadline_ts)
182
192
} else {
183
193
None
@@ -217,8 +227,8 @@ impl WorkflowCtx {
217
227
. commit_workflow (
218
228
self . workflow_id ,
219
229
& self . name ,
220
- false ,
221
- deadline_ts ,
230
+ wake_immediate ,
231
+ wake_deadline_ts ,
222
232
wake_signals,
223
233
wake_sub_workflow,
224
234
& err_str,
@@ -423,6 +433,7 @@ impl WorkflowCtx {
423
433
loop_location : self . loop_location . clone ( ) ,
424
434
425
435
msg_ctx : self . msg_ctx . clone ( ) ,
436
+ stop : self . stop . clone ( ) ,
426
437
}
427
438
}
428
439
@@ -434,6 +445,21 @@ impl WorkflowCtx {
434
445
435
446
branch
436
447
}
448
+
449
+ pub ( crate ) fn check_stop ( & self ) -> WorkflowResult < ( ) > {
450
+ if self . stop . has_changed ( ) . unwrap_or ( true ) {
451
+ Err ( WorkflowError :: WorkflowStopped )
452
+ } else {
453
+ Ok ( ( ) )
454
+ }
455
+ }
456
+
457
+ pub ( crate ) async fn wait_stop ( & self ) -> WorkflowResult < ( ) > {
458
+ // We have to clone here because this function can't have a mutable reference to self. The state of
459
+ // the stop channel doesn't matter because it only ever receives one message
460
+ let _ = self . stop . clone ( ) . changed ( ) . await ;
461
+ Err ( WorkflowError :: WorkflowStopped )
462
+ }
437
463
}
438
464
439
465
impl WorkflowCtx {
@@ -459,6 +485,8 @@ impl WorkflowCtx {
459
485
I : ActivityInput ,
460
486
<I as ActivityInput >:: Activity : Activity < Input = I > ,
461
487
{
488
+ self . check_stop ( ) . map_err ( GlobalError :: raw) ?;
489
+
462
490
let event_id = EventId :: new ( I :: Activity :: NAME , & input) ;
463
491
464
492
let history_res = self
@@ -556,6 +584,8 @@ impl WorkflowCtx {
556
584
/// short circuit in the event of an error to make sure activity side effects are recorded.
557
585
#[ tracing:: instrument( skip_all) ]
558
586
pub async fn join < T : Executable > ( & mut self , exec : T ) -> GlobalResult < T :: Output > {
587
+ self . check_stop ( ) . map_err ( GlobalError :: raw) ?;
588
+
559
589
exec. execute ( self ) . await
560
590
}
561
591
@@ -571,7 +601,7 @@ impl WorkflowCtx {
571
601
Ok ( inner_err) => {
572
602
// Despite "history diverged" errors being unrecoverable, they should not have be returned
573
603
// by this function because the state of the history is already messed up and no new
574
- // workflow items can be run.
604
+ // workflow items should be run.
575
605
if !inner_err. is_recoverable ( )
576
606
&& !matches ! ( * inner_err, WorkflowError :: HistoryDiverged ( _) )
577
607
{
@@ -599,6 +629,8 @@ impl WorkflowCtx {
599
629
/// received, the workflow will be woken up and continue.
600
630
#[ tracing:: instrument( skip_all) ]
601
631
pub async fn listen < T : Listen > ( & mut self ) -> GlobalResult < T > {
632
+ self . check_stop ( ) . map_err ( GlobalError :: raw) ?;
633
+
602
634
let history_res = self
603
635
. cursor
604
636
. compare_signal ( self . version )
@@ -648,6 +680,7 @@ impl WorkflowCtx {
648
680
tokio:: select! {
649
681
_ = wake_sub. next( ) => { } ,
650
682
_ = interval. tick( ) => { } ,
683
+ res = self . wait_stop( ) => res. map_err( GlobalError :: raw) ?,
651
684
}
652
685
}
653
686
} ;
@@ -664,6 +697,8 @@ impl WorkflowCtx {
664
697
& mut self ,
665
698
listener : & T ,
666
699
) -> GlobalResult < <T as CustomListener >:: Output > {
700
+ self . check_stop ( ) . map_err ( GlobalError :: raw) ?;
701
+
667
702
let history_res = self
668
703
. cursor
669
704
. compare_signal ( self . version )
@@ -713,6 +748,7 @@ impl WorkflowCtx {
713
748
tokio:: select! {
714
749
_ = wake_sub. next( ) => { } ,
715
750
_ = interval. tick( ) => { } ,
751
+ res = self . wait_stop( ) => res. map_err( GlobalError :: raw) ?,
716
752
}
717
753
}
718
754
} ;
@@ -756,6 +792,8 @@ impl WorkflowCtx {
756
792
F : for < ' a > FnMut ( & ' a mut WorkflowCtx , & ' a mut S ) -> AsyncResult < ' a , Loop < T > > ,
757
793
T : Serialize + DeserializeOwned ,
758
794
{
795
+ self . check_stop ( ) . map_err ( GlobalError :: raw) ?;
796
+
759
797
let history_res = self
760
798
. cursor
761
799
. compare_loop ( self . version )
@@ -806,6 +844,8 @@ impl WorkflowCtx {
806
844
tracing:: debug!( name=%self . name, id=%self . workflow_id, "running loop" ) ;
807
845
808
846
loop {
847
+ self . check_stop ( ) . map_err ( GlobalError :: raw) ?;
848
+
809
849
let start_instant = Instant :: now ( ) ;
810
850
811
851
// Create a new branch for each iteration of the loop at location {...loop location, iteration idx}
@@ -928,6 +968,8 @@ impl WorkflowCtx {
928
968
929
969
#[ tracing:: instrument( skip_all) ]
930
970
pub async fn sleep_until ( & mut self , time : impl TsToMillis ) -> GlobalResult < ( ) > {
971
+ self . check_stop ( ) . map_err ( GlobalError :: raw) ?;
972
+
931
973
let history_res = self
932
974
. cursor
933
975
. compare_sleep ( self . version )
@@ -969,7 +1011,10 @@ impl WorkflowCtx {
969
1011
else if duration < self . db . worker_poll_interval ( ) . as_millis ( ) as i64 + 1 {
970
1012
tracing:: debug!( name=%self . name, id=%self . workflow_id, %deadline_ts, "sleeping in memory" ) ;
971
1013
972
- tokio:: time:: sleep ( Duration :: from_millis ( duration. try_into ( ) ?) ) . await ;
1014
+ tokio:: select! {
1015
+ _ = tokio:: time:: sleep( Duration :: from_millis( duration. try_into( ) ?) ) => { } ,
1016
+ res = self . wait_stop( ) => res?,
1017
+ }
973
1018
}
974
1019
// Workflow sleep
975
1020
else {
@@ -1008,6 +1053,8 @@ impl WorkflowCtx {
1008
1053
& mut self ,
1009
1054
time : impl TsToMillis ,
1010
1055
) -> GlobalResult < Option < T > > {
1056
+ self . check_stop ( ) . map_err ( GlobalError :: raw) ?;
1057
+
1011
1058
let history_res = self
1012
1059
. cursor
1013
1060
. compare_sleep ( self . version )
@@ -1122,6 +1169,7 @@ impl WorkflowCtx {
1122
1169
tokio:: select! {
1123
1170
_ = wake_sub. next( ) => { } ,
1124
1171
_ = interval. tick( ) => { } ,
1172
+ res = self . wait_stop( ) => res?,
1125
1173
}
1126
1174
}
1127
1175
} )
@@ -1173,6 +1221,7 @@ impl WorkflowCtx {
1173
1221
tokio:: select! {
1174
1222
_ = wake_sub. next( ) => { } ,
1175
1223
_ = interval. tick( ) => { } ,
1224
+ res = self . wait_stop( ) => res. map_err( GlobalError :: raw) ?,
1176
1225
}
1177
1226
}
1178
1227
} ;
@@ -1205,6 +1254,8 @@ impl WorkflowCtx {
1205
1254
/// Represents a removed workflow step.
1206
1255
#[ tracing:: instrument( skip_all) ]
1207
1256
pub async fn removed < T : Removed > ( & mut self ) -> GlobalResult < ( ) > {
1257
+ self . check_stop ( ) . map_err ( GlobalError :: raw) ?;
1258
+
1208
1259
// Existing event
1209
1260
if self
1210
1261
. cursor
@@ -1242,6 +1293,8 @@ impl WorkflowCtx {
1242
1293
/// inserts a version check event.
1243
1294
#[ tracing:: instrument( skip_all) ]
1244
1295
pub async fn check_version ( & mut self , current_version : usize ) -> GlobalResult < usize > {
1296
+ self . check_stop ( ) . map_err ( GlobalError :: raw) ?;
1297
+
1245
1298
if current_version == 0 {
1246
1299
return Err ( GlobalError :: raw ( WorkflowError :: InvalidVersion (
1247
1300
"version for `check_version` must be greater than 0" . into ( ) ,
0 commit comments