Skip to content

Commit

Permalink
feat(client): endpoint type definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
drochetti committed Nov 12, 2024
1 parent 5a1d115 commit ac19e8e
Show file tree
Hide file tree
Showing 18 changed files with 12,930 additions and 63 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
package-lock.json linguist-generated
docs/reference/** linguist-generated
libs/client/src/types/endpoints.ts linguist-generated
4 changes: 2 additions & 2 deletions apps/demo-nextjs-app-router/app/comfy/image-to-image/page.tsx
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/* eslint-disable @next/next/no-img-element */
"use client";

import { createFalClient } from "@fal-ai/client";
import { createFalClient, Result } from "@fal-ai/client";
import { useMemo, useState } from "react";

const fal = createFalClient({
Expand Down Expand Up @@ -80,7 +80,7 @@ export default function ComfyImageToImagePage() {
setLoading(true);
const start = Date.now();
try {
const { data } = await fal.subscribe<ComfyOutput>(
const { data }: Result<ComfyOutput> = await fal.subscribe(
"comfy/fal-ai/image-to-image",
{
input: {
Expand Down
4 changes: 2 additions & 2 deletions apps/demo-nextjs-app-router/app/comfy/image-to-video/page.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"use client";

import { createFalClient } from "@fal-ai/client";
import { createFalClient, Result } from "@fal-ai/client";
import { useMemo, useState } from "react";

const fal = createFalClient({
Expand Down Expand Up @@ -75,7 +75,7 @@ export default function ComfyImageToVideoPage() {
setLoading(true);
const start = Date.now();
try {
const { data } = await fal.subscribe<ComfyOutput>(
const { data }: Result<ComfyOutput> = await fal.subscribe(
"comfy/fal-ai/image-to-video",
{
input: {
Expand Down
4 changes: 2 additions & 2 deletions apps/demo-nextjs-app-router/app/comfy/text-to-image/page.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"use client";

import { createFalClient } from "@fal-ai/client";
import { createFalClient, Result } from "@fal-ai/client";
import { useMemo, useState } from "react";

const fal = createFalClient({
Expand Down Expand Up @@ -78,7 +78,7 @@ export default function ComfyTextToImagePage() {
setLoading(true);
const start = Date.now();
try {
const { data } = await fal.subscribe<ComfyOutput>(
const { data }: Result<ComfyOutput> = await fal.subscribe(
"comfy/fal-ai/text-to-image",
{
input: {
Expand Down
16 changes: 4 additions & 12 deletions apps/demo-nextjs-app-router/app/page.tsx
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"use client";

import { createFalClient } from "@fal-ai/client";
import { IllusionDiffusionOutput } from "@fal-ai/client/endpoints";
import { useMemo, useState } from "react";

const fal = createFalClient({
Expand All @@ -9,16 +10,6 @@ const fal = createFalClient({
// proxyUrl: 'http://localhost:3333/api/fal/proxy', // or your own external proxy
});

type Image = {
url: string;
file_name: string;
file_size: number;
};
type Output = {
image: Image;
};
// @snippet:end

type ErrorProps = {
error: any;
};
Expand Down Expand Up @@ -48,7 +39,7 @@ export default function Home() {
// Result state
const [loading, setLoading] = useState(false);
const [error, setError] = useState<Error | null>(null);
const [result, setResult] = useState<Output | null>(null);
const [result, setResult] = useState<IllusionDiffusionOutput | null>(null);
const [logs, setLogs] = useState<string[]>([]);
const [elapsedTime, setElapsedTime] = useState<number>(0);
// @snippet:end
Expand All @@ -71,12 +62,13 @@ export default function Home() {
};

const generateImage = async () => {
if (!imageFile) return;
reset();
// @snippet:start("client.queue.subscribe")
setLoading(true);
const start = Date.now();
try {
const result = await fal.subscribe<Output>("fal-ai/illusion-diffusion", {
const result = await fal.subscribe("fal-ai/illusion-diffusion", {
input: {
prompt,
image_url: imageFile,
Expand Down
4 changes: 2 additions & 2 deletions apps/demo-nextjs-app-router/app/whisper/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,10 @@ export default function WhisperDemo() {
setLoading(true);
const start = Date.now();
try {
const result = await fal.subscribe("fal-ai/whisper", {
const result = await fal.subscribe("fal-ai/wizper", {
input: {
file_name: "recording.wav",
audio_url: audioFile,
version: "3",
},
logs: true,
onQueueUpdate(update) {
Expand Down
15 changes: 14 additions & 1 deletion libs/client/package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "@fal-ai/client",
"description": "The fal.ai client for JavaScript and TypeScript",
"version": "1.0.4",
"version": "1.1.0-alpha.0",
"license": "MIT",
"repository": {
"type": "git",
Expand All @@ -15,6 +15,19 @@
"ml",
"typescript"
],
"exports": {
".": "./src/index.js",
"./endpoints": "./src/types/endpoints.js"
},
"typesVersions": {
"*": {
"endpoints": [
"src/types/endpoints.d.ts"
]
}
},
"main": "./src/index.js",
"types": "./src/index.d.ts",
"dependencies": {
"@msgpack/msgpack": "^3.0.0-beta2",
"eventsource-parser": "^1.1.2",
Expand Down
36 changes: 17 additions & 19 deletions libs/client/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import { buildUrl, dispatchRequest } from "./request";
import { resultResponseHandler } from "./response";
import { createStorageClient, StorageClient } from "./storage";
import { createStreamingClient, StreamingClient } from "./streaming";
import { Result, RunOptions } from "./types";
import { EndpointType, InputType, OutputType } from "./types/client";
import { Result, RunOptions } from "./types/common";

/**
* The main client type, it provides access to simple API model usage,
Expand Down Expand Up @@ -44,10 +45,10 @@ export interface FalClient {
* @param endpointId the registered function revision id or alias.
* @returns the remote function output
*/
run<Output = any, Input = Record<string, any>>(
endpointId: string,
options: RunOptions<Input>,
): Promise<Result<Output>>;
run<Id extends EndpointType>(
endpointId: Id,
options: RunOptions<InputType<Id>>,
): Promise<Result<OutputType<Id>>>;

/**
* Subscribes to updates for a specific request in the queue.
Expand All @@ -56,10 +57,10 @@ export interface FalClient {
* @param options - Options to configure how the request is run and how updates are received.
* @returns A promise that resolves to the result of the request once it's completed.
*/
subscribe<Output = any, Input = Record<string, any>>(
endpointId: string,
options: RunOptions<Input> & QueueSubscribeOptions,
): Promise<Result<Output>>;
subscribe<Id extends EndpointType>(
endpointId: Id,
options: RunOptions<InputType<Id>> & QueueSubscribeOptions,
): Promise<Result<OutputType<Id>>>;

/**
* Calls a fal app that supports streaming and provides a streaming-capable
Expand Down Expand Up @@ -90,27 +91,24 @@ export function createFalClient(userConfig: Config = {}): FalClient {
storage,
streaming,
stream: streaming.stream,
async run<Output, Input>(
endpointId: string,
options: RunOptions<Input> = {},
): Promise<Result<Output>> {
async run<Id extends EndpointType>(
endpointId: Id,
options: RunOptions<InputType<Id>> = {},
): Promise<Result<OutputType<Id>>> {
const input = options.input
? await storage.transformInput(options.input)
: undefined;
return dispatchRequest<Input, Result<Output>>({
return dispatchRequest<InputType<Id>, Result<OutputType<Id>>>({
method: options.method,
targetUrl: buildUrl(endpointId, options),
input: input as Input,
input: input as InputType<Id>,
config: {
...config,
responseHandler: resultResponseHandler,
},
});
},
async subscribe<Output, Input>(
endpointId: string,
options: RunOptions<Input> & QueueSubscribeOptions = {},
): Promise<Result<Output>> {
subscribe: async (endpointId, options) => {
const { request_id: requestId } = await queue.submit(endpointId, options);
if (options.onEnqueue) {
options.onEnqueue(requestId);
Expand Down
25 changes: 16 additions & 9 deletions libs/client/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { createFalClient, type FalClient } from "./client";
import { Config } from "./config";
import { StreamOptions } from "./streaming";
import { RunOptions } from "./types";
import { EndpointType, InputType } from "./types/client";
import { RunOptions } from "./types/common";

export { createFalClient, type FalClient } from "./client";
export { withMiddleware, withProxy } from "./middleware";
Expand All @@ -12,12 +13,12 @@ export { ApiError, ValidationError } from "./response";
export type { ResponseHandler } from "./response";
export type { StorageClient } from "./storage";
export type { FalStream, StreamingClient } from "./streaming";
export * from "./types";
export * from "./types/common";
export type {
QueueStatus,
ValidationErrorInfo,
WebHookResponse,
} from "./types";
} from "./types/common";
export { parseEndpointId } from "./utils";

type SingletonFalClient = {
Expand Down Expand Up @@ -46,14 +47,20 @@ export const fal: SingletonFalClient = (function createSingletonFalClient() {
get streaming() {
return currentInstance.streaming;
},
run<Output, Input>(id: string, options: RunOptions<Input>) {
return currentInstance.run<Output, Input>(id, options);
run<Id extends EndpointType>(id: Id, options: RunOptions<InputType<Id>>) {
return currentInstance.run<Id>(id, options);
},
subscribe<Output, Input>(endpointId: string, options: RunOptions<Input>) {
return currentInstance.subscribe<Output, Input>(endpointId, options);
subscribe<Id extends EndpointType>(
endpointId: Id,
options: RunOptions<InputType<Id>>,
) {
return currentInstance.subscribe<Id>(endpointId, options);
},
stream<Output, Input>(endpointId: string, options: StreamOptions<Input>) {
return currentInstance.stream<Output, Input>(endpointId, options);
stream<Id extends EndpointType>(
endpointId: Id,
options: StreamOptions<InputType<Id>>,
) {
return currentInstance.stream<Id>(endpointId, options);
},
} satisfies SingletonFalClient;
})();
2 changes: 1 addition & 1 deletion libs/client/src/queue.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
RequestLog,
Result,
RunOptions,
} from "./types";
} from "./types/common";
import { parseEndpointId } from "./utils";

export type QueuePriority = "low" | "normal";
Expand Down
2 changes: 1 addition & 1 deletion libs/client/src/request.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { RequiredConfig } from "./config";
import { ResponseHandler } from "./response";
import { getUserAgent, isBrowser } from "./runtime";
import { RunOptions, UrlOptions } from "./types";
import { RunOptions, UrlOptions } from "./types/common";
import { ensureEndpointIdFormat, isValidUrl } from "./utils";

const isCloudflareWorkers =
Expand Down
2 changes: 1 addition & 1 deletion libs/client/src/response.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { RequiredConfig } from "./config";
import { Result, ValidationErrorInfo } from "./types";
import { Result, ValidationErrorInfo } from "./types/common";

export type ResponseHandler<Output> = (response: Response) => Promise<Output>;

Expand Down
21 changes: 11 additions & 10 deletions libs/client/src/streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import { RequiredConfig } from "./config";
import { buildUrl, dispatchRequest } from "./request";
import { ApiError, defaultResponseHandler } from "./response";
import { type StorageClient } from "./storage";
import { EndpointType, InputType, OutputType } from "./types/client";

export type StreamingConnectionMode = "client" | "server";

Expand Down Expand Up @@ -117,7 +118,7 @@ export class FalStream<Input, Output> {
);
}
this.signal.addEventListener("abort", () => {
resolve(this.currentData);
resolve(this.currentData ?? ({} as Output));
});
this.on("done", (data) => {
this.streamClosed = true;
Expand Down Expand Up @@ -365,10 +366,10 @@ export interface StreamingClient {
* @param options the request options, including the input payload.
* @returns the `FalStream` instance.
*/
stream<Output = any, Input = Record<string, any>>(
endpointId: string,
options: StreamOptions<Input>,
): Promise<FalStream<Input, Output>>;
stream<Id extends EndpointType>(
endpointId: Id,
options: StreamOptions<InputType<Id>>,
): Promise<FalStream<InputType<Id>, OutputType<Id>>>;
}

type StreamingClientDependencies = {
Expand All @@ -381,16 +382,16 @@ export function createStreamingClient({
storage,
}: StreamingClientDependencies): StreamingClient {
return {
async stream<Input, Output>(
endpointId: string,
options: StreamOptions<Input>,
async stream<Id extends EndpointType>(
endpointId: Id,
options: StreamOptions<InputType<Id>>,
) {
const input = options.input
? await storage.transformInput(options.input)
: undefined;
return new FalStream<Input, Output>(endpointId, config, {
return new FalStream<InputType<Id>, OutputType<Id>>(endpointId, config, {
...options,
input: input as Input,
input: input as InputType<Id>,
});
},
};
Expand Down
14 changes: 14 additions & 0 deletions libs/client/src/types/client.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import { EndpointTypeMap } from "./endpoints";

// eslint-disable-next-line @typescript-eslint/ban-types
export type EndpointType = keyof EndpointTypeMap | (string & {});

// Get input type based on endpoint ID
export type InputType<T extends string> = T extends keyof EndpointTypeMap
? EndpointTypeMap[T]["input"]
: Record<string, any>;

// Get output type based on endpoint ID
export type OutputType<T extends string> = T extends keyof EndpointTypeMap
? EndpointTypeMap[T]["output"]
: any;
File renamed without changes.
Loading

0 comments on commit ac19e8e

Please sign in to comment.