From 56a00bf105e51b6cbd93713753942a8260b16802 Mon Sep 17 00:00:00 2001 From: Francesco Guardiani Date: Tue, 30 Jan 2024 18:32:25 +0100 Subject: [PATCH] Introduce orTimeout(millis) API, to easily combine a CombineablePromise with a sleep to implement timeouts. (#234) --- src/restate_context.ts | 9 +++++ src/restate_context_impl.ts | 34 ++++++++++++++++-- src/types/errors.ts | 6 ++++ test/promise_combinators.test.ts | 61 ++++++++++++++++++++++++++++++++ 4 files changed, 107 insertions(+), 3 deletions(-) diff --git a/src/restate_context.ts b/src/restate_context.ts index 675a9c28..8dc11b97 100644 --- a/src/restate_context.ts +++ b/src/restate_context.ts @@ -18,6 +18,15 @@ import { RestateGrpcContextImpl } from "./restate_context_impl"; */ export type CombineablePromise = Promise & { __restate_context: RestateBaseContext; + + /** + * Creates a promise that awaits for the current promise up to the specified timeout duration. + * If the timeout is fired, this Promise will be rejected with a {@link TimeoutError}. + * + * @param millis duration of the sleep in millis. + * This is a lower-bound. + */ + orTimeout(millis: number): Promise; }; /** diff --git a/src/restate_context_impl.ts b/src/restate_context_impl.ts index 5411f96d..bc71b127 100644 --- a/src/restate_context_impl.ts +++ b/src/restate_context_impl.ts @@ -49,6 +49,7 @@ import { TerminalError, ensureError, errorToFailureWithTerminal, + TimeoutError, } from "./types/errors"; import { jsonSerialize, jsonDeserialize } from "./utils/utils"; import { Empty } from "./generated/google/protobuf/empty"; @@ -62,7 +63,10 @@ import { Client, SendClient } from "./types/router"; import { RpcRequest, RpcResponse } from "./generated/proto/dynrpc"; import { requestFromArgs } from "./utils/assumptions"; import { RandImpl } from "./utils/rand"; -import { newJournalEntryPromiseId } from "./promise_combinator_tracker"; +import { + newJournalEntryPromiseId, + PromiseId, +} from "./promise_combinator_tracker"; import { WrappedPromise } from "./utils/promises"; export enum CallContexType { @@ -345,7 +349,7 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { return this.markCombineablePromise(this.sleepInternal(millis)); } - private sleepInternal(millis: number): Promise { + private sleepInternal(millis: number): WrappedPromise { return this.stateMachine.handleUserCodeMessage( SLEEP_ENTRY_MESSAGE_TYPE, SleepEntryMessage.create({ wakeUpTime: Date.now() + millis }) @@ -494,12 +498,36 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { private markCombineablePromise( p: Promise ): InternalCombineablePromise { + const journalIndex = this.stateMachine.getUserCodeJournalIndex(); + const orTimeout = (millis: number): Promise => { + const sleepPromise: Promise = this.sleepInternal(millis).transform( + () => { + throw new TimeoutError(); + } + ); + const sleepPromiseIndex = this.stateMachine.getUserCodeJournalIndex(); + + return this.stateMachine.createCombinator(Promise.race.bind(Promise), [ + { + id: newJournalEntryPromiseId(journalIndex), + promise: p, + }, + { + id: newJournalEntryPromiseId(sleepPromiseIndex), + promise: sleepPromise, + }, + ]) as Promise; + }; + return Object.defineProperties(p, { __restate_context: { value: this, }, journalIndex: { - value: this.stateMachine.getUserCodeJournalIndex(), + value: journalIndex, + }, + orTimeout: { + value: orTimeout.bind(this), }, }) as InternalCombineablePromise; } diff --git a/src/types/errors.ts b/src/types/errors.ts index fe63b9dd..fe79decd 100644 --- a/src/types/errors.ts +++ b/src/types/errors.ts @@ -201,6 +201,12 @@ export class TerminalError extends RestateError { } } +export class TimeoutError extends TerminalError { + constructor() { + super("Timeout occurred", { errorCode: ErrorCodes.DEADLINE_EXCEEDED }); + } +} + // Leads to Restate retries export class RetryableError extends RestateError { constructor(message: string, options?: { errorCode?: number; cause?: any }) { diff --git a/test/promise_combinators.test.ts b/test/promise_combinators.test.ts index e5b683f3..b0d0953a 100644 --- a/test/promise_combinators.test.ts +++ b/test/promise_combinators.test.ts @@ -30,6 +30,7 @@ import { } from "./protoutils"; import { TestGreeter, TestResponse } from "../src/generated/proto/test"; import { SLEEP_ENTRY_MESSAGE_TYPE } from "../src/types/protocol"; +import { TimeoutError } from "../src/types/errors"; import { CombineablePromise } from "../src/restate_context"; class AwakeableSleepRaceGreeter implements TestGreeter { @@ -327,3 +328,63 @@ describe("CombineablePromiseThenSideEffect", () => { ]); }); }); + +class AwakeableOrTimeoutGreeter implements TestGreeter { + async greet(): Promise { + const ctx = restate.useContext(this); + + const { promise } = ctx.awakeable(); + try { + const result = await promise.orTimeout(100); + return TestResponse.create({ + greeting: `Hello ${result}`, + }); + } catch (e) { + if (e instanceof TimeoutError) { + return TestResponse.create({ + greeting: `Hello timed-out`, + }); + } + } + + throw new Error("Unexpected result"); + } +} + +describe("AwakeableOrTimeoutGreeter", () => { + it("handles completion of awakeable", async () => { + const result = await new TestDriver(new AwakeableOrTimeoutGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + completionMessage(1, JSON.stringify("Francesco")), + ackMessage(3), + ]).run(); + + expect(result.length).toStrictEqual(5); + expect(result[0]).toStrictEqual(awakeableMessage()); + expect(result[1].messageType).toStrictEqual(SLEEP_ENTRY_MESSAGE_TYPE); + expect(result.slice(2)).toStrictEqual([ + combinatorEntryMessage(0, [1]), + outputMessage(greetResponse(`Hello Francesco`)), + END_MESSAGE, + ]); + }); + + it("handles completion of sleep", async () => { + const result = await new TestDriver(new AwakeableOrTimeoutGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + completionMessage(2, undefined, true), + ackMessage(3), + ]).run(); + + expect(result.length).toStrictEqual(5); + expect(result[0]).toStrictEqual(awakeableMessage()); + expect(result[1].messageType).toStrictEqual(SLEEP_ENTRY_MESSAGE_TYPE); + expect(result.slice(2)).toStrictEqual([ + combinatorEntryMessage(0, [2]), + outputMessage(greetResponse(`Hello timed-out`)), + END_MESSAGE, + ]); + }); +});