diff --git a/proto/javascript.proto b/proto/javascript.proto index b11a5f25..3d9008c0 100644 --- a/proto/javascript.proto +++ b/proto/javascript.proto @@ -27,4 +27,11 @@ message SideEffectEntryMessage { bytes value = 14; FailureWithTerminal failure = 15; }; +} + +// Type: 0xFC00 + 2 +message CombinatorEntryMessage { + int32 combinator_id = 1; + + repeated int32 journal_entries_order = 2; } \ No newline at end of file diff --git a/src/invocation.ts b/src/invocation.ts index 5edcbe64..6f98ae1f 100644 --- a/src/invocation.ts +++ b/src/invocation.ts @@ -18,7 +18,7 @@ import { PollInputStreamEntryMessage, StartMessage, } from "./generated/proto/protocol"; -import { CompletablePromise, formatMessageAsJson } from "./utils/utils"; +import { formatMessageAsJson } from "./utils/utils"; import { POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE, START_MESSAGE_TYPE, @@ -27,6 +27,7 @@ import { RestateStreamConsumer } from "./connection/connection"; import { LocalStateStore } from "./local_state_store"; import { ensureError } from "./types/errors"; import { LoggerContext } from "./logger"; +import { CompletablePromise } from "./utils/promises"; enum State { ExpectingStart = 0, diff --git a/src/journal.ts b/src/journal.ts index da4f958d..1a5f06b7 100644 --- a/src/journal.ts +++ b/src/journal.ts @@ -16,6 +16,7 @@ import { AwakeableEntryMessage, BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE, CLEAR_STATE_ENTRY_MESSAGE_TYPE, + COMBINATOR_ENTRY_MESSAGE, COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE, CompletionMessage, EntryAckMessage, @@ -39,6 +40,7 @@ import { Message } from "./types/types"; import { SideEffectEntryMessage } from "./generated/proto/javascript"; import { Invocation } from "./invocation"; import { failureToError, RetryableError } from "./types/errors"; +import { CompletablePromise } from "./utils/promises"; const RESOLVED = Promise.resolve(undefined); @@ -99,7 +101,7 @@ export class Journal { const journalEntry = new JournalEntry(messageType, message); this.handleReplay(this.userCodeJournalIndex, replayEntry, journalEntry); - return journalEntry.promise; + return journalEntry.completablePromise.promise; } case NewExecutionState.PROCESSING: { switch (messageType) { @@ -109,7 +111,7 @@ export class Journal { messageType, message as p.SuspensionMessage | p.OutputStreamEntryMessage ); - return Promise.resolve(undefined); + return RESOLVED; } case p.SET_STATE_ENTRY_MESSAGE_TYPE: case p.CLEAR_STATE_ENTRY_MESSAGE_TYPE: @@ -128,22 +130,11 @@ export class Journal { return Promise.resolve(getStateMsg.value || getStateMsg.empty); } else { // Need to retrieve state by going to the runtime. - const journalEntry = new JournalEntry(messageType, message); - this.pendingJournalEntries.set( - this.userCodeJournalIndex, - journalEntry - ); - return journalEntry.promise; + return this.appendJournalEntry(messageType, message); } } default: { - // Need completion - const journalEntry = new JournalEntry(messageType, message); - this.pendingJournalEntries.set( - this.userCodeJournalIndex, - journalEntry - ); - return journalEntry.promise; + return this.appendJournalEntry(messageType, message); } } } @@ -153,7 +144,7 @@ export class Journal { // So no more user messages can come in... // - Print warning log and continue... //TODO received user-side message but state machine is closed - return Promise.resolve(undefined); + return RESOLVED; } default: { throw RetryableError.protocolViolation( @@ -178,7 +169,7 @@ export class Journal { } if (m.value !== undefined) { - journalEntry.resolve(m.value); + journalEntry.completablePromise.resolve(m.value); this.pendingJournalEntries.delete(m.entryIndex); } else if (m.failure !== undefined) { // we do all completions with Terminal Errors, because failures triggered by those exceptions @@ -186,10 +177,10 @@ export class Journal { // thus an infinite loop that keeps replay-ing but never makes progress // these failures here consequently need to cause terminal failures, unless caught and handled // by the handler code - journalEntry.reject(failureToError(m.failure, true)); + journalEntry.completablePromise.reject(failureToError(m.failure, true)); this.pendingJournalEntries.delete(m.entryIndex); } else if (m.empty !== undefined) { - journalEntry.resolve(m.empty); + journalEntry.completablePromise.resolve(m.empty); this.pendingJournalEntries.delete(m.entryIndex); } else { //TODO completion message without a value/failure/empty @@ -205,7 +196,7 @@ export class Journal { } // Just needs an ack - journalEntry.resolve(undefined); + journalEntry.completablePromise.resolve(undefined); this.pendingJournalEntries.delete(m.entryIndex); } @@ -314,7 +305,7 @@ export class Journal { } else { // A side effect can have a void return type // If it was replayed, then it is acked, so we should resolve it. - journalEntry.resolve(undefined); + journalEntry.completablePromise.resolve(undefined); this.pendingJournalEntries.delete(journalIndex); } break; @@ -322,9 +313,10 @@ export class Journal { case SET_STATE_ENTRY_MESSAGE_TYPE: case CLEAR_STATE_ENTRY_MESSAGE_TYPE: case COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE: - case BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE: { + case BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE: + case COMBINATOR_ENTRY_MESSAGE: { // Do not need a completion. So if the match has passed then the entry can be deleted. - journalEntry.resolve(undefined); + journalEntry.completablePromise.resolve(undefined); this.pendingJournalEntries.delete(journalIndex); break; } @@ -342,11 +334,11 @@ export class Journal { failureWouldBeTerminal?: boolean ) { if (value !== undefined) { - journalEntry.resolve(value); + journalEntry.completablePromise.resolve(value); this.pendingJournalEntries.delete(journalIndex); } else if (failure !== undefined) { const error = failureToError(failure, failureWouldBeTerminal ?? true); - journalEntry.reject(error); + journalEntry.completablePromise.reject(error); this.pendingJournalEntries.delete(journalIndex); } else { this.pendingJournalEntries.set(journalIndex, journalEntry); @@ -369,7 +361,9 @@ export class Journal { } this.pendingJournalEntries.delete(0); - rootJournalEntry.resolve(new Message(messageType, message)); + rootJournalEntry.completablePromise.resolve( + new Message(messageType, message) + ); } private checkJournalMatch( @@ -412,7 +406,7 @@ export class Journal { } } - private incrementUserCodeIndex() { + incrementUserCodeIndex() { this.userCodeJournalIndex++; if ( this.userCodeJournalIndex === this.invocation.nbEntriesToReplay && @@ -422,6 +416,26 @@ export class Journal { } } + /** + * Read the next replay entry + */ + public readNextReplayEntry() { + this.incrementUserCodeIndex(); + return this.invocation.replayEntries.get(this.userCodeJournalIndex); + } + + /** + * Append journal entry. This won't increment the journal index. + */ + public appendJournalEntry( + messageType: bigint, + message: p.ProtocolMessage | Uint8Array + ): Promise { + const journalEntry = new JournalEntry(messageType, message); + this.pendingJournalEntries.set(this.userCodeJournalIndex, journalEntry); + return journalEntry.completablePromise.promise; + } + public isClosed(): boolean { return this.state === NewExecutionState.CLOSED; } @@ -464,22 +478,14 @@ export class Journal { } export class JournalEntry { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - public promise: Promise; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - public resolve!: (value: any) => void; - // eslint-disable-next-line @typescript-eslint/no-explicit-any - public reject!: (reason?: any) => void; + public completablePromise: CompletablePromise; constructor( readonly messageType: bigint, readonly message: p.ProtocolMessage | Uint8Array ) { // eslint-disable-next-line @typescript-eslint/no-explicit-any - this.promise = new Promise((res, rej) => { - this.resolve = res; - this.reject = rej; - }); + this.completablePromise = new CompletablePromise(); } } diff --git a/src/promise_combinator_tracker.ts b/src/promise_combinator_tracker.ts new file mode 100644 index 00000000..76bc77c6 --- /dev/null +++ b/src/promise_combinator_tracker.ts @@ -0,0 +1,191 @@ +import { + CompletablePromise, + wrapDeeply, + WrappedPromise, +} from "./utils/promises"; + +export enum PromiseType { + JournalEntry, + // Combinator?, + // SideEffect? +} + +export interface PromiseId { + type: PromiseType; + id: number; +} + +export function newJournalEntryPromiseId(entryIndex: number): PromiseId { + return { + type: PromiseType.JournalEntry, + id: entryIndex, + }; +} + +/** + * Prepare a Promise combinator + * + * @param combinatorIndex the index of this combinator + * @param combinatorConstructor the function that creates the combinator promise, e.g. Promise.all/any/race/allSettled + * @param promises the promises given by the user, and the respective ids + * @param readReplayOrder the function to read the replay order + * @param onNewCompleted callback when a child entry is resolved + * @param onCombinatorResolved callback when the combinator is resolved + * @param onCombinatorReplayed callback when the combinator is replayed + */ +function preparePromiseCombinator( + combinatorIndex: number, + combinatorConstructor: (promises: PromiseLike[]) => Promise, + promises: Array<{ id: PromiseId; promise: Promise }>, + readReplayOrder: (combinatorIndex: number) => PromiseId[] | undefined, + onNewCompleted: (combinatorIndex: number, promiseId: PromiseId) => void, + onCombinatorResolved: (combinatorIndex: number) => Promise, + onCombinatorReplayed: (combinatorIndex: number) => void +): WrappedPromise { + // Create the proxy promises and index them + const promisesWithProxyPromise = promises.map((v) => ({ + id: v.id, + originalPromise: v.promise, + proxyPromise: new CompletablePromise(), + })); + const promisesMap = new Map( + promisesWithProxyPromise.map((v) => [ + // We need to define a key format for this map... + v.id.type.toString() + "-" + v.id.id.toString(), + { originalPromise: v.originalPromise, proxyPromise: v.proxyPromise }, + ]) + ); + + // Create the combinator using the proxy promises + const combinator = combinatorConstructor( + promisesWithProxyPromise.map((v) => v.proxyPromise.promise) + ).finally( + async () => + // Once the combinator is resolved, notify back. + await onCombinatorResolved(combinatorIndex) + ); + + return wrapDeeply(combinator, () => { + const replayOrder = readReplayOrder(combinatorIndex); + + if (replayOrder === undefined) { + // We're in processing mode! We need to wire up original promises with proxy promises + for (const { + originalPromise, + proxyPromise, + id, + } of promisesWithProxyPromise) { + originalPromise + // This code works deterministically because the javascript runtime will enqueue + // the listeners of the proxy promise (which are mounted in Promise.all/any) in a single FIFO queue, + // so a subsequent resolve on another proxy promise can't overtake this one. + // + // Some resources: + // * https://stackoverflow.com/questions/38059284/why-does-javascript-promise-then-handler-run-after-other-code + // * https://262.ecma-international.org/6.0/#sec-jobs-and-job-queues + // * https://tr.javascript.info/microtask-queue + .then( + (v) => { + onNewCompleted(combinatorIndex, id); + proxyPromise.resolve(v); + }, + (e) => { + onNewCompleted(combinatorIndex, id); + proxyPromise.reject(e); + } + ); + } + return; + } + + // We're in replay mode, Now follow the replayIndexes order. + onCombinatorReplayed(combinatorIndex); + for (const promiseId of replayOrder) { + // These are already completed, so once we set the then callback they will be immediately resolved. + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + const { originalPromise, proxyPromise } = promisesMap.get( + promiseId.type.toString() + "-" + promiseId.id.toString() + )!; + + // Because this promise is already completed, promise.then will immediately enqueue in the promise microtask queue + // the handlers to execute. + // See the comment below for more details. + originalPromise.then( + (v) => proxyPromise.resolve(v), + (e) => proxyPromise.reject(e) + ); + } + }); +} + +/** + * This class takes care of creating and managing deterministic promise combinators. + * + * It should be wired up to the journal/state machine methods to read and write entries. + */ +export class PromiseCombinatorTracker { + private nextCombinatorIndex = 0; + private pendingCombinators: Map = new Map(); + + constructor( + private readonly readReplayOrder: ( + combinatorIndex: number + ) => PromiseId[] | undefined, + private readonly onWriteCombinatorOrder: ( + combinatorIndex: number, + order: PromiseId[] + ) => Promise + ) {} + + public createCombinator( + combinatorConstructor: ( + promises: PromiseLike[] + ) => Promise, + promises: Array<{ id: PromiseId; promise: Promise }> + ): WrappedPromise { + const combinatorIndex = this.nextCombinatorIndex; + this.nextCombinatorIndex++; + + // Prepare combinator order + this.pendingCombinators.set(combinatorIndex, []); + + return preparePromiseCombinator( + combinatorIndex, + combinatorConstructor, + promises, + this.readReplayOrder, + this.appendOrder.bind(this), + this.onCombinatorResolved.bind(this), + this.onCombinatorReplayed.bind(this) + ); + } + + private appendOrder(idx: number, promiseId: PromiseId) { + const order = this.pendingCombinators.get(idx); + if (order === undefined) { + // The order was already published, nothing to do here. + return; + } + + order.push(promiseId); + } + + private onCombinatorReplayed(idx: number) { + // This avoids republishing the order + this.pendingCombinators.delete(idx); + } + + private async onCombinatorResolved(idx: number) { + const order = this.pendingCombinators.get(idx); + if (order === undefined) { + // It was already published + return; + } + + // We don't need this list anymore. + this.pendingCombinators.delete(idx); + + // Publish the combinator order + await this.onWriteCombinatorOrder(idx, order); + } +} diff --git a/src/restate_context.ts b/src/restate_context.ts index 1cd644e8..675a9c28 100644 --- a/src/restate_context.ts +++ b/src/restate_context.ts @@ -11,6 +11,14 @@ import { RetrySettings } from "./utils/public_utils"; import { Client, SendClient } from "./types/router"; +import { RestateGrpcContextImpl } from "./restate_context_impl"; + +/** + * A promise that can be combined using Promise combinators in RestateContext. + */ +export type CombineablePromise = Promise & { + __restate_context: RestateBaseContext; +}; /** * Base Restate context, which contains all operations that are the same in the gRPC-based API @@ -148,7 +156,7 @@ export interface RestateBaseContext { * // Wait for the external service to wake this service back up * const result = await awakeable.promise; */ - awakeable(): { id: string; promise: Promise }; + awakeable(): { id: string; promise: CombineablePromise }; /** * Resolve an awakeable of another service. @@ -187,7 +195,7 @@ export interface RestateBaseContext { * const ctx = restate.useContext(this); * await ctx.sleep(1000); */ - sleep(millis: number): Promise; + sleep(millis: number): CombineablePromise; } export interface Rand { @@ -204,6 +212,103 @@ export interface Rand { uuidv4(): string; } +export const CombineablePromise = { + /** + * Creates a Promise that is resolved with an array of results when all of the provided Promises + * resolve, or rejected when any Promise is rejected. + * + * See {@link Promise.all} for more details. + * + * @param values An iterable of Promises. + * @returns A new Promise. + */ + all[] | []>( + values: T + ): Promise<{ -readonly [P in keyof T]: Awaited }> { + if (values.length == 0) { + return Promise.all(values); + } + + return ( + values[0].__restate_context as RestateGrpcContextImpl + ).createCombinator(Promise.all.bind(Promise), values) as Promise<{ + -readonly [P in keyof T]: Awaited; + }>; + }, + + /** + * Creates a Promise that is resolved or rejected when any of the provided Promises are resolved + * or rejected. + * + * See {@link Promise.race} for more details. + * + * @param values An iterable of Promises. + * @returns A new Promise. + */ + race[] | []>( + values: T + ): Promise> { + if (values.length == 0) { + return Promise.race(values); + } + + return ( + values[0].__restate_context as RestateGrpcContextImpl + ).createCombinator(Promise.race.bind(Promise), values) as Promise< + Awaited + >; + }, + + /** + * Creates a promise that fulfills when any of the input's promises fulfills, with this first fulfillment value. + * It rejects when all the input's promises reject (including when an empty iterable is passed), + * with an AggregateError containing an array of rejection reasons. + * + * See {@link Promise.any} for more details. + * + * @param values An iterable of Promises. + * @returns A new Promise. + */ + any[] | []>( + values: T + ): Promise> { + if (values.length == 0) { + return Promise.any(values); + } + + return ( + values[0].__restate_context as RestateGrpcContextImpl + ).createCombinator(Promise.any.bind(Promise), values) as Promise< + Awaited + >; + }, + + /** + * Creates a promise that fulfills when all the input's promises settle (including when an empty iterable is passed), + * with an array of objects that describe the outcome of each promise. + * + * See {@link Promise.allSettled} for more details. + * + * @param values An iterable of Promises. + * @returns A new Promise. + */ + allSettled[] | []>( + values: T + ): Promise<{ + -readonly [P in keyof T]: PromiseSettledResult>; + }> { + if (values.length == 0) { + return Promise.allSettled(values); + } + + return ( + values[0].__restate_context as RestateGrpcContextImpl + ).createCombinator(Promise.allSettled.bind(Promise), values) as Promise<{ + -readonly [P in keyof T]: PromiseSettledResult>; + }>; + }, +}; + // ---------------------------------------------------------------------------- // types and functions for the gRPC-based API // ---------------------------------------------------------------------------- diff --git a/src/restate_context_impl.ts b/src/restate_context_impl.ts index 114048f4..5411f96d 100644 --- a/src/restate_context_impl.ts +++ b/src/restate_context_impl.ts @@ -10,6 +10,7 @@ */ import { + CombineablePromise, Rand, RestateGrpcChannel, RestateGrpcContext, @@ -61,6 +62,8 @@ import { Client, SendClient } from "./types/router"; import { RpcRequest, RpcResponse } from "./generated/proto/dynrpc"; import { requestFromArgs } from "./utils/assumptions"; import { RandImpl } from "./utils/rand"; +import { newJournalEntryPromiseId } from "./promise_combinator_tracker"; +import { WrappedPromise } from "./utils/promises"; export enum CallContexType { None, @@ -73,6 +76,10 @@ export interface CallContext { delay?: number; } +export type InternalCombineablePromise = CombineablePromise & { + journalIndex: number; +}; + export class RestateGrpcContextImpl implements RestateGrpcContext { // here, we capture the context information for actions on the Restate context that // are executed within other actions, such as @@ -221,6 +228,10 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { ); } + rpcGateway(): RpcGateway { + return new RpcContextImpl(this); + } + // DON'T make this function async!!! // The reason is that we want the erros thrown by the initial checks to be propagated in the caller context, // and not in the promise context. To understand the semantic difference, make this function async and run the @@ -329,9 +340,9 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { }); } - public sleep(millis: number): Promise { + public sleep(millis: number): CombineablePromise { this.checkState("sleep"); - return this.sleepInternal(millis); + return this.markCombineablePromise(this.sleepInternal(millis)); } private sleepInternal(millis: number): Promise { @@ -341,7 +352,9 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { ); } - public awakeable(): { id: string; promise: Promise } { + // -- Awakeables + + public awakeable(): { id: string; promise: CombineablePromise } { this.checkState("awakeable"); const msg = AwakeableEntryMessage.create(); @@ -366,8 +379,10 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { ); return { - id: AWAKEABLE_IDENTIFIER_PREFIX + Buffer.concat([this.id, encodedEntryIndex]).toString("base64url"), - promise: promise, + id: + AWAKEABLE_IDENTIFIER_PREFIX + + Buffer.concat([this.id, encodedEntryIndex]).toString("base64url"), + promise: this.markCombineablePromise(promise), }; } @@ -396,6 +411,37 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { ); } + // Used by static methods of CombineablePromise + public createCombinator[]>( + combinatorConstructor: ( + promises: PromiseLike[] + ) => Promise, + promises: T + ): WrappedPromise { + const outPromises = []; + + for (const promise of promises) { + if (promise.__restate_context !== this) { + throw RetryableError.internal( + "You're mixing up CombineablePromises from different RestateContext. This is not supported." + ); + } + const index = (promise as InternalCombineablePromise) + .journalIndex; + outPromises.push({ + id: newJournalEntryPromiseId(index), + promise: promise, + }); + } + + return this.stateMachine.createCombinator( + combinatorConstructor, + outPromises + ); + } + + // -- Various private methods + private isInSideEffect(): boolean { const context = RestateGrpcContextImpl.callContext.getStore(); return context?.type === CallContexType.SideEffect; @@ -445,8 +491,17 @@ export class RestateGrpcContextImpl implements RestateGrpcContext { } } - rpcGateway(): RpcGateway { - return new RpcContextImpl(this); + private markCombineablePromise( + p: Promise + ): InternalCombineablePromise { + return Object.defineProperties(p, { + __restate_context: { + value: this, + }, + journalIndex: { + value: this.stateMachine.getUserCodeJournalIndex(), + }, + }) as InternalCombineablePromise; } } @@ -602,7 +657,7 @@ export class RpcContextImpl implements RpcContext { ): Promise { return this.ctx.sideEffect(fn, retryPolicy); } - public awakeable(): { id: string; promise: Promise } { + public awakeable(): { id: string; promise: CombineablePromise } { return this.ctx.awakeable(); } public resolveAwakeable(id: string, payload: T): void { @@ -611,7 +666,7 @@ export class RpcContextImpl implements RpcContext { public rejectAwakeable(id: string, reason: string): void { this.ctx.rejectAwakeable(id, reason); } - public sleep(millis: number): Promise { + public sleep(millis: number): CombineablePromise { return this.ctx.sleep(millis); } diff --git a/src/state_machine.ts b/src/state_machine.ts index 97c411e5..9fe3b4da 100644 --- a/src/state_machine.ts +++ b/src/state_machine.ts @@ -14,13 +14,13 @@ import { RestateGrpcContextImpl } from "./restate_context_impl"; import { Connection, RestateStreamConsumer } from "./connection/connection"; import { ProtocolMode } from "./generated/proto/discovery"; import { Message } from "./types/types"; -import { CompletablePromise, makeFqServiceName } from "./utils/utils"; import { createStateMachineConsole, StateMachineConsole, } from "./utils/message_logger"; import { clearTimeout } from "timers"; import { + COMBINATOR_ENTRY_MESSAGE, COMPLETION_MESSAGE_TYPE, END_MESSAGE_TYPE, EndMessage, @@ -42,6 +42,18 @@ import { } from "./types/errors"; import { LocalStateStore } from "./local_state_store"; import { createRestateConsole, LoggerContext } from "./logger"; +import { + CompletablePromise, + wrapDeeply, + WRAPPED_PROMISE_PENDING, + WrappedPromise, +} from "./utils/promises"; +import { + PromiseCombinatorTracker, + PromiseId, + PromiseType, +} from "./promise_combinator_tracker"; +import { CombinatorEntryMessage } from "./generated/proto/javascript"; export class StateMachine implements RestateStreamConsumer { private journal: Journal; @@ -66,6 +78,8 @@ export class StateMachine implements RestateStreamConsumer { // Suspension timeout that gets set and cleared based on completion messages; private suspensionTimeout?: NodeJS.Timeout; + private promiseCombinatorTracker: PromiseCombinatorTracker; + console: StateMachineConsole; constructor( @@ -75,7 +89,6 @@ export class StateMachine implements RestateStreamConsumer { loggerContext: LoggerContext, private readonly suspensionMillis: number = 30_000 ) { - this.journal = new Journal(this.invocation); this.localStateStore = invocation.localStateStore; this.console = createStateMachineConsole(loggerContext); @@ -86,6 +99,11 @@ export class StateMachine implements RestateStreamConsumer { createRestateConsole(loggerContext, () => !this.journal.isReplaying()), this ); + this.journal = new Journal(this.invocation); + this.promiseCombinatorTracker = new PromiseCombinatorTracker( + this.readCombinatorOrderEntry.bind(this), + this.writeCombinatorOrderEntry.bind(this) + ); } public handleMessage(m: Message): boolean { @@ -138,7 +156,7 @@ export class StateMachine implements RestateStreamConsumer { // if the state machine is already closed, return a promise that never // completes, so that the user code does not resume if (this.stateMachineClosed) { - return wrapDeeply(new CompletablePromise().promise); + return WRAPPED_PROMISE_PENDING as WrappedPromise; } const promise = this.journal.handleUserSideMessage(messageType, message); @@ -179,6 +197,103 @@ export class StateMachine implements RestateStreamConsumer { }); } + // -- Methods related to combinators to wire up promise combinator API with PromiseCombinatorTracker + + public createCombinator( + combinatorConstructor: ( + promises: PromiseLike[] + ) => Promise, + promises: Array<{ id: PromiseId; promise: Promise }> + ) { + if (this.stateMachineClosed) { + return WRAPPED_PROMISE_PENDING as WrappedPromise; + } + + // We don't need the promise wrapping here to schedule a suspension, + // because the combined promises will already have that, so once we call then() on them, + // if we have to suspend we will suspend. + return this.promiseCombinatorTracker.createCombinator( + combinatorConstructor, + promises + ); + } + + readCombinatorOrderEntry(combinatorId: number): PromiseId[] | undefined { + const wannabeCombinatorEntry = this.journal.readNextReplayEntry(); + if (wannabeCombinatorEntry === undefined) { + // We're in processing mode + return undefined; + } + if (wannabeCombinatorEntry.messageType !== COMBINATOR_ENTRY_MESSAGE) { + throw RetryableError.journalMismatch( + this.journal.getUserCodeJournalIndex(), + wannabeCombinatorEntry, + { + messageType: COMBINATOR_ENTRY_MESSAGE, + message: { + combinatorId, + } as CombinatorEntryMessage, + } + ); + } + + const combinatorMessage = + wannabeCombinatorEntry.message as CombinatorEntryMessage; + if (combinatorMessage.combinatorId != combinatorId) { + throw RetryableError.journalMismatch( + this.journal.getUserCodeJournalIndex(), + wannabeCombinatorEntry, + { + messageType: COMBINATOR_ENTRY_MESSAGE, + message: { + combinatorId, + } as CombinatorEntryMessage, + } + ); + } + + this.console.debugJournalMessage( + "Matched and replayed message from journal", + COMBINATOR_ENTRY_MESSAGE, + combinatorMessage + ); + + return combinatorMessage.journalEntriesOrder.map((id) => ({ + id, + type: PromiseType.JournalEntry, + })); + } + + async writeCombinatorOrderEntry(combinatorId: number, order: PromiseId[]) { + if (this.journal.isProcessing()) { + const combinatorMessage: CombinatorEntryMessage = { + combinatorId, + journalEntriesOrder: order.map((pid) => pid.id), + }; + this.console.debugJournalMessage( + "Adding message to journal and sending to Restate", + COMBINATOR_ENTRY_MESSAGE, + combinatorMessage + ); + + const ackPromise = this.journal.appendJournalEntry( + COMBINATOR_ENTRY_MESSAGE, + combinatorMessage + ); + this.send( + new Message( + COMBINATOR_ENTRY_MESSAGE, + combinatorMessage, + undefined, + undefined, + true + ) + ); + + await ackPromise; + } + } + /** * Invokes the RPC function and returns a promise that completes when the state machine * stops processing the invocation, meaning when: @@ -394,13 +509,15 @@ export class StateMachine implements RestateStreamConsumer { this.clearSuspensionTimeout(); } + /** + * This method is invoked when we hit a suspension point. + * + * A suspension point is everytime the user "await"s a Promise returned by RestateContext that might be completed at a later point in time by a CompletionMessage. + */ private scheduleSuspension() { // If there was already a timeout set, we want to reset the time to postpone suspension as long as we make progress. // So we first clear the old timeout, and then we set a new one. - if (this.suspensionTimeout !== undefined) { - clearTimeout(this.suspensionTimeout); - this.suspensionTimeout = undefined; - } + this.clearSuspensionTimeout(); const delay = this.getSuspensionMillis(); this.console.debugJournalMessage( @@ -465,10 +582,6 @@ export class StateMachine implements RestateStreamConsumer { await this.finish(); } - public async notifyHandlerExecutionError(e: RetryableError | TerminalError) { - await this.sendErrorAndFinish(e); - } - /** * WARNING: make sure you use this at the right point in the code * After the index has been incremented... @@ -478,13 +591,6 @@ export class StateMachine implements RestateStreamConsumer { return this.journal.getUserCodeJournalIndex(); } - public getFullServiceName(): string { - return makeFqServiceName( - this.invocation.method.pkg, - this.invocation.method.service - ); - } - public handleInputClosed(): void { if ( this.journal.isClosed() || @@ -524,72 +630,3 @@ export class StateMachine implements RestateStreamConsumer { } } } -/** - * Returns a promise that wraps the original promise and calls cb() at the first time - * this promise or any nested promise that is chained to it is awaited. (then-ed) - */ - -/* eslint-disable @typescript-eslint/no-explicit-any */ -export type WrappedPromise = Promise & { - transform: ( - onfulfilled?: - | ((value: T) => TResult1 | PromiseLike) - | null - | undefined, - onrejected?: - | ((reason: any) => TResult2 | PromiseLike) - | null - | undefined - ) => Promise; -}; - -const wrapDeeply = ( - promise: Promise, - cb?: () => void -): WrappedPromise => { - /* eslint-disable @typescript-eslint/no-explicit-any */ - return { - transform: function ( - onfulfilled?: - | ((value: T) => TResult1 | PromiseLike) - | null - | undefined, - onrejected?: - | ((reason: any) => TResult2 | PromiseLike) - | null - | undefined - ): Promise { - return wrapDeeply(promise.then(onfulfilled, onrejected), cb); - }, - - then: function ( - onfulfilled?: - | ((value: T) => TResult1 | PromiseLike) - | null - | undefined, - onrejected?: - | ((reason: any) => TResult2 | PromiseLike) - | null - | undefined - ): Promise { - if (cb !== undefined) { - cb(); - } - return promise.then(onfulfilled, onrejected); - }, - catch: function ( - onrejected?: - | ((reason: any) => TResult | PromiseLike) - | null - | undefined - ): Promise { - return wrapDeeply(promise.catch(onrejected), cb); - }, - finally: function ( - onfinally?: (() => void) | null | undefined - ): Promise { - return wrapDeeply(promise.finally(onfinally), cb); - }, - [Symbol.toStringTag]: "", - }; -}; diff --git a/src/types/errors.ts b/src/types/errors.ts index 7d0e7e58..fe63b9dd 100644 --- a/src/types/errors.ts +++ b/src/types/errors.ts @@ -13,9 +13,8 @@ import { ErrorMessage, Failure } from "../generated/proto/protocol"; import { formatMessageAsJson } from "../utils/utils"; -import { Message } from "./types"; -import { JournalEntry } from "../journal"; import { FailureWithTerminal } from "../generated/proto/javascript"; +import * as p from "./protocol"; export enum ErrorCodes { /** @@ -214,17 +213,23 @@ export class RetryableError extends RestateError { public static journalMismatch( journalIndex: number, - replayMessage: Message, - journalEntry: JournalEntry + actualEntry: { + messageType: bigint; + message: p.ProtocolMessage | Uint8Array; + }, + expectedEntry: { + messageType: bigint; + message: p.ProtocolMessage | Uint8Array; + } ) { const msg = `Journal mismatch: Replayed journal entries did not correspond to the user code. The user code has to be deterministic! The journal entry at position ${journalIndex} was: - In the user code: type: ${ - journalEntry.messageType - }, message:${formatMessageAsJson(journalEntry.message)} + expectedEntry.messageType + }, message:${formatMessageAsJson(expectedEntry.message)} - In the replayed messages: type: ${ - replayMessage.messageType - }, message: ${formatMessageAsJson(replayMessage.message)}`; + actualEntry.messageType + }, message: ${formatMessageAsJson(actualEntry.message)}`; return new RetryableError(msg, { errorCode: RestateErrorCodes.JOURNAL_MISMATCH, }); diff --git a/src/types/protocol.ts b/src/types/protocol.ts index 31788e57..c3259eb9 100644 --- a/src/types/protocol.ts +++ b/src/types/protocol.ts @@ -9,7 +9,10 @@ * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE */ -import { SideEffectEntryMessage } from "../generated/proto/javascript"; +import { + SideEffectEntryMessage, + CombinatorEntryMessage, +} from "../generated/proto/javascript"; import { AwakeableEntryMessage, BackgroundInvokeEntryMessage, @@ -72,6 +75,7 @@ export const AWAKEABLE_IDENTIFIER_PREFIX = "prom_1"; // Export the custom message types // Side effects are custom messages because the runtime does not need to inspect them export const SIDE_EFFECT_ENTRY_MESSAGE_TYPE = 0xfc01n; +export const COMBINATOR_ENTRY_MESSAGE = 0xfc02n; // Restate DuplexStream @@ -95,6 +99,7 @@ export const KNOWN_MESSAGE_TYPES = new Set([ AWAKEABLE_ENTRY_MESSAGE_TYPE, COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE, SIDE_EFFECT_ENTRY_MESSAGE_TYPE, + COMBINATOR_ENTRY_MESSAGE, ]); const PROTOBUF_MESSAGE_NAME_BY_TYPE = new Map([ @@ -115,6 +120,7 @@ const PROTOBUF_MESSAGE_NAME_BY_TYPE = new Map([ [AWAKEABLE_ENTRY_MESSAGE_TYPE, "AwakeableEntryMessage"], [COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE, "CompleteAwakeableEntryMessage"], [SIDE_EFFECT_ENTRY_MESSAGE_TYPE, "SideEffectEntryMessage"], + [COMBINATOR_ENTRY_MESSAGE, "CombinatorEntryMessage"], ]); export function formatMessageType(messageType: bigint) { @@ -142,6 +148,7 @@ const PROTOBUF_MESSAGES: Array<[bigint, any]> = [ [AWAKEABLE_ENTRY_MESSAGE_TYPE, AwakeableEntryMessage], [COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE, CompleteAwakeableEntryMessage], [SIDE_EFFECT_ENTRY_MESSAGE_TYPE, SideEffectEntryMessage], + [COMBINATOR_ENTRY_MESSAGE, CombinatorEntryMessage], ]; export const PROTOBUF_MESSAGE_BY_TYPE = new Map(PROTOBUF_MESSAGES); @@ -163,14 +170,17 @@ export type ProtocolMessage = | BackgroundInvokeEntryMessage | AwakeableEntryMessage | CompleteAwakeableEntryMessage - | SideEffectEntryMessage; + | SideEffectEntryMessage + | CombinatorEntryMessage; // These message types will trigger sending a suspension message from the runtime // for each of the protocol modes export const SUSPENSION_TRIGGERS: bigint[] = [ INVOKE_ENTRY_MESSAGE_TYPE, GET_STATE_ENTRY_MESSAGE_TYPE, - SIDE_EFFECT_ENTRY_MESSAGE_TYPE, AWAKEABLE_ENTRY_MESSAGE_TYPE, SLEEP_ENTRY_MESSAGE_TYPE, + COMBINATOR_ENTRY_MESSAGE, + // We need it because of the ack + SIDE_EFFECT_ENTRY_MESSAGE_TYPE, ]; diff --git a/src/utils/promises.ts b/src/utils/promises.ts new file mode 100644 index 00000000..4a5670d4 --- /dev/null +++ b/src/utils/promises.ts @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2023 - Restate Software, Inc., Restate GmbH + * + * This file is part of the Restate SDK for Node.js/TypeScript, + * which is released under the MIT license. + * + * You can find a copy of the license in file LICENSE in the root + * directory of this repository or package, or at + * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE + */ + +// -- Wrapped promise + +/* eslint-disable @typescript-eslint/no-explicit-any */ +export type WrappedPromise = Promise & { + // The reason for this transform is that we want to retain the wrapping. + // When working with WrappedPromise you MUST use this method instead of then for mapping the promise results. + transform: ( + onfulfilled?: + | ((value: T) => TResult1 | PromiseLike) + | null + | undefined, + onrejected?: + | ((reason: any) => TResult2 | PromiseLike) + | null + | undefined + ) => Promise; +}; + +export function wrapDeeply( + promise: Promise, + onThen?: () => void +): WrappedPromise { + // We need this to support nesting of WrappedPromise + let transform: ( + onfulfilled?: + | ((value: T) => TResult1 | PromiseLike) + | null + | undefined, + onrejected?: + | ((reason: any) => TResult2 | PromiseLike) + | null + | undefined + ) => Promise; + if (Object.hasOwn(promise, "transform")) { + const wrappedPromise = promise as WrappedPromise; + transform = (onfulfilled, onrejected) => + wrapDeeply(wrappedPromise.transform(onfulfilled, onrejected), onThen); + } else { + transform = (onfulfilled, onrejected) => + wrapDeeply(promise.then(onfulfilled, onrejected), onThen); + } + + /* eslint-disable @typescript-eslint/no-explicit-any */ + return { + transform, + + then: function ( + onfulfilled?: + | ((value: T) => TResult1 | PromiseLike) + | null + | undefined, + onrejected?: + | ((reason: any) => TResult2 | PromiseLike) + | null + | undefined + ): Promise { + if (onThen !== undefined) { + onThen(); + } + return promise.then(onfulfilled, onrejected); + }, + catch: function ( + onrejected?: + | ((reason: any) => TResult | PromiseLike) + | null + | undefined + ): Promise { + return wrapDeeply(promise.catch(onrejected), onThen); + }, + finally: function ( + onfinally?: (() => void) | null | undefined + ): Promise { + return wrapDeeply(promise.finally(onfinally), onThen); + }, + [Symbol.toStringTag]: "", + }; +} + +// Like https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Promise/withResolvers +// (not yet available in node) +export class CompletablePromise { + private success!: (value: T | PromiseLike) => void; + private failure!: (reason?: any) => void; + + public readonly promise: Promise; + + constructor() { + this.promise = new Promise((resolve, reject) => { + this.success = resolve; + this.failure = reject; + }); + } + + public resolve(value: T) { + this.success(value); + } + + public reject(reason?: any) { + this.failure(reason); + } +} + +// A promise that is never completed +// eslint-disable-next-line @typescript-eslint/no-empty-function +export const PROMISE_PENDING: Promise = new Promise(() => {}); +export const WRAPPED_PROMISE_PENDING: Promise = + wrapDeeply(PROMISE_PENDING); diff --git a/src/utils/utils.ts b/src/utils/utils.ts index 6ccc9622..1cc04287 100644 --- a/src/utils/utils.ts +++ b/src/utils/utils.ts @@ -33,28 +33,6 @@ import { SLEEP_ENTRY_MESSAGE_TYPE, } from "../types/protocol"; -export class CompletablePromise { - private success!: (value: T | PromiseLike) => void; - private failure!: (reason?: any) => void; - - public readonly promise: Promise; - - constructor() { - this.promise = new Promise((resolve, reject) => { - this.success = resolve; - this.failure = reject; - }); - } - - public resolve(value: T) { - this.success(value); - } - - public reject(reason?: any) { - this.failure(reason); - } -} - export function jsonSerialize(obj: any): string { return JSON.stringify(obj, (_, v) => typeof v === "bigint" ? "BIGINT::" + v.toString() : v @@ -106,10 +84,6 @@ export function formatMessageAsJson(obj: any): string { ); } -export function makeFqServiceName(pckg: string, name: string): string { - return pckg ? `${pckg}.${name}` : name; -} - /** * Equality functions * @param msg1 the current message from user code diff --git a/test/promise_combinator_tracker.test.ts b/test/promise_combinator_tracker.test.ts new file mode 100644 index 00000000..e6c21945 --- /dev/null +++ b/test/promise_combinator_tracker.test.ts @@ -0,0 +1,238 @@ +/* + * Copyright (c) 2023 - Restate Software, Inc., Restate GmbH + * + * This file is part of the Restate SDK for Node.js/TypeScript, + * which is released under the MIT license. + * + * You can find a copy of the license in file LICENSE in the root + * directory of this repository or package, or at + * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE + */ + +import { describe, expect } from "@jest/globals"; +import { CompletablePromise } from "../src/utils/promises"; +import { + newJournalEntryPromiseId, + PromiseCombinatorTracker, + PromiseId, +} from "../src/promise_combinator_tracker"; + +describe("PromiseCombinatorTracker with Promise.any", () => { + it("should provide order in processing mode", async () => { + const { completers, promises } = generateTestPromises(3); + + const testResultPromise = testCombinatorInProcessingMode( + Promise.any.bind(Promise), + promises + ); + + setImmediate(() => { + // Any doesn't return on first reject + completers[0].reject("bla"); + completers[2].resolve("my value"); + }); + + const { order, result } = await testResultPromise; + expect(result).toStrictEqual("my value"); + expect(order).toStrictEqual(createOrder(0, 2)); + }); + + it("should provide order in processing mode, with partially resolved promises", async () => { + const { completers, promises } = generateTestPromises(3); + // Any doesn't return on first reject + completers[0].reject("bla"); + + const testResultPromise = testCombinatorInProcessingMode( + Promise.any.bind(Promise), + promises + ); + + setImmediate(() => { + completers[2].resolve("my value"); + }); + + const { order, result } = await testResultPromise; + expect(result).toStrictEqual("my value"); + expect(order).toStrictEqual(createOrder(0, 2)); + }); + + it("should provide order in processing mode, with all promises already resolved", async () => { + const { completers, promises } = generateTestPromises(3); + // Any doesn't return on first reject + completers[0].reject("bla"); + completers[2].resolve("my value"); + + const testResultPromise = testCombinatorInProcessingMode( + Promise.any.bind(Promise), + promises + ); + + const { order, result } = await testResultPromise; + expect(result).toStrictEqual("my value"); + expect(order).toStrictEqual(createOrder(0, 2)); + }); + + it("should replay correctly", async () => { + const { completers, promises } = generateTestPromises(3); + // This should not influence the result + completers[1].resolve("another value"); + completers[2].resolve("my value"); + completers[0].reject("bla"); + + const result = await testCombinatorInReplayMode( + Promise.any.bind(Promise), + promises, + createOrder(0, 2) + ); + + expect(result).toStrictEqual("my value"); + }); +}); + +describe("PromiseCombinatorTracker with Promise.all", () => { + it("should provide order in processing mode, with failing child", async () => { + const { completers, promises } = generateTestPromises(3); + + const testResultPromise = testCombinatorInProcessingMode( + Promise.all.bind(Promise), + promises + ); + + setImmediate(() => { + completers[2].resolve("my value"); + completers[0].reject("my error"); + }); + + const { order, result } = await testResultPromise; + expect(result).toStrictEqual("my error"); + expect(order).toStrictEqual(createOrder(2, 0)); + }); + + it("should provide order in processing mode, with all success children", async () => { + const { completers, promises } = generateTestPromises(3); + + const testResultPromise = testCombinatorInProcessingMode( + Promise.all.bind(Promise), + promises + ); + + setImmediate(() => { + completers[2].resolve("my value 2"); + completers[0].resolve("my value 0"); + completers[1].resolve("my value 1"); + }); + + const { order, result } = await testResultPromise; + expect(result).toStrictEqual(["my value 0", "my value 1", "my value 2"]); + expect(order).toStrictEqual(createOrder(2, 0, 1)); + }); + + it("should replay correctly with failing child", async () => { + const { completers, promises } = generateTestPromises(3); + // This should not influence the result + completers[1].resolve("should be irrelevant"); + completers[2].resolve("my value"); + completers[0].reject("my error"); + + const result = await testCombinatorInReplayMode( + Promise.all.bind(Promise), + promises, + createOrder(2, 0) + ); + + expect(result).toStrictEqual("my error"); + }); + + it("should replay correctly with all success children", async () => { + const { completers, promises } = generateTestPromises(3); + completers[2].resolve("my value 2"); + completers[0].resolve("my value 0"); + completers[1].resolve("my value 1"); + + const result = await testCombinatorInReplayMode( + Promise.all.bind(Promise), + promises, + createOrder(2, 0, 1) + ); + + expect(result).toStrictEqual(["my value 0", "my value 1", "my value 2"]); + }); +}); + +// -- Some utility methods for these tests + +function generateTestPromises(n: number): { + completers: Array>; + promises: Array<{ id: PromiseId; promise: Promise }>; +} { + const completers = []; + const promises = []; + + for (let i = 0; i < n; i++) { + const completablePromise = new CompletablePromise(); + completers.push(completablePromise); + promises.push({ + id: newJournalEntryPromiseId(i), + promise: completablePromise.promise, + }); + } + + return { completers, promises }; +} + +function createOrder(...numbers: number[]) { + return numbers.map(newJournalEntryPromiseId); +} + +async function testCombinatorInProcessingMode( + combinatorConstructor: (promises: PromiseLike[]) => Promise, + promises: Array<{ id: PromiseId; promise: Promise }> +) { + const resultMap = new Map(); + const tracker = new PromiseCombinatorTracker( + () => { + return undefined; + }, + (combinatorIndex, order) => { + resultMap.set(combinatorIndex, order); + return new Promise((r) => r()); + } + ); + + return tracker.createCombinator(combinatorConstructor, promises).transform( + (result) => ({ + order: resultMap.get(0), + result, + }), + (result) => ({ + order: resultMap.get(0), + result, + }) + ); +} + +async function testCombinatorInReplayMode( + combinatorConstructor: (promises: PromiseLike[]) => Promise, + promises: Array<{ id: PromiseId; promise: Promise }>, + order: PromiseId[] +) { + const tracker = new PromiseCombinatorTracker( + (idx) => { + expect(idx).toStrictEqual(0); + return order; + }, + () => { + throw new Error("Unexpected call"); + } + ); + + return ( + tracker + .createCombinator(combinatorConstructor, promises) + // To make sure it behaves like testCombinatorInProcessingMode and always succeeds + .transform( + (v) => v, + (e) => e + ) + ); +} diff --git a/test/promise_combinators.test.ts b/test/promise_combinators.test.ts new file mode 100644 index 00000000..e5b683f3 --- /dev/null +++ b/test/promise_combinators.test.ts @@ -0,0 +1,329 @@ +/* + * Copyright (c) 2023 - Restate Software, Inc., Restate GmbH + * + * This file is part of the Restate SDK for Node.js/TypeScript, + * which is released under the MIT license. + * + * You can find a copy of the license in file LICENSE in the root + * directory of this repository or package, or at + * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE + */ + +import { describe, expect } from "@jest/globals"; +import * as restate from "../src/public_api"; +import { TestDriver } from "./testdriver"; +import { + awakeableMessage, + completionMessage, + getAwakeableId, + greetRequest, + greetResponse, + inputMessage, + outputMessage, + startMessage, + suspensionMessage, + END_MESSAGE, + combinatorEntryMessage, + sleepMessage, + sideEffectMessage, + ackMessage, +} from "./protoutils"; +import { TestGreeter, TestResponse } from "../src/generated/proto/test"; +import { SLEEP_ENTRY_MESSAGE_TYPE } from "../src/types/protocol"; +import { CombineablePromise } from "../src/restate_context"; + +class AwakeableSleepRaceGreeter implements TestGreeter { + async greet(): Promise { + const ctx = restate.useContext(this); + + const awakeable = ctx.awakeable(); + const sleep = ctx.sleep(1); + + const result = await CombineablePromise.race([awakeable.promise, sleep]); + + if (typeof result === "string") { + return TestResponse.create({ + greeting: `Hello ${result} for ${awakeable.id}`, + }); + } + + return TestResponse.create({ + greeting: `Hello timed-out`, + }); + } +} + +describe("AwakeableSleepRaceGreeter", () => { + it("should suspend without completions", async () => { + const result = await new TestDriver(new AwakeableSleepRaceGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + ]).run(); + + expect(result.length).toStrictEqual(3); + expect(result[0]).toStrictEqual(awakeableMessage()); + expect(result[1].messageType).toStrictEqual(SLEEP_ENTRY_MESSAGE_TYPE); + expect(result[2]).toStrictEqual(suspensionMessage([1, 2])); + }); + + it("handles completion of awakeable", async () => { + const result = await new TestDriver(new AwakeableSleepRaceGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + completionMessage(1, JSON.stringify("Francesco")), + ackMessage(3), + ]).run(); + + expect(result.length).toStrictEqual(5); + expect(result[0]).toStrictEqual(awakeableMessage()); + expect(result[1].messageType).toStrictEqual(SLEEP_ENTRY_MESSAGE_TYPE); + expect(result.slice(2)).toStrictEqual([ + combinatorEntryMessage(0, [1]), + outputMessage(greetResponse(`Hello Francesco for ${getAwakeableId(1)}`)), + END_MESSAGE, + ]); + }); + + it("handles completion of sleep", async () => { + const result = await new TestDriver(new AwakeableSleepRaceGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + completionMessage(2, undefined, true), + ackMessage(3), + ]).run(); + + expect(result.length).toStrictEqual(5); + expect(result[0]).toStrictEqual(awakeableMessage()); + expect(result[1].messageType).toStrictEqual(SLEEP_ENTRY_MESSAGE_TYPE); + expect(result.slice(2)).toStrictEqual([ + combinatorEntryMessage(0, [2]), + outputMessage(greetResponse(`Hello timed-out`)), + END_MESSAGE, + ]); + }); + + it("handles replay of the awakeable", async () => { + const result = await new TestDriver(new AwakeableSleepRaceGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + awakeableMessage("Francesco"), + ackMessage(3), + ]).run(); + + expect(result.length).toStrictEqual(4); + expect(result[0].messageType).toStrictEqual(SLEEP_ENTRY_MESSAGE_TYPE); + expect(result.slice(1)).toStrictEqual([ + combinatorEntryMessage(0, [1]), + outputMessage(greetResponse(`Hello Francesco for ${getAwakeableId(1)}`)), + END_MESSAGE, + ]); + }); + + it("handles replay of the awakeable and sleep", async () => { + const result = await new TestDriver(new AwakeableSleepRaceGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + awakeableMessage("Francesco"), + sleepMessage(1), + ackMessage(3), + ]).run(); + + expect(result).toStrictEqual([ + // The awakeable will be chosen because Promise.race will pick the first promise, in case both are resolved + combinatorEntryMessage(0, [1]), + outputMessage(greetResponse(`Hello Francesco for ${getAwakeableId(1)}`)), + END_MESSAGE, + ]); + }); + + it("handles replay of the combinator with awakeable completed", async () => { + const result = await new TestDriver(new AwakeableSleepRaceGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + awakeableMessage("Francesco"), + sleepMessage(1), + combinatorEntryMessage(0, [1]), + ]).run(); + + expect(result).toStrictEqual([ + outputMessage(greetResponse(`Hello Francesco for ${getAwakeableId(1)}`)), + END_MESSAGE, + ]); + }); + + it("handles replay of the combinator with sleep completed", async () => { + const result = await new TestDriver(new AwakeableSleepRaceGreeter(), [ + startMessage(), + inputMessage(greetRequest("Till")), + awakeableMessage(), + sleepMessage(1, {}), + combinatorEntryMessage(0, [2]), + ]).run(); + + expect(result).toStrictEqual([ + outputMessage(greetResponse(`Hello timed-out`)), + END_MESSAGE, + ]); + }); +}); + +class AwakeableSleepRaceInterleavedWithSideEffectGreeter + implements TestGreeter +{ + async greet(): Promise { + const ctx = restate.useContext(this); + + const awakeable = ctx.awakeable(); + const sleep = ctx.sleep(1); + const combinatorPromise = CombineablePromise.race([ + awakeable.promise, + sleep, + ]); + + await ctx.sideEffect(async () => "sideEffect"); + + // Because the combinatorPromise generates the message when awaited, the entries order here should be: + // * AwakeableEntry + // * SleepEntry + // * SideEffectEntry + // * CombinatorOrderEntry + const result = await combinatorPromise; + + if (typeof result === "string") { + return TestResponse.create({ + greeting: `Hello ${result} for ${awakeable.id}`, + }); + } + + return TestResponse.create({ + greeting: `Hello timed-out`, + }); + } +} + +describe("AwakeableSleepRaceInterleavedWithSideEffectGreeter", () => { + it("generates the combinator entry after the side effect, when processing first time", async () => { + const result = await new TestDriver( + new AwakeableSleepRaceInterleavedWithSideEffectGreeter(), + [ + startMessage(), + inputMessage(greetRequest("Till")), + completionMessage(1, JSON.stringify("Francesco")), + ackMessage(3), + ackMessage(4), + ] + ).run(); + + expect(result.length).toStrictEqual(6); + expect(result[0]).toStrictEqual(awakeableMessage()); + expect(result[1].messageType).toStrictEqual(SLEEP_ENTRY_MESSAGE_TYPE); + expect(result.slice(2)).toStrictEqual([ + sideEffectMessage("sideEffect"), + combinatorEntryMessage(0, [1]), + outputMessage(greetResponse(`Hello Francesco for ${getAwakeableId(1)}`)), + END_MESSAGE, + ]); + }); + + it("generates the combinator entry after the side effect, when replaying up to sleep", async () => { + const result = await new TestDriver( + new AwakeableSleepRaceInterleavedWithSideEffectGreeter(), + [ + startMessage(), + inputMessage(greetRequest("Till")), + awakeableMessage("Francesco"), + sleepMessage(1), + ackMessage(3), + ackMessage(4), + ] + ).run(); + + expect(result).toStrictEqual([ + sideEffectMessage("sideEffect"), + combinatorEntryMessage(0, [1]), + outputMessage(greetResponse(`Hello Francesco for ${getAwakeableId(1)}`)), + END_MESSAGE, + ]); + }); +}); + +class CombineablePromiseThenSideEffect implements TestGreeter { + async greet(): Promise { + const ctx = restate.useContext(this); + + const a1 = ctx.awakeable(); + const a2 = ctx.awakeable(); + const combinatorResult = await CombineablePromise.race([ + a1.promise, + a2.promise, + ]); + + const sideEffectResult = await ctx.sideEffect( + async () => "sideEffect" + ); + + return TestResponse.create({ + greeting: combinatorResult + "-" + sideEffectResult, + }); + } +} + +describe("CombineablePromiseThenSideEffect", () => { + it("after the combinator entry, suspends waiting for ack", async () => { + const result = await new TestDriver( + new CombineablePromiseThenSideEffect(), + [ + startMessage(), + inputMessage(greetRequest("Till")), + awakeableMessage("Francesco"), + awakeableMessage(), + ] + ).run(); + + expect(result).toStrictEqual([ + combinatorEntryMessage(0, [1]), + suspensionMessage([2, 3]), + ]); + }); + + it("after the combinator entry and the ack, completes", async () => { + const result = await new TestDriver( + new CombineablePromiseThenSideEffect(), + [ + startMessage(), + inputMessage(greetRequest("Till")), + awakeableMessage("Francesco"), + awakeableMessage(), + ackMessage(3), + ackMessage(4), + ] + ).run(); + + expect(result).toStrictEqual([ + combinatorEntryMessage(0, [1]), + sideEffectMessage("sideEffect"), + outputMessage(greetResponse(`Francesco-sideEffect`)), + END_MESSAGE, + ]); + }); + + it("no need to wait for ack when replaing the combinator entry", async () => { + const result = await new TestDriver( + new CombineablePromiseThenSideEffect(), + [ + startMessage(), + inputMessage(greetRequest("Till")), + awakeableMessage("Francesco"), + awakeableMessage(), + combinatorEntryMessage(0, [1]), + ackMessage(4), + ] + ).run(); + + expect(result).toStrictEqual([ + sideEffectMessage("sideEffect"), + outputMessage(greetResponse(`Francesco-sideEffect`)), + END_MESSAGE, + ]); + }); +}); diff --git a/test/promises.test.ts b/test/promises.test.ts new file mode 100644 index 00000000..485f4963 --- /dev/null +++ b/test/promises.test.ts @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2023 - Restate Software, Inc., Restate GmbH + * + * This file is part of the Restate SDK for Node.js/TypeScript, + * which is released under the MIT license. + * + * You can find a copy of the license in file LICENSE in the root + * directory of this repository or package, or at + * https://github.com/restatedev/sdk-typescript/blob/main/LICENSE + */ + +import { describe, expect } from "@jest/globals"; +import { + CompletablePromise, + wrapDeeply, + WrappedPromise, +} from "../src/utils/promises"; + +describe("promises.wrapDeeply", () => { + it("should support nested wrapping", async () => { + const callbackInvokeOrder: number[] = []; + const completablePromise = new CompletablePromise(); + + let p = completablePromise.promise; + p = wrapDeeply(p, () => { + callbackInvokeOrder.push(2); + }); + p = wrapDeeply(p, () => { + callbackInvokeOrder.push(1); + }); + p = (p as WrappedPromise).transform((v) => v + " transformed"); + + completablePromise.resolve("my value"); + + expect(await p).toStrictEqual("my value transformed"); + expect(callbackInvokeOrder).toStrictEqual([1, 2]); + }); +}); diff --git a/test/protocol_stream.test.ts b/test/protocol_stream.test.ts index b10e3b8f..e2a8e51f 100644 --- a/test/protocol_stream.test.ts +++ b/test/protocol_stream.test.ts @@ -21,8 +21,8 @@ import { import { RestateHttp2Connection } from "../src/connection/http_connection"; import { Header, Message } from "../src/types/types"; import stream from "stream"; -import { CompletablePromise } from "../src/utils/utils"; import { setTimeout } from "timers/promises"; +import { CompletablePromise } from "../src/utils/promises"; // The following test suite is taken from headers.rs describe("Header", () => { diff --git a/test/protoutils.ts b/test/protoutils.ts index 4d9b1bad..3391e885 100644 --- a/test/protoutils.ts +++ b/test/protoutils.ts @@ -46,10 +46,12 @@ import { END_MESSAGE_TYPE, EndMessage, AWAKEABLE_IDENTIFIER_PREFIX, + COMBINATOR_ENTRY_MESSAGE, } from "../src/types/protocol"; import { Message } from "../src/types/types"; import { TestRequest, TestResponse } from "../src/generated/proto/test"; import { + CombinatorEntryMessage, FailureWithTerminal, SideEffectEntryMessage, } from "../src/generated/proto/javascript"; @@ -431,6 +433,22 @@ export function suspensionMessage(entryIndices: number[]): Message { ); } +export function combinatorEntryMessage( + combinatorId: number, + journalEntriesOrder: number[] +): Message { + return new Message( + COMBINATOR_ENTRY_MESSAGE, + CombinatorEntryMessage.create({ + combinatorId, + journalEntriesOrder, + }), + undefined, + undefined, + true + ); +} + export function failure( msg: string, code: number = ErrorCodes.INTERNAL @@ -488,10 +506,13 @@ export function getAwakeableId(entryIndex: number): string { const encodedEntryIndex = Buffer.alloc(4 /* Size of u32 */); encodedEntryIndex.writeUInt32BE(entryIndex); - return AWAKEABLE_IDENTIFIER_PREFIX + Buffer.concat([ - Buffer.from("f311f1fdcb9863f0018bd3400ecd7d69b547204e776218b2", "hex"), - encodedEntryIndex, - ]).toString("base64url"); + return ( + AWAKEABLE_IDENTIFIER_PREFIX + + Buffer.concat([ + Buffer.from("f311f1fdcb9863f0018bd3400ecd7d69b547204e776218b2", "hex"), + encodedEntryIndex, + ]).toString("base64url") + ); } export function keyVal(key: string, value: any): Buffer[] { diff --git a/test/side_effect.test.ts b/test/side_effect.test.ts index 388f4485..213cee01 100644 --- a/test/side_effect.test.ts +++ b/test/side_effect.test.ts @@ -13,22 +13,22 @@ import { describe, expect } from "@jest/globals"; import * as restate from "../src/public_api"; import { TestDriver } from "./testdriver"; import { + ackMessage, + backgroundInvokeMessage, + checkJournalMismatchError, + checkTerminalError, completionMessage, + END_MESSAGE, + failureWithTerminal, + getAwakeableId, + greetRequest, + greetResponse, inputMessage, + invokeMessage, outputMessage, sideEffectMessage, startMessage, - greetRequest, - greetResponse, - invokeMessage, - getAwakeableId, - backgroundInvokeMessage, suspensionMessage, - checkTerminalError, - checkJournalMismatchError, - failureWithTerminal, - ackMessage, - END_MESSAGE, } from "./protoutils"; import { TestGreeter, @@ -37,6 +37,7 @@ import { TestResponse, } from "../src/generated/proto/test"; import { ErrorCodes, TerminalError } from "../src/types/errors"; +import { ProtocolMode } from "../src/generated/proto/discovery"; class SideEffectGreeter implements TestGreeter { constructor(readonly sideEffectOutput: unknown) {} @@ -1042,3 +1043,42 @@ describe("TerminalErrorSideEffectService", () => { checkTerminalError(result[1], "Something bad happened."); }); }); + +class SideEffectWithMutableVariable implements TestGreeter { + constructor(readonly externalSideEffect: { effectExecuted: boolean }) {} + + async greet(): Promise { + const ctx = restate.useContext(this); + + await ctx.sideEffect(() => { + // I'm trying to simulate a case where the side effect resolution + // happens on the next event loop tick. + return new Promise((resolve) => { + setTimeout(() => { + this.externalSideEffect.effectExecuted = true; + resolve(undefined); + }, 100); + }); + }); + + throw new Error("It should not reach this point"); + } +} + +describe("SideEffectWithMutableVariable", () => { + it("should suspend after mutating the variable, and not before!", async () => { + const externalSideEffect = { effectExecuted: false }; + + const result = await new TestDriver( + new SideEffectWithMutableVariable(externalSideEffect), + [startMessage(), inputMessage(greetRequest("Till"))], + ProtocolMode.REQUEST_RESPONSE + ).run(); + + expect(result).toStrictEqual([ + sideEffectMessage(undefined), + suspensionMessage([1]), + ]); + expect(externalSideEffect.effectExecuted).toBeTruthy(); + }); +});