Skip to content

Commit ea92b9d

Browse files
committed
feat: graceful worker shutdown (#2274)
<!-- Please make sure there is an issue that this PR is correlated to. --> Fixes RVT-4594 ## Changes <!-- If there are frontend changes, please include screenshots. -->
1 parent 15cabdb commit ea92b9d

File tree

20 files changed

+257
-89
lines changed

20 files changed

+257
-89
lines changed

examples/system-test/tests/client.ts

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ const client = new RivetClient({
2828
async function run() {
2929
let actorId: string | undefined;
3030
try {
31+
console.time(`create actor`);
3132
console.log("Creating actor", { region });
3233
const { actor } = await client.actor.create({
3334
project: RIVET_PROJECT,
@@ -54,16 +55,18 @@ async function run() {
5455
},
5556
...(BUILD_NAME === "ws-container"
5657
? {
57-
resources: {
58-
cpu: 100,
59-
memory: 100,
60-
},
61-
}
58+
resources: {
59+
cpu: 100,
60+
memory: 100,
61+
},
62+
}
6263
: {}),
6364
},
6465
});
6566
actorId = actor.id;
6667

68+
console.timeEnd(`create actor`);
69+
6770
const port = actor.network.ports.http;
6871

6972
const actorOrigin = `${port.protocol}://${port.hostname}:${port.port}${port.path ?? ""}`;
@@ -84,13 +87,13 @@ async function run() {
8487
//}
8588

8689
// Retry loop for HTTP health check
87-
console.time(`ready-${actorId}`);
90+
console.time(`ready ${actorId}`);
8891
while (true) {
8992
try {
9093
const response = await fetch(`${actorOrigin}/health`);
9194
if (response.ok) {
9295
console.log("Health check passed");
93-
console.timeEnd(`ready-${actorId}`);
96+
console.timeEnd(`ready ${actorId}`);
9497
break;
9598
} else {
9699
console.error(

packages/common/chirp-workflow/core/src/builder/workflow/message.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ impl<'a, M: Message> MessageBuilder<'a, M> {
7373

7474
#[tracing::instrument(skip_all)]
7575
pub async fn send(self) -> GlobalResult<()> {
76+
self.ctx.check_stop().map_err(GlobalError::raw)?;
77+
7678
if let Some(err) = self.error {
7779
return Err(err.into());
7880
}

packages/common/chirp-workflow/core/src/builder/workflow/signal.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ impl<'a, T: Signal + Serialize> SignalBuilder<'a, T> {
8686

8787
#[tracing::instrument(skip_all)]
8888
pub async fn send(self) -> GlobalResult<Uuid> {
89+
self.ctx.check_stop().map_err(GlobalError::raw)?;
90+
8991
if let Some(err) = self.error {
9092
return Err(err.into());
9193
}

packages/common/chirp-workflow/core/src/builder/workflow/sub_workflow.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ where
8888

8989
#[tracing::instrument(skip_all)]
9090
pub async fn dispatch(self) -> GlobalResult<Uuid> {
91+
self.ctx.check_stop().map_err(GlobalError::raw)?;
92+
9193
if let Some(err) = self.error {
9294
return Err(err.into());
9395
}
@@ -237,6 +239,8 @@ where
237239
pub async fn output(
238240
self,
239241
) -> GlobalResult<<<I as WorkflowInput>::Workflow as Workflow>::Output> {
242+
self.ctx.check_stop().map_err(GlobalError::raw)?;
243+
240244
if let Some(err) = self.error {
241245
return Err(err.into());
242246
}
@@ -288,6 +292,8 @@ where
288292
&self,
289293
sub_workflow_id: Uuid,
290294
) -> GlobalResult<<<I as WorkflowInput>::Workflow as Workflow>::Output> {
295+
self.ctx.check_stop().map_err(GlobalError::raw)?;
296+
291297
tracing::debug!(name=%self.ctx.name(), id=%self.ctx.workflow_id(), sub_workflow_name=%I::Workflow::NAME, ?sub_workflow_id, "waiting for sub workflow");
292298

293299
let mut wake_sub = self.ctx.db().wake_sub().await?;

packages/common/chirp-workflow/core/src/ctx/listen.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::{
1010

1111
/// Indirection struct to prevent invalid implementations of listen traits.
1212
pub struct ListenCtx<'a> {
13-
ctx: &'a mut WorkflowCtx,
13+
ctx: &'a WorkflowCtx,
1414
location: &'a Location,
1515
// Used by certain db drivers to know when to update internal indexes for signal wake conditions
1616
last_try: bool,
@@ -19,7 +19,7 @@ pub struct ListenCtx<'a> {
1919
}
2020

2121
impl<'a> ListenCtx<'a> {
22-
pub(crate) fn new(ctx: &'a mut WorkflowCtx, location: &'a Location) -> Self {
22+
pub(crate) fn new(ctx: &'a WorkflowCtx, location: &'a Location) -> Self {
2323
ListenCtx {
2424
ctx,
2525
location,

packages/common/chirp-workflow/core/src/ctx/workflow.rs

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::{
66
use futures_util::StreamExt;
77
use global_error::{GlobalError, GlobalResult};
88
use serde::{de::DeserializeOwned, Serialize};
9+
use tokio::sync::watch;
910
use tracing::Instrument as _;
1011
use uuid::Uuid;
1112

@@ -69,6 +70,8 @@ pub struct WorkflowCtx {
6970
loop_location: Option<Location>,
7071

7172
msg_ctx: MessageCtx,
73+
/// Used to stop workflow execution by the worker.
74+
stop: watch::Receiver<()>,
7275
}
7376

7477
impl WorkflowCtx {
@@ -79,6 +82,7 @@ impl WorkflowCtx {
7982
config: rivet_config::Config,
8083
conn: rivet_connection::Connection,
8184
data: PulledWorkflowData,
85+
stop: watch::Receiver<()>,
8286
) -> GlobalResult<Self> {
8387
let msg_ctx = MessageCtx::new(&conn, data.ray_id).await?;
8488
let event_history = Arc::new(data.events);
@@ -105,6 +109,7 @@ impl WorkflowCtx {
105109
loop_location: None,
106110

107111
msg_ctx,
112+
stop,
108113
})
109114
}
110115

@@ -135,6 +140,9 @@ impl WorkflowCtx {
135140
pub(crate) async fn run(mut self) -> WorkflowResult<()> {
136141
tracing::debug!(name=%self.name, id=%self.workflow_id, "running workflow");
137142

143+
// Check for stop before running
144+
self.check_stop()?;
145+
138146
// Lookup workflow
139147
let workflow = self.registry.get_workflow(&self.name)?;
140148

@@ -176,8 +184,10 @@ impl WorkflowCtx {
176184
}
177185
}
178186
Err(err) => {
187+
let wake_immediate = err.wake_immediate();
188+
179189
// Retry the workflow if its recoverable
180-
let deadline_ts = if let Some(deadline_ts) = err.deadline_ts() {
190+
let wake_deadline_ts = if let Some(deadline_ts) = err.deadline_ts() {
181191
Some(deadline_ts)
182192
} else {
183193
None
@@ -217,8 +227,8 @@ impl WorkflowCtx {
217227
.commit_workflow(
218228
self.workflow_id,
219229
&self.name,
220-
false,
221-
deadline_ts,
230+
wake_immediate,
231+
wake_deadline_ts,
222232
wake_signals,
223233
wake_sub_workflow,
224234
&err_str,
@@ -423,6 +433,7 @@ impl WorkflowCtx {
423433
loop_location: self.loop_location.clone(),
424434

425435
msg_ctx: self.msg_ctx.clone(),
436+
stop: self.stop.clone(),
426437
}
427438
}
428439

@@ -434,6 +445,21 @@ impl WorkflowCtx {
434445

435446
branch
436447
}
448+
449+
pub(crate) fn check_stop(&self) -> WorkflowResult<()> {
450+
if self.stop.has_changed().unwrap_or(true) {
451+
Err(WorkflowError::WorkflowStopped)
452+
} else {
453+
Ok(())
454+
}
455+
}
456+
457+
pub(crate) async fn wait_stop(&self) -> WorkflowResult<()> {
458+
// We have to clone here because this function can't have a mutable reference to self. The state of
459+
// the stop channel doesn't matter because it only ever receives one message
460+
let _ = self.stop.clone().changed().await;
461+
Err(WorkflowError::WorkflowStopped)
462+
}
437463
}
438464

439465
impl WorkflowCtx {
@@ -459,6 +485,8 @@ impl WorkflowCtx {
459485
I: ActivityInput,
460486
<I as ActivityInput>::Activity: Activity<Input = I>,
461487
{
488+
self.check_stop().map_err(GlobalError::raw)?;
489+
462490
let event_id = EventId::new(I::Activity::NAME, &input);
463491

464492
let history_res = self
@@ -556,6 +584,8 @@ impl WorkflowCtx {
556584
/// short circuit in the event of an error to make sure activity side effects are recorded.
557585
#[tracing::instrument(skip_all)]
558586
pub async fn join<T: Executable>(&mut self, exec: T) -> GlobalResult<T::Output> {
587+
self.check_stop().map_err(GlobalError::raw)?;
588+
559589
exec.execute(self).await
560590
}
561591

@@ -571,7 +601,7 @@ impl WorkflowCtx {
571601
Ok(inner_err) => {
572602
// Despite "history diverged" errors being unrecoverable, they should not have be returned
573603
// by this function because the state of the history is already messed up and no new
574-
// workflow items can be run.
604+
// workflow items should be run.
575605
if !inner_err.is_recoverable()
576606
&& !matches!(*inner_err, WorkflowError::HistoryDiverged(_))
577607
{
@@ -599,6 +629,8 @@ impl WorkflowCtx {
599629
/// received, the workflow will be woken up and continue.
600630
#[tracing::instrument(skip_all)]
601631
pub async fn listen<T: Listen>(&mut self) -> GlobalResult<T> {
632+
self.check_stop().map_err(GlobalError::raw)?;
633+
602634
let history_res = self
603635
.cursor
604636
.compare_signal(self.version)
@@ -648,6 +680,7 @@ impl WorkflowCtx {
648680
tokio::select! {
649681
_ = wake_sub.next() => {},
650682
_ = interval.tick() => {},
683+
res = self.wait_stop() => res.map_err(GlobalError::raw)?,
651684
}
652685
}
653686
};
@@ -664,6 +697,8 @@ impl WorkflowCtx {
664697
&mut self,
665698
listener: &T,
666699
) -> GlobalResult<<T as CustomListener>::Output> {
700+
self.check_stop().map_err(GlobalError::raw)?;
701+
667702
let history_res = self
668703
.cursor
669704
.compare_signal(self.version)
@@ -713,6 +748,7 @@ impl WorkflowCtx {
713748
tokio::select! {
714749
_ = wake_sub.next() => {},
715750
_ = interval.tick() => {},
751+
res = self.wait_stop() => res.map_err(GlobalError::raw)?,
716752
}
717753
}
718754
};
@@ -756,6 +792,8 @@ impl WorkflowCtx {
756792
F: for<'a> FnMut(&'a mut WorkflowCtx, &'a mut S) -> AsyncResult<'a, Loop<T>>,
757793
T: Serialize + DeserializeOwned,
758794
{
795+
self.check_stop().map_err(GlobalError::raw)?;
796+
759797
let history_res = self
760798
.cursor
761799
.compare_loop(self.version)
@@ -806,6 +844,8 @@ impl WorkflowCtx {
806844
tracing::debug!(name=%self.name, id=%self.workflow_id, "running loop");
807845

808846
loop {
847+
self.check_stop().map_err(GlobalError::raw)?;
848+
809849
let start_instant = Instant::now();
810850

811851
// Create a new branch for each iteration of the loop at location {...loop location, iteration idx}
@@ -928,6 +968,8 @@ impl WorkflowCtx {
928968

929969
#[tracing::instrument(skip_all)]
930970
pub async fn sleep_until(&mut self, time: impl TsToMillis) -> GlobalResult<()> {
971+
self.check_stop().map_err(GlobalError::raw)?;
972+
931973
let history_res = self
932974
.cursor
933975
.compare_sleep(self.version)
@@ -969,7 +1011,10 @@ impl WorkflowCtx {
9691011
else if duration < self.db.worker_poll_interval().as_millis() as i64 + 1 {
9701012
tracing::debug!(name=%self.name, id=%self.workflow_id, %deadline_ts, "sleeping in memory");
9711013

972-
tokio::time::sleep(Duration::from_millis(duration.try_into()?)).await;
1014+
tokio::select! {
1015+
_ = tokio::time::sleep(Duration::from_millis(duration.try_into()?)) => {},
1016+
res = self.wait_stop() => res?,
1017+
}
9731018
}
9741019
// Workflow sleep
9751020
else {
@@ -1008,6 +1053,8 @@ impl WorkflowCtx {
10081053
&mut self,
10091054
time: impl TsToMillis,
10101055
) -> GlobalResult<Option<T>> {
1056+
self.check_stop().map_err(GlobalError::raw)?;
1057+
10111058
let history_res = self
10121059
.cursor
10131060
.compare_sleep(self.version)
@@ -1122,6 +1169,7 @@ impl WorkflowCtx {
11221169
tokio::select! {
11231170
_ = wake_sub.next() => {},
11241171
_ = interval.tick() => {},
1172+
res = self.wait_stop() => res?,
11251173
}
11261174
}
11271175
})
@@ -1173,6 +1221,7 @@ impl WorkflowCtx {
11731221
tokio::select! {
11741222
_ = wake_sub.next() => {},
11751223
_ = interval.tick() => {},
1224+
res = self.wait_stop() => res.map_err(GlobalError::raw)?,
11761225
}
11771226
}
11781227
};
@@ -1205,6 +1254,8 @@ impl WorkflowCtx {
12051254
/// Represents a removed workflow step.
12061255
#[tracing::instrument(skip_all)]
12071256
pub async fn removed<T: Removed>(&mut self) -> GlobalResult<()> {
1257+
self.check_stop().map_err(GlobalError::raw)?;
1258+
12081259
// Existing event
12091260
if self
12101261
.cursor
@@ -1242,6 +1293,8 @@ impl WorkflowCtx {
12421293
/// inserts a version check event.
12431294
#[tracing::instrument(skip_all)]
12441295
pub async fn check_version(&mut self, current_version: usize) -> GlobalResult<usize> {
1296+
self.check_stop().map_err(GlobalError::raw)?;
1297+
12451298
if current_version == 0 {
12461299
return Err(GlobalError::raw(WorkflowError::InvalidVersion(
12471300
"version for `check_version` must be greater than 0".into(),

packages/common/chirp-workflow/core/src/db/fdb_sqlite_nats/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ impl Database for DatabaseFdbSqliteNats {
214214
);
215215

216216
// Shut down entire runtime
217-
rivet_runtime::shutdown();
217+
rivet_runtime::shutdown().await;
218218
}
219219
});
220220

packages/common/chirp-workflow/core/src/error.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ pub enum WorkflowError {
3434
#[error("workflow not found")]
3535
WorkflowNotFound,
3636

37+
#[error("workflow stopped")]
38+
WorkflowStopped,
39+
3740
#[error("history diverged: {0}")]
3841
HistoryDiverged(String),
3942

@@ -195,6 +198,10 @@ pub enum WorkflowError {
195198
}
196199

197200
impl WorkflowError {
201+
pub(crate) fn wake_immediate(&self) -> bool {
202+
matches!(self, WorkflowError::WorkflowStopped)
203+
}
204+
198205
/// Returns the next deadline for a workflow to be woken up again based on the error.
199206
pub(crate) fn deadline_ts(&self) -> Option<i64> {
200207
match self {
@@ -231,12 +238,13 @@ impl WorkflowError {
231238
| WorkflowError::NoSignalFound(_)
232239
| WorkflowError::NoSignalFoundAndSleep(_, _)
233240
| WorkflowError::SubWorkflowIncomplete(_)
234-
| WorkflowError::Sleep(_) => true,
241+
| WorkflowError::Sleep(_)
242+
| WorkflowError::WorkflowStopped => true,
235243
_ => false,
236244
}
237245
}
238246

239-
/// Any error that the workflow can try again on. Only used for printing.
247+
/// Any error that the workflow can try again on a fixed number of times. Only used for printing.
240248
pub(crate) fn is_retryable(&self) -> bool {
241249
match self {
242250
WorkflowError::ActivityFailure(_, _)

0 commit comments

Comments
 (0)