diff --git a/proto/protocol.proto b/proto/protocol.proto index 5be52b03..85c9918c 100644 --- a/proto/protocol.proto +++ b/proto/protocol.proto @@ -96,13 +96,18 @@ message ErrorMessage { // ------ Input and output ------ -// Kind: Completable JournalEntry +// Completable: Yes +// Fallible: No // Type: 0x0400 + 0 message PollInputStreamEntryMessage { - bytes value = 14; + oneof result { + bytes value = 14; + Failure failure = 15; + } } -// Kind: Non-Completable JournalEntry +// Completable: No +// Fallible: No // Type: 0x0400 + 1 message OutputStreamEntryMessage { oneof result { @@ -113,7 +118,8 @@ message OutputStreamEntryMessage { // ------ State access ------ -// Kind: Completable JournalEntry +// Completable: Yes +// Fallible: No // Type: 0x0800 + 0 message GetStateEntryMessage { bytes key = 1; @@ -121,17 +127,20 @@ message GetStateEntryMessage { oneof result { google.protobuf.Empty empty = 13; bytes value = 14; + Failure failure = 15; }; } -// Kind: Non-Completable JournalEntry +// Completable: No +// Fallible: No // Type: 0x0800 + 1 message SetStateEntryMessage { bytes key = 1; bytes value = 3; } -// Kind: Non-Completable JournalEntry +// Completable: No +// Fallible: No // Type: 0x0800 + 2 message ClearStateEntryMessage { bytes key = 1; @@ -139,17 +148,22 @@ message ClearStateEntryMessage { // ------ Syscalls ------ -// Kind: Completable JournalEntry +// Completable: Yes +// Fallible: No // Type: 0x0C00 + 0 message SleepEntryMessage { // Wake up time. // The time is set as duration since UNIX Epoch. uint64 wake_up_time = 1; - google.protobuf.Empty result = 13; + oneof result { + google.protobuf.Empty empty = 13; + Failure failure = 15; + } } -// Kind: Completable JournalEntry +// Completable: Yes +// Fallible: Yes // Type: 0x0C00 + 1 message InvokeEntryMessage { string service_name = 1; @@ -163,7 +177,8 @@ message InvokeEntryMessage { }; } -// Kind: Non-Completable JournalEntry +// Completable: No +// Fallible: Yes // Type: 0x0C00 + 2 message BackgroundInvokeEntryMessage { string service_name = 1; @@ -178,7 +193,8 @@ message BackgroundInvokeEntryMessage { uint64 invoke_time = 4; } -// Kind: Completable JournalEntry +// Completable: Yes +// Fallible: No // Type: 0x0C00 + 3 // Awakeables are addressed by an identifier exposed to the user. See the spec for more details. message AwakeableEntryMessage { @@ -188,7 +204,8 @@ message AwakeableEntryMessage { }; } -// Kind: Non-Completable JournalEntry +// Completable: No +// Fallible: Yes // Type: 0x0C00 + 4 message CompleteAwakeableEntryMessage { // Identifier of the awakeable. See the spec for more details. diff --git a/src/invocation.ts b/src/invocation.ts index 75977e4c..b78d2ef1 100644 --- a/src/invocation.ts +++ b/src/invocation.ts @@ -14,6 +14,7 @@ import { Message } from "./types/types"; import { HostedGrpcServiceMethod } from "./types/grpc"; import { + Failure, PollInputStreamEntryMessage, StartMessage, } from "./generated/proto/protocol"; @@ -37,6 +38,10 @@ enum State { Complete = 3, } +type InvocationValue = + | { kind: "value"; value: Buffer } + | { kind: "failure"; failure: Failure }; + export class InvocationBuilder implements RestateStreamConsumer { private readonly complete = new CompletablePromise(); @@ -46,7 +51,7 @@ export class InvocationBuilder implements RestateStreamConsumer { private replayEntries = new Map(); private id?: Buffer = undefined; private debugId?: string = undefined; - private invocationValue?: Buffer = undefined; + private invocationValue?: InvocationValue = undefined; private nbEntriesToReplay?: number = undefined; private localStateStore?: LocalStateStore; @@ -67,6 +72,8 @@ export class InvocationBuilder implements RestateStreamConsumer { POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE, m ); + + this.handlePollInputStreamEntry(m); this.addReplayEntry(m); break; @@ -100,6 +107,28 @@ export class InvocationBuilder implements RestateStreamConsumer { } } + private handlePollInputStreamEntry(m: Message) { + const pollInputStreamMessage = m.message as PollInputStreamEntryMessage; + + if (pollInputStreamMessage.value !== undefined) { + this.invocationValue = { + kind: "value", + value: pollInputStreamMessage.value, + }; + } else if (pollInputStreamMessage.failure !== undefined) { + this.invocationValue = { + kind: "failure", + failure: pollInputStreamMessage.failure, + }; + } else { + throw new Error( + `PollInputStreamEntry neither contains value nor failure: ${printMessageAsJson( + m + )}` + ); + } + } + public handleStreamError(e: Error): void { this.complete.reject(e); } @@ -120,10 +149,6 @@ export class InvocationBuilder implements RestateStreamConsumer { } private addReplayEntry(m: Message): InvocationBuilder { - if (m.messageType === POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE) { - this.invocationValue = (m.message as PollInputStreamEntryMessage).value; - } - // Will be retrieved when the user code reaches this point this.replayEntries.set(this.runtimeReplayIndex, m); this.incrementRuntimeReplayIndex(); @@ -164,7 +189,7 @@ export class Invocation { public readonly debugId: string, public readonly nbEntriesToReplay: number, public readonly replayEntries: Map, - public readonly invocationValue: Buffer, + public readonly invocationValue: InvocationValue, public readonly localStateStore: LocalStateStore ) { this.logPrefix = `[${makeFqServiceName( diff --git a/src/journal.ts b/src/journal.ts index 1ca79a1d..6a960caa 100644 --- a/src/journal.ts +++ b/src/journal.ts @@ -260,7 +260,8 @@ export class Journal { this.resolveResult( journalIndex, journalEntry, - getStateMsg.value || getStateMsg.empty + getStateMsg.value || getStateMsg.empty, + getStateMsg.failure ); break; } @@ -276,7 +277,12 @@ export class Journal { } case SLEEP_ENTRY_MESSAGE_TYPE: { const sleepMsg = replayMessage.message as SleepEntryMessage; - this.resolveResult(journalIndex, journalEntry, sleepMsg.result); + this.resolveResult( + journalIndex, + journalEntry, + sleepMsg.empty, + sleepMsg.failure + ); break; } case AWAKEABLE_ENTRY_MESSAGE_TYPE: { diff --git a/src/state_machine.ts b/src/state_machine.ts index 85736666..72b6cb1d 100644 --- a/src/state_machine.ts +++ b/src/state_machine.ts @@ -33,6 +33,7 @@ import { TerminalError, RetryableError, errorToErrorMessage, + failureToTerminalError, } from "./types/errors"; import { LocalStateStore } from "./local_state_store"; @@ -211,10 +212,21 @@ export class StateMachine implements RestateStreamConsumer { rlog.debugInvokeMessage(this.invocation.logPrefix, "Invoking function."); } - const resultBytes: Promise = this.invocation.method.invoke( - this.restateContext, - this.invocation.invocationValue - ); + let resultBytes: Promise; + + switch (this.invocation.invocationValue.kind) { + case "value": + resultBytes = this.invocation.method.invoke( + this.restateContext, + this.invocation.invocationValue.value + ); + break; + case "failure": + resultBytes = Promise.reject( + failureToTerminalError(this.invocation.invocationValue.failure) + ); + break; + } resultBytes .then((bytes) => { diff --git a/src/types/errors.ts b/src/types/errors.ts index 6badcac1..3fcbeb62 100644 --- a/src/types/errors.ts +++ b/src/types/errors.ts @@ -260,6 +260,10 @@ export function errorToFailureWithTerminal(err: Error): FailureWithTerminal { }); } +export function failureToTerminalError(failure: Failure): TerminalError { + return failureToError(failure, true) as TerminalError; +} + export function failureToError( failure: Failure, terminalError: boolean diff --git a/test/get_state.test.ts b/test/get_state.test.ts index 8b9219de..560b304f 100644 --- a/test/get_state.test.ts +++ b/test/get_state.test.ts @@ -14,7 +14,9 @@ import * as restate from "../src/public_api"; import { TestDriver } from "./testdriver"; import { checkJournalMismatchError, + checkTerminalError, completionMessage, + failure, getStateMessage, greetRequest, greetResponse, @@ -106,6 +108,18 @@ describe("GetStringStateGreeter", () => { ]); }); + it("handles completion with failure", async () => { + const result = await new TestDriver(new GetStringStateGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + completionMessage(1, undefined, undefined, failure("Canceled")), + ]).run(); + + expect(result.length).toStrictEqual(2); + expect(result[0]).toStrictEqual(getStateMessage("STATE")); + checkTerminalError(result[1], "Canceled"); + }); + it("handles replay with value", async () => { const result = await new TestDriver(new GetStringStateGreeter(), [ startMessage(), @@ -129,6 +143,17 @@ describe("GetStringStateGreeter", () => { outputMessage(greetResponse("Hello nobody")), ]); }); + + it("handles replay with failure", async () => { + const result = await new TestDriver(new GetStringStateGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + getStateMessage("STATE", undefined, undefined, failure("Canceled")), + ]).run(); + + expect(result.length).toStrictEqual(1); + checkTerminalError(result[0], "Canceled"); + }); }); class GetNumberStateGreeter implements TestGreeter { diff --git a/test/protoutils.ts b/test/protoutils.ts index ebace05b..39709ba5 100644 --- a/test/protoutils.ts +++ b/test/protoutils.ts @@ -90,13 +90,24 @@ export function toStateEntries(entries: Buffer[][]) { ); } -export function inputMessage(value: Uint8Array): Message { - return new Message( - POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE, - PollInputStreamEntryMessage.create({ - value: Buffer.from(value), - }) - ); +export function inputMessage(value?: Uint8Array, failure?: Failure): Message { + if (failure !== undefined) { + return new Message( + POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE, + PollInputStreamEntryMessage.create({ + failure: failure, + }) + ); + } else if (value !== undefined) { + return new Message( + POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE, + PollInputStreamEntryMessage.create({ + value: Buffer.from(value), + }) + ); + } else { + throw new Error("Input message needs either a value or a failure set."); + } } export function outputMessage(value?: Uint8Array, failure?: Failure): Message { @@ -130,7 +141,8 @@ export function outputMessage(value?: Uint8Array, failure?: Failure): Message { export function getStateMessage( key: string, value?: T, - empty?: boolean + empty?: boolean, + failure?: Failure ): Message { if (empty === true) { return new Message( @@ -148,6 +160,14 @@ export function getStateMessage( value: Buffer.from(jsonSerialize(value)), }) ); + } else if (failure !== undefined) { + return new Message( + GET_STATE_ENTRY_MESSAGE_TYPE, + GetStateEntryMessage.create({ + key: Buffer.from(key), + failure: failure, + }) + ); } else { return new Message( GET_STATE_ENTRY_MESSAGE_TYPE, @@ -177,13 +197,25 @@ export function clearStateMessage(key: string): Message { ); } -export function sleepMessage(wakeupTime: number, result?: Empty): Message { - if (result !== undefined) { +export function sleepMessage( + wakeupTime: number, + empty?: Empty, + failure?: Failure +): Message { + if (empty !== undefined) { return new Message( SLEEP_ENTRY_MESSAGE_TYPE, SleepEntryMessage.create({ wakeUpTime: wakeupTime, - result: result, + empty: empty, + }) + ); + } else if (failure !== undefined) { + return new Message( + SLEEP_ENTRY_MESSAGE_TYPE, + SleepEntryMessage.create({ + wakeUpTime: wakeupTime, + failure: failure, }) ); } else { diff --git a/test/sleep.test.ts b/test/sleep.test.ts index e5accb93..6e9881ac 100644 --- a/test/sleep.test.ts +++ b/test/sleep.test.ts @@ -15,7 +15,9 @@ import { TestDriver } from "./testdriver"; import { awakeableMessage, checkJournalMismatchError, + checkTerminalError, completionMessage, + failure, greetRequest, greetResponse, inputMessage, @@ -94,6 +96,18 @@ describe("SleepGreeter", () => { expect(result[1]).toStrictEqual(outputMessage(greetResponse("Hello"))); }); + it("handles completion with failure", async () => { + const result = await new TestDriver(new SleepGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + completionMessage(1, undefined, undefined, failure("Canceled")), + ]).run(); + + expect(result.length).toStrictEqual(2); + expect(result[0].messageType).toStrictEqual(SLEEP_ENTRY_MESSAGE_TYPE); + checkTerminalError(result[1], "Canceled"); + }); + it("handles replay with no empty", async () => { const result = await new TestDriver(new SleepGreeter(), [ startMessage(), @@ -116,6 +130,17 @@ describe("SleepGreeter", () => { expect(result[0]).toStrictEqual(outputMessage(greetResponse("Hello"))); }); + it("handles replay with failure", async () => { + const result = await new TestDriver(new SleepGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + sleepMessage(wakeupTime, undefined, failure("Canceled")), + ]).run(); + + expect(result.length).toStrictEqual(1); + checkTerminalError(result[0], "Canceled"); + }); + it("fails on journal mismatch. Completed with Awakeable.", async () => { const result = await new TestDriver(new SleepGreeter(), [ startMessage(), diff --git a/test/state_machine.test.ts b/test/state_machine.test.ts index 472fff24..057d273c 100644 --- a/test/state_machine.test.ts +++ b/test/state_machine.test.ts @@ -14,6 +14,8 @@ import * as restate from "../src/public_api"; import { describe, expect } from "@jest/globals"; import { TestDriver } from "./testdriver"; import { + checkTerminalError, + failure, greetRequest, greetResponse, inputMessage, @@ -48,4 +50,14 @@ describe("Greeter", () => { expect(result).toStrictEqual([]); }); + + it("fails invocation if input is failed", async () => { + const result = await new TestDriver(new Greeter(), [ + startMessage(1), + inputMessage(undefined, failure("Canceled")), + ]).run(); + + expect(result.length).toStrictEqual(1); + checkTerminalError(result[0], "Canceled"); + }); });