diff --git a/src/restate_context_impl.ts b/src/restate_context_impl.ts index fb9f3420..7b39f23e 100644 --- a/src/restate_context_impl.ts +++ b/src/restate_context_impl.ts @@ -82,6 +82,10 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { // For example, this is illegal: 'ctx.sideEffect(() => {await ctx.get("my-state")})' static callContext = new AsyncLocalStorage(); + // This is used to guard users against calling ctx.sideEffect without awaiting it. + // See https://github.com/restatedev/sdk-typescript/issues/197 for more details. + private executingSideEffect = false; + constructor( public readonly id: Buffer, public readonly serviceName: string, @@ -90,30 +94,35 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { public readonly rand: Rand = new RandImpl(id) ) {} - public async get(name: string): Promise { + // DON'T make this function async!!! see sideEffect comment for details. + public get(name: string): Promise { // Check if this is a valid action this.checkState("get state"); // Create the message and let the state machine process it const msg = this.stateMachine.localStateStore.get(name); - const result = await this.stateMachine.handleUserCodeMessage( - GET_STATE_ENTRY_MESSAGE_TYPE, - msg - ); - // If the GetState message did not have a value or empty, - // then we went to the runtime to get the value. - // When we get the response, we set it in the localStateStore, - // to answer subsequent requests - if (msg.value === undefined && msg.empty === undefined) { - this.stateMachine.localStateStore.add(name, result as Buffer | Empty); - } + const getState = async (): Promise => { + const result = await this.stateMachine.handleUserCodeMessage( + GET_STATE_ENTRY_MESSAGE_TYPE, + msg + ); - if (!(result instanceof Buffer)) { - return null; - } + // If the GetState message did not have a value or empty, + // then we went to the runtime to get the value. + // When we get the response, we set it in the localStateStore, + // to answer subsequent requests + if (msg.value === undefined && msg.empty === undefined) { + this.stateMachine.localStateStore.add(name, result as Buffer | Empty); + } - return jsonDeserialize(result.toString()); + if (!(result instanceof Buffer)) { + return null; + } + + return jsonDeserialize(result.toString()); + }; + return getState(); } public set(name: string, value: T): void { @@ -144,7 +153,8 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { } } - private async invoke( + // DON'T make this function async!!! see sideEffect comment for details. + private invoke( service: string, method: string, data: Uint8Array @@ -156,11 +166,9 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { methodName: method, parameter: Buffer.from(data), }); - const promise = this.stateMachine.handleUserCodeMessage( - INVOKE_ENTRY_MESSAGE_TYPE, - msg - ); - return (await promise) as Uint8Array; + return this.stateMachine + .handleUserCodeMessage(INVOKE_ENTRY_MESSAGE_TYPE, msg) + .transform((v) => v as Uint8Array); } private async invokeOneWay( @@ -184,19 +192,21 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { return new Uint8Array(); } - public async oneWayCall( + // DON'T make this function async!!! see sideEffect comment for details. + public oneWayCall( // eslint-disable-next-line @typescript-eslint/no-explicit-any call: () => Promise ): Promise { this.checkState("oneWayCall"); - await RestateGrpcContextImpl.callContext.run( + return RestateGrpcContextImpl.callContext.run( { type: CallContexType.OneWayCall }, call ); } - public async delayedCall( + // DON'T make this function async!!! see sideEffect comment for details. + public delayedCall( // eslint-disable-next-line @typescript-eslint/no-explicit-any call: () => Promise, delayMillis?: number @@ -204,13 +214,17 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { this.checkState("delayedCall"); // Delayed call is a one way call with a delay - await RestateGrpcContextImpl.callContext.run( + return RestateGrpcContextImpl.callContext.run( { type: CallContexType.OneWayCall, delay: delayMillis }, call ); } - public async sideEffect( + // DON'T make this function async!!! + // The reason is that we want the erros thrown by the initial checks to be propagated in the caller context, + // and not in the promise context. To understand the semantic difference, make this function async and run the + // UnawaitedSideEffectShouldFailSubsequentContextCall test. + public sideEffect( fn: () => Promise, retryPolicy: RetrySettings = DEFAULT_INFINITE_EXPONENTIAL_BACKOFF ): Promise { @@ -227,6 +241,8 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { { errorCode: ErrorCodes.INTERNAL } ); } + this.checkNotExecutingSideEffect(); + this.executingSideEffect = true; const executeAndLogSideEffect = async () => { // in replay mode, we directly return the value from the log @@ -301,17 +317,25 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { return sideEffectResult; }; - const sleep = (millis: number) => this.sleep(millis); - return executeWithRetries(retryPolicy, executeAndLogSideEffect, sleep); + const sleep = (millis: number) => this.sleepInternal(millis); + return executeWithRetries( + retryPolicy, + executeAndLogSideEffect, + sleep + ).finally(() => { + this.executingSideEffect = false; + }); } public sleep(millis: number): Promise { this.checkState("sleep"); + return this.sleepInternal(millis); + } - const msg = SleepEntryMessage.create({ wakeUpTime: Date.now() + millis }); + private sleepInternal(millis: number): Promise { return this.stateMachine.handleUserCodeMessage( SLEEP_ENTRY_MESSAGE_TYPE, - msg + SleepEntryMessage.create({ wakeUpTime: Date.now() + millis }) ); } @@ -385,9 +409,20 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { return context?.delay || 0; } + private checkNotExecutingSideEffect() { + if (this.executingSideEffect) { + throw new TerminalError( + `Invoked a RestateContext method while a side effect is still executing. + Make sure you await the ctx.sideEffect call before using any other RestateContext method.`, + { errorCode: ErrorCodes.INTERNAL } + ); + } + } + private checkState(callType: string): void { const context = RestateGrpcContextImpl.callContext.getStore(); if (!context) { + this.checkNotExecutingSideEffect(); return; } diff --git a/src/server/base_restate_server.ts b/src/server/base_restate_server.ts index e8632d30..578cc555 100644 --- a/src/server/base_restate_server.ts +++ b/src/server/base_restate_server.ts @@ -149,9 +149,7 @@ export abstract class BaseRestateServer { method ); // note that this log will not print all the keys. - rlog.info( - `Binding: ${url} -> ${JSON.stringify(method, null, "\t")}` - ); + rlog.info(`Binding: ${url} -> ${JSON.stringify(method, null, "\t")}`); } } @@ -264,11 +262,7 @@ export abstract class BaseRestateServer { ) as HostedGrpcServiceMethod; rlog.info( - `Binding: ${url} -> ${JSON.stringify( - registration.method, - null, - "\t" - )}` + `Binding: ${url} -> ${JSON.stringify(registration.method, null, "\t")}` ); } diff --git a/test/side_effect.test.ts b/test/side_effect.test.ts index 60a33851..388f4485 100644 --- a/test/side_effect.test.ts +++ b/test/side_effect.test.ts @@ -911,6 +911,103 @@ describe("AwaitSideEffectService", () => { }); }); +export class UnawaitedSideEffectShouldFailSubsequentContextCallService + implements TestGreeter +{ + constructor( + // eslint-disable-next-line @typescript-eslint/no-empty-function + private readonly next = (ctx: restate.RestateContext): void => {} + ) {} + + // eslint-disable-next-line @typescript-eslint/no-unused-vars + async greet(request: TestRequest): Promise { + const ctx = restate.useContext(this); + + ctx.sideEffect(async () => { + // eslint-disable-next-line @typescript-eslint/no-empty-function + return new Promise(() => {}); + }); + this.next(ctx); + + throw new Error("code should not reach this point"); + } +} + +describe("UnawaitedSideEffectShouldFailSubsequentContextCall", () => { + const defineTestCase = ( + contextMethodCall: string, + next: (ctx: restate.RestateContext) => void + ): void => { + it( + "Not awaiting side effect should fail at next " + contextMethodCall, + async () => { + const result = await new TestDriver( + new UnawaitedSideEffectShouldFailSubsequentContextCallService(next), + [startMessage(), inputMessage(greetRequest("Till"))] + ).run(); + + checkTerminalError( + result[0], + `Invoked a RestateContext method while a side effect is still executing. + Make sure you await the ctx.sideEffect call before using any other RestateContext method.` + ); + expect(result.slice(1)).toStrictEqual([END_MESSAGE]); + } + ); + }; + + defineTestCase("side effect", (ctx) => + ctx.sideEffect(async () => { + return 1; + }) + ); + defineTestCase("get", (ctx) => ctx.get("123")); + defineTestCase("set", (ctx) => ctx.set("123", "abc")); + defineTestCase("call", (ctx) => { + const client = new TestGreeterClientImpl(ctx); + client.greet(TestRequest.create({ name: "Francesco" })); + }); + defineTestCase("one way call", (ctx) => { + const client = new TestGreeterClientImpl(ctx); + ctx.oneWayCall(() => + client.greet(TestRequest.create({ name: "Francesco" })) + ); + }); +}); + +export class UnawaitedSideEffectShouldFailSubsequentSetService + implements TestGreeter +{ + // eslint-disable-next-line @typescript-eslint/no-unused-vars + async greet(request: TestRequest): Promise { + const ctx = restate.useContext(this); + + ctx.sideEffect(async () => { + // eslint-disable-next-line @typescript-eslint/no-empty-function + return new Promise(() => {}); + }); + ctx.set("123", "abc"); + + throw new Error("code should not reach this point"); + } +} + +describe("UnawaitedSideEffectShouldFailSubsequentSetService", () => { + it("Not awaiting side effects should fail", async () => { + const result = await new TestDriver( + new UnawaitedSideEffectShouldFailSubsequentSetService(), + [startMessage(), inputMessage(greetRequest("Till"))] + ).run(); + + checkTerminalError( + result[0], + `Invoked a RestateContext method while a side effect is still executing. + Make sure you await the ctx.sideEffect call before using any other RestateContext method.` + ); + expect(result.slice(1)).toStrictEqual([END_MESSAGE]); + }); +}); + export class TerminalErrorSideEffectService implements TestGreeter { // eslint-disable-next-line @typescript-eslint/no-unused-vars async greet(request: TestRequest): Promise {