From 481f9fa4964deadc81ae3fdf6b66633f3c8a8dac Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Thu, 19 Jun 2025 17:18:07 +0000 Subject: [PATCH] chore: add generic types for input & auth data --- packages/core/src/client/worker-common.ts | 13 +- packages/core/src/registry/config.ts | 2 +- packages/core/src/worker/action.ts | 10 +- packages/core/src/worker/config.ts | 114 +++++++++++------- packages/core/src/worker/connection.ts | 15 ++- packages/core/src/worker/context.ts | 8 +- packages/core/src/worker/definition.ts | 61 +++++++--- packages/core/src/worker/instance.ts | 57 ++++++--- packages/core/src/worker/mod.ts | 8 +- .../core/src/worker/protocol/message/mod.ts | 18 +-- packages/core/tests/worker-types.test.ts | 28 ++++- turbo.json | 4 +- 12 files changed, 226 insertions(+), 112 deletions(-) diff --git a/packages/core/src/client/worker-common.ts b/packages/core/src/client/worker-common.ts index b2ef87dc0..cabd449f5 100644 --- a/packages/core/src/client/worker-common.ts +++ b/packages/core/src/client/worker-common.ts @@ -1,11 +1,17 @@ -import type { AnyWorkerDefinition, WorkerDefinition } from "@/worker/definition"; +import type { + AnyWorkerDefinition, + WorkerDefinition, +} from "@/worker/definition"; import type * as protoHttpResolve from "@/worker/protocol/http/resolve"; import type { Encoding } from "@/worker/protocol/serde"; import type { WorkerQuery } from "@/manager/protocol/query"; import { logger } from "./log"; import * as errors from "./errors"; import { sendHttpRequest } from "./utils"; -import { HEADER_WORKER_QUERY, HEADER_ENCODING } from "@/worker/router-endpoints"; +import { + HEADER_WORKER_QUERY, + HEADER_ENCODING, +} from "@/worker/router-endpoints"; /** * Action function returned by Worker connections and handles. @@ -27,11 +33,10 @@ export type WorkerActionFunction< * Maps action methods from worker definition to typed function signatures. */ export type WorkerDefinitionActions = - AD extends WorkerDefinition + AD extends WorkerDefinition ? { [K in keyof R]: R[K] extends (...args: infer Args) => infer Return ? WorkerActionFunction : never; } : never; - diff --git a/packages/core/src/registry/config.ts b/packages/core/src/registry/config.ts index d265ecb0f..a8b569a53 100644 --- a/packages/core/src/registry/config.ts +++ b/packages/core/src/registry/config.ts @@ -43,7 +43,7 @@ export type WorkerPeerConfig = z.infer; export const WorkersSchema = z.record( z.string(), - z.custom>(), + z.custom>(), ); export type Workers = z.infer; diff --git a/packages/core/src/worker/action.ts b/packages/core/src/worker/action.ts index b302421f8..bc03dd4da 100644 --- a/packages/core/src/worker/action.ts +++ b/packages/core/src/worker/action.ts @@ -13,8 +13,8 @@ import { WorkerContext } from "./context"; * * @typeParam A Worker this action belongs to */ -export class ActionContext { - #workerContext: WorkerContext; +export class ActionContext { + #workerContext: WorkerContext; /** * Should not be called directly. @@ -23,8 +23,8 @@ export class ActionContext { * @param conn - The connection associated with the action */ constructor( - workerContext: WorkerContext, - public readonly conn: Conn, + workerContext: WorkerContext, + public readonly conn: Conn, ) { this.#workerContext = workerContext; } @@ -95,7 +95,7 @@ export class ActionContext { /** * Gets the map of connections. */ - get conns(): Map> { + get conns(): Map> { return this.#workerContext.conns; } diff --git a/packages/core/src/worker/config.ts b/packages/core/src/worker/config.ts index cb3684867..a660f131c 100644 --- a/packages/core/src/worker/config.ts +++ b/packages/core/src/worker/config.ts @@ -75,12 +75,12 @@ export const WorkerConfigSchema = z }, ); -export interface OnCreateOptions { - input?: unknown; +export interface OnCreateOptions { + input?: I; } -export interface CreateStateOptions { - input?: unknown; +export interface CreateStateOptions { + input?: I; } export interface OnConnectOptions { @@ -98,12 +98,12 @@ export interface OnConnectOptions { // This must have only one or the other or else S will not be able to be inferred // // Data returned from this handler will be available on `c.state`. -type CreateState = +type CreateState = | { state: S } | { createState: ( - c: WorkerContext, - opts: CreateStateOptions, + c: WorkerContext, + opts: CreateStateOptions, ) => S | Promise; } | Record; @@ -113,11 +113,11 @@ type CreateState = // This must have only one or the other or else S will not be able to be inferred // // Data returned from this handler will be available on `c.conn.state`. -type CreateConnState = +type CreateConnState = | { connState: CS } | { createConnState: ( - c: WorkerContext, + c: WorkerContext, opts: OnConnectOptions, ) => CS | Promise; } @@ -129,7 +129,7 @@ type CreateConnState = /** * @experimental */ -type CreateVars = +type CreateVars = | { /** * @experimental @@ -141,20 +141,23 @@ type CreateVars = * @experimental */ createVars: ( - c: WorkerContext, + c: WorkerContext, driverCtx: unknown, ) => V | Promise; } | Record; -export interface Actions { - [Action: string]: (c: ActionContext, ...args: any[]) => any; +export interface Actions { + [Action: string]: ( + c: ActionContext, + ...args: any[] + ) => any; } -//export type WorkerConfig = BaseWorkerConfig & -// WorkerConfigLifecycle & -// CreateState & -// CreateConnState; +//export type WorkerConfig = BaseWorkerConfig & +// WorkerConfigLifecycle & +// CreateState & +// CreateConnState; /** * @experimental @@ -170,7 +173,15 @@ interface OnAuthOptions { params: CP; } -interface BaseWorkerConfig> { +interface BaseWorkerConfig< + S, + CP, + CS, + V, + I, + AD, + R extends Actions, +> { /** * Called on the HTTP server before clients can interact with the worker. * @@ -196,7 +207,7 @@ interface BaseWorkerConfig> { * @returns Authentication data to attach to connections (must be serializable) * @throws Throw an error to deny access to the worker */ - onAuth?: (opts: OnAuthOptions) => unknown | Promise; + onAuth?: (opts: OnAuthOptions) => AD | Promise; /** * Called when the worker is first initialized. @@ -205,8 +216,8 @@ interface BaseWorkerConfig> { * This is called before any other lifecycle hooks. */ onCreate?: ( - c: WorkerContext, - opts: OnCreateOptions, + c: WorkerContext, + opts: OnCreateOptions, ) => void | Promise; /** @@ -217,7 +228,7 @@ interface BaseWorkerConfig> { * * @returns Void or a Promise that resolves when startup is complete */ - onStart?: (c: WorkerContext) => void | Promise; + onStart?: (c: WorkerContext) => void | Promise; /** * Called when the worker's state changes. @@ -227,7 +238,7 @@ interface BaseWorkerConfig> { * * @param newState The updated state */ - onStateChange?: (c: WorkerContext, newState: S) => void; + onStateChange?: (c: WorkerContext, newState: S) => void; /** * Called before a client connects to the worker. @@ -250,7 +261,7 @@ interface BaseWorkerConfig> { * @throws Throw an error to reject the connection */ onBeforeConnect?: ( - c: WorkerContext, + c: WorkerContext, opts: OnConnectOptions, ) => void | Promise; @@ -264,8 +275,8 @@ interface BaseWorkerConfig> { * @returns Void or a Promise that resolves when connection handling is complete */ onConnect?: ( - c: WorkerContext, - conn: Conn, + c: WorkerContext, + conn: Conn, ) => void | Promise; /** @@ -278,8 +289,8 @@ interface BaseWorkerConfig> { * @returns Void or a Promise that resolves when disconnect handling is complete */ onDisconnect?: ( - c: WorkerContext, - conn: Conn, + c: WorkerContext, + conn: Conn, ) => void | Promise; /** @@ -295,7 +306,7 @@ interface BaseWorkerConfig> { * @returns The modified output to send to the client */ onBeforeActionResponse?: ( - c: WorkerContext, + c: WorkerContext, name: string, args: unknown[], output: Out, @@ -307,7 +318,7 @@ interface BaseWorkerConfig> { // 1. Infer schema // 2. Omit keys that we'll manually define (because of generics) // 3. Define our own types that have generic constraints -export type WorkerConfig = Omit< +export type WorkerConfig = Omit< z.infer, | "actions" | "onAuth" @@ -325,10 +336,10 @@ export type WorkerConfig = Omit< | "vars" | "createVars" > & - BaseWorkerConfig> & - CreateState & - CreateConnState & - CreateVars; + BaseWorkerConfig> & + CreateState & + CreateConnState & + CreateVars; // See description on `WorkerConfig` export type WorkerConfigInput< @@ -336,7 +347,9 @@ export type WorkerConfigInput< CP, CS, V, - R extends Actions, + I, + AD, + R extends Actions, > = Omit< z.input, | "actions" @@ -355,16 +368,31 @@ export type WorkerConfigInput< | "vars" | "createVars" > & - BaseWorkerConfig & - CreateState & - CreateConnState & - CreateVars; + BaseWorkerConfig & + CreateState & + CreateConnState & + CreateVars; // For testing type definitions: -export function test>( - input: WorkerConfigInput, -): WorkerConfig { - const config = WorkerConfigSchema.parse(input) as WorkerConfig; +export function test< + S, + CP, + CS, + V, + I, + AD, + R extends Actions, +>( + input: WorkerConfigInput, +): WorkerConfig { + const config = WorkerConfigSchema.parse(input) as WorkerConfig< + S, + CP, + CS, + V, + I, + AD + >; return config; } diff --git a/packages/core/src/worker/connection.ts b/packages/core/src/worker/connection.ts index 8639c6d4e..8f0645f35 100644 --- a/packages/core/src/worker/connection.ts +++ b/packages/core/src/worker/connection.ts @@ -17,7 +17,7 @@ export function generateConnToken(): string { export type ConnId = string; -export type AnyConn = Conn; +export type AnyConn = Conn; /** * Represents a client connection to a worker. @@ -26,13 +26,13 @@ export type AnyConn = Conn; * * @see {@link https://rivet.gg/docs/connections|Connection Documentation} */ -export class Conn { +export class Conn { subscriptions: Set = new Set(); #stateEnabled: boolean; // TODO: Remove this cyclical reference - #worker: WorkerInstance; + #worker: WorkerInstance; /** * The proxied state that notifies of changes automatically. @@ -103,7 +103,7 @@ export class Conn { * @protected */ public constructor( - worker: WorkerInstance, + worker: WorkerInstance, persist: PersistedConn, driver: ConnDriver, stateEnabled: boolean, @@ -157,6 +157,11 @@ export class Conn { * @param reason - The reason for disconnection. */ public async disconnect(reason?: string) { - await this.#driver.disconnect(this.#worker, this, this.__persist.ds, reason); + await this.#driver.disconnect( + this.#worker, + this, + this.__persist.ds, + reason, + ); } } diff --git a/packages/core/src/worker/context.ts b/packages/core/src/worker/context.ts index 3cda79159..f7fd07064 100644 --- a/packages/core/src/worker/context.ts +++ b/packages/core/src/worker/context.ts @@ -8,10 +8,10 @@ import { Schedule } from "./schedule"; /** * WorkerContext class that provides access to worker methods and state */ -export class WorkerContext { - #worker: WorkerInstance; +export class WorkerContext { + #worker: WorkerInstance; - constructor(worker: WorkerInstance) { + constructor(worker: WorkerInstance) { this.#worker = worker; } @@ -84,7 +84,7 @@ export class WorkerContext { /** * Gets the map of connections. */ - get conns(): Map> { + get conns(): Map> { return this.#worker.conns; } diff --git a/packages/core/src/worker/definition.ts b/packages/core/src/worker/definition.ts index 6c5f4947a..617040a38 100644 --- a/packages/core/src/worker/definition.ts +++ b/packages/core/src/worker/definition.ts @@ -1,41 +1,70 @@ -import { - type WorkerConfig, - type Actions, -} from "./config"; +import { type WorkerConfig, type Actions } from "./config"; import { WorkerInstance } from "./instance"; import { WorkerContext } from "./context"; import type { ActionContext } from "./action"; -export type AnyWorkerDefinition = WorkerDefinition; +export type AnyWorkerDefinition = WorkerDefinition< + any, + any, + any, + any, + any, + any, + any +>; /** * Extracts the context type from an WorkerDefinition */ -export type WorkerContextOf = - AD extends WorkerDefinition - ? WorkerContext +export type WorkerContextOf = + AD extends WorkerDefinition< + infer S, + infer CP, + infer CS, + infer V, + infer I, + infer AD, + any + > + ? WorkerContext : never; /** * Extracts the context type from an WorkerDefinition */ -export type ActionContextOf = - AD extends WorkerDefinition - ? ActionContext +export type ActionContextOf = + AD extends WorkerDefinition< + infer S, + infer CP, + infer CS, + infer V, + infer I, + infer AD, + any + > + ? ActionContext : never; -export class WorkerDefinition> { - #config: WorkerConfig; +export class WorkerDefinition< + S, + CP, + CS, + V, + I, + AD, + R extends Actions, +> { + #config: WorkerConfig; - constructor(config: WorkerConfig) { + constructor(config: WorkerConfig) { this.#config = config; } - get config(): WorkerConfig { + get config(): WorkerConfig { return this.#config; } - instantiate(): WorkerInstance { + instantiate(): WorkerInstance { return new WorkerInstance(this.#config); } } diff --git a/packages/core/src/worker/instance.ts b/packages/core/src/worker/instance.ts index b55f38fd3..bc3c64ef8 100644 --- a/packages/core/src/worker/instance.ts +++ b/packages/core/src/worker/instance.ts @@ -39,7 +39,7 @@ export interface SaveStateOptions { /** Worker type alias with all `any` types. Used for `extends` in classes referencing this worker. */ // biome-ignore lint/suspicious/noExplicitAny: Needs to be used in `extends` -export type AnyWorkerInstance = WorkerInstance; +export type AnyWorkerInstance = WorkerInstance; export type ExtractWorkerState = A extends WorkerInstance< @@ -49,6 +49,10 @@ export type ExtractWorkerState = // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` any, // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` + any, + // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` + any, + // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` any > ? State @@ -62,6 +66,10 @@ export type ExtractWorkerConnParams = // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` any, // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` + any, + // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` + any, + // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` any > ? ConnParams @@ -75,14 +83,18 @@ export type ExtractWorkerConnState = any, infer ConnState, // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` + any, + // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` + any, + // biome-ignore lint/suspicious/noExplicitAny: Must be used for `extends` any > ? ConnState : never; -export class WorkerInstance { +export class WorkerInstance { // Shared worker context for this instance - workerContext: WorkerContext; + workerContext: WorkerContext; isStopping = false; #persistChanged = false; @@ -105,7 +117,7 @@ export class WorkerInstance { #vars?: V; #backgroundPromises: Promise[] = []; - #config: WorkerConfig; + #config: WorkerConfig; #connectionDrivers!: ConnDrivers; #workerDriver!: WorkerDriver; #workerId!: string; @@ -114,8 +126,8 @@ export class WorkerInstance { #region!: string; #ready = false; - #connections = new Map>(); - #subscriptionIndex = new Map>>(); + #connections = new Map>(); + #subscriptionIndex = new Map>>(); #schedule!: Schedule; @@ -136,7 +148,7 @@ export class WorkerInstance { * * @private */ - constructor(config: WorkerConfig) { + constructor(config: WorkerConfig) { this.#config = config; this.workerContext = new WorkerContext(this); } @@ -169,6 +181,8 @@ export class WorkerInstance { if ("createVars" in this.#config) { const dataOrPromise = this.#config.createVars( this.workerContext as unknown as WorkerContext< + undefined, + undefined, undefined, undefined, undefined, @@ -482,7 +496,7 @@ export class WorkerInstance { for (const connPersist of this.#persist.c) { // Create connections const driver = this.__getConnDriver(connPersist.d); - const conn = new Conn( + const conn = new Conn( this, connPersist, driver, @@ -498,7 +512,7 @@ export class WorkerInstance { } else { logger().info("worker creating"); - const input = await this.#workerDriver.readInput(this.#workerId); + const input = (await this.#workerDriver.readInput(this.#workerId)) as I; // Initialize worker state let stateData: unknown = undefined; @@ -511,6 +525,8 @@ export class WorkerInstance { // Convert state to undefined since state is not defined yet here stateData = await this.#config.createState( this.workerContext as unknown as WorkerContext< + undefined, + undefined, undefined, undefined, undefined, @@ -546,14 +562,14 @@ export class WorkerInstance { } } - __getConnForId(id: string): Conn | undefined { + __getConnForId(id: string): Conn | undefined { return this.#connections.get(id); } /** * Removes a connection and cleans up its resources. */ - __removeConn(conn: Conn | undefined) { + __removeConn(conn: Conn | undefined) { if (!conn) { logger().warn("`conn` does not exist"); return; @@ -622,6 +638,8 @@ export class WorkerInstance { if ("createConnState" in this.#config) { const dataOrPromise = this.#config.createConnState( this.workerContext as unknown as WorkerContext< + undefined, + undefined, undefined, undefined, undefined, @@ -667,7 +685,7 @@ export class WorkerInstance { driverId: string, driverState: unknown, authData: unknown, - ): Promise> { + ): Promise> { if (this.#connections.has(connectionId)) { throw new Error(`Connection already exists: ${connectionId}`); } @@ -684,7 +702,7 @@ export class WorkerInstance { a: authData, su: [], }; - const conn = new Conn( + const conn = new Conn( this, persist, driver, @@ -738,7 +756,10 @@ export class WorkerInstance { } // MARK: Messages - async processMessage(message: wsToServer.ToServer, conn: Conn) { + async processMessage( + message: wsToServer.ToServer, + conn: Conn, + ) { await processMessage(message, this, conn, { onExecuteAction: async (ctx, name, args) => { return await this.executeAction(ctx, name, args); @@ -755,7 +776,7 @@ export class WorkerInstance { // MARK: Events #addSubscription( eventName: string, - connection: Conn, + connection: Conn, fromPersist: boolean, ) { if (connection.subscriptions.has(eventName)) { @@ -785,7 +806,7 @@ export class WorkerInstance { #removeSubscription( eventName: string, - connection: Conn, + connection: Conn, fromRemoveConn: boolean, ) { if (!connection.subscriptions.has(eventName)) { @@ -844,7 +865,7 @@ export class WorkerInstance { * @internal */ async executeAction( - ctx: ActionContext, + ctx: ActionContext, actionName: string, args: unknown[], ): Promise { @@ -988,7 +1009,7 @@ export class WorkerInstance { /** * Gets the map of connections. */ - get conns(): Map> { + get conns(): Map> { return this.#connections; } diff --git a/packages/core/src/worker/mod.ts b/packages/core/src/worker/mod.ts index ac3b0a2be..08c5af305 100644 --- a/packages/core/src/worker/mod.ts +++ b/packages/core/src/worker/mod.ts @@ -20,9 +20,9 @@ export type { ActionContextOf, } from "./definition"; -export function worker>( - input: WorkerConfigInput, -): WorkerDefinition { - const config = WorkerConfigSchema.parse(input) as WorkerConfig; +export function worker>( + input: WorkerConfigInput, +): WorkerDefinition { + const config = WorkerConfigSchema.parse(input) as WorkerConfig; return new WorkerDefinition(config); } diff --git a/packages/core/src/worker/protocol/message/mod.ts b/packages/core/src/worker/protocol/message/mod.ts index fa8e330d3..9dd20e36c 100644 --- a/packages/core/src/worker/protocol/message/mod.ts +++ b/packages/core/src/worker/protocol/message/mod.ts @@ -69,24 +69,24 @@ export async function parseMessage( return message; } -export interface ProcessMessageHandler { +export interface ProcessMessageHandler { onExecuteAction?: ( - ctx: ActionContext, + ctx: ActionContext, name: string, args: unknown[], ) => Promise; - onSubscribe?: (eventName: string, conn: Conn) => Promise; + onSubscribe?: (eventName: string, conn: Conn) => Promise; onUnsubscribe?: ( eventName: string, - conn: Conn, + conn: Conn, ) => Promise; } -export async function processMessage( +export async function processMessage( message: wsToServer.ToServer, - worker: WorkerInstance, - conn: Conn, - handler: ProcessMessageHandler, + worker: WorkerInstance, + conn: Conn, + handler: ProcessMessageHandler, ) { let actionId: number | undefined; let actionName: string | undefined; @@ -110,7 +110,7 @@ export async function processMessage( argsCount: args.length, }); - const ctx = new ActionContext(worker.workerContext, conn); + const ctx = new ActionContext(worker.workerContext, conn); // Process the action request and wait for the result // This will wait for async actions to complete diff --git a/packages/core/tests/worker-types.test.ts b/packages/core/tests/worker-types.test.ts index e90a9851c..3fc1a5223 100644 --- a/packages/core/tests/worker-types.test.ts +++ b/packages/core/tests/worker-types.test.ts @@ -22,6 +22,14 @@ describe("WorkerDefinition", () => { foo: string; } + interface TestInput { + bar: string; + } + + interface TestAuthData { + baz: string; + } + // For testing type utilities, we don't need a real worker instance // We just need a properly typed WorkerDefinition to check against type TestActions = Record; @@ -30,12 +38,21 @@ describe("WorkerDefinition", () => { TestConnParams, TestConnState, TestVars, + TestInput, + TestAuthData, TestActions >; // Use expectTypeOf to verify our type utility works correctly expectTypeOf>().toEqualTypeOf< - WorkerContext + WorkerContext< + TestState, + TestConnParams, + TestConnState, + TestVars, + TestInput, + TestAuthData + > >(); // Make sure that different types are not compatible @@ -44,7 +61,14 @@ describe("WorkerDefinition", () => { } expectTypeOf>().not.toEqualTypeOf< - WorkerContext + WorkerContext< + DifferentState, + TestConnParams, + TestConnState, + TestVars, + TestInput, + TestAuthData + > >(); }); }); diff --git a/turbo.json b/turbo.json index 719e5beda..ed780730f 100644 --- a/turbo.json +++ b/turbo.json @@ -21,7 +21,9 @@ }, "dev": { // Both builds & checks types for this repo and all dependencies - "dependsOn": ["build", "^check-types", "check-types"] + // + // Build after checking types since check types will return errors faster + "dependsOn": ["^check-types", "check-types", "build"] }, "test": { "dependsOn": ["^build", "check-types"],