Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable messages api #581

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion e2e/deno/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ const tokenInfo = await whoAmI({ credentials: { accessToken: token } });
console.log(tokenInfo);

const sum = await hf.summarization({
model: "facebook/bart-large-cnn",
model: "google/pegasus-xsum",
inputs:
"The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930.",
parameters: {
Expand Down
2 changes: 1 addition & 1 deletion e2e/svelte/src/routes/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
}

const sum = await hf.summarization({
model: "facebook/bart-large-cnn",
model: "google/pegasus-xsum",
inputs:
"The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930.",
parameters: {
Expand Down
2 changes: 1 addition & 1 deletion e2e/ts/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const hf = new HfInference(hfToken);
console.log(info);

const sum = await hf.summarization({
model: "facebook/bart-large-cnn",
model: "google/pegasus-xsum",
inputs:
"The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930.",
parameters: {
Expand Down
7 changes: 4 additions & 3 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ import { textGeneration } from "@huggingface/inference";

await textGeneration({
accessToken: "hf_...",
model: "model_or_endpoint",
model: "model",
inputs: ...,
parameters: ...
parameters: ...,
endpointUrl: "custom endpoint url",
})
```

Expand All @@ -80,7 +81,7 @@ Summarizes longer text into shorter text. Be careful, some models have a maximum

```typescript
await hf.summarization({
model: 'facebook/bart-large-cnn',
model: 'google/pegasus-xsum',
inputs:
'The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930.',
parameters: {
Expand Down
8 changes: 4 additions & 4 deletions packages/inference/src/HfInference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ type TaskWithNoAccessToken = {
) => ReturnType<Task[key]>;
};

type TaskWithNoAccessTokenNoModel = {
type TaskWithNoAccessTokenNoEndpointUrl = {
[key in keyof Task]: (
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "model">,
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken" | "endpointUrl">,
options?: Parameters<Task[key]>[1]
) => ReturnType<Task[key]>;
};
Expand Down Expand Up @@ -57,12 +57,12 @@ export class HfInferenceEndpoint {
enumerable: false,
value: (params: RequestArgs, options: Options) =>
// eslint-disable-next-line @typescript-eslint/no-explicit-any
fn({ ...params, accessToken, model: endpointUrl } as any, { ...defaultOptions, ...options }),
fn({ ...params, accessToken, endpointUrl } as any, { ...defaultOptions, ...options }),
});
}
}
}

export interface HfInference extends TaskWithNoAccessToken {}

export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoModel {}
export interface HfInferenceEndpoint extends TaskWithNoAccessTokenNoEndpointUrl {}
19 changes: 12 additions & 7 deletions packages/inference/src/lib/makeRequestOptions.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import type { InferenceTask, Options, RequestArgs } from "../types";
import { isObjectEmpty } from "../utils/isEmpty";
import { omit } from "../utils/omit";
import { HF_HUB_URL } from "./getDefaultTask";
import { isUrl } from "./isUrl";

Expand All @@ -24,8 +26,7 @@ export async function makeRequestOptions(
taskHint?: InferenceTask;
}
): Promise<{ url: string; info: RequestInit }> {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { accessToken, model: _model, ...otherArgs } = args;
const { accessToken, endpointUrl, ...otherArgs } = args;
let { model } = args;
const {
forceTask: task,
Expand Down Expand Up @@ -78,10 +79,16 @@ export async function makeRequestOptions(
}

const url = (() => {
if (endpointUrl && isUrl(model)) {
throw new TypeError("Both model and endpointUrl cannot be URLs");
}
if (isUrl(model)) {
console.warn("Using a model URL is deprecated, please use the `endpointUrl` parameter instead");
return model;
}

if (endpointUrl) {
return endpointUrl;
}
if (task) {
return `${HF_INFERENCE_API_BASE_URL}/pipeline/${task}/${model}`;
}
Expand All @@ -98,19 +105,17 @@ export async function makeRequestOptions(
} else if (includeCredentials === true) {
credentials = "include";
}

const info: RequestInit = {
headers,
method: "POST",
body: binary
? args.data
: JSON.stringify({
...otherArgs,
options: options && otherOptions,
coyotte508 marked this conversation as resolved.
Show resolved Hide resolved
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : otherArgs),
...(otherOptions && !isObjectEmpty(otherOptions) && { options: otherOptions }),
}),
...(credentials && { credentials }),
signal: options?.signal,
};

return { url, info };
}
3 changes: 3 additions & 0 deletions packages/inference/src/tasks/custom/streamingRequest.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ export async function* streamingRequest<T>(
onChunk(value);
for (const event of events) {
if (event.data.length > 0) {
if (event.data === "[DONE]") {
return;
}
const data = JSON.parse(event.data);
if (typeof data === "object" && data !== null && "error" in data) {
throw new Error(data.error);
Expand Down
18 changes: 11 additions & 7 deletions packages/inference/src/tasks/nlp/textGeneration.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
import type { Choice } from "./textGenerationStream";

/**
* Inputs for Text Generation inference
Expand Down Expand Up @@ -102,11 +102,16 @@ export interface TextGenerationOutput {
* When enabled, details about the generation
*/
details?: TextGenerationOutputDetails;
[property: string]: unknown;

/**
* The generated text
*/
generated_text: string;
[property: string]: unknown;
generated_text?: string;
/**
* If Message API compatible
*/
choices?: Choice[];
}

/**
Expand Down Expand Up @@ -212,13 +217,12 @@ export async function textGeneration(
args: BaseArgs & TextGenerationInput,
options?: Options
): Promise<TextGenerationOutput> {
const res = await request<TextGenerationOutput[]>(args, {
const res = await request<TextGenerationOutput[] | TextGenerationOutput>(args, {
...options,
taskHint: "text-generation",
});
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
if (!isValidOutput) {
throw new InferenceOutputError("Expected Array<{generated_text: string}>");
if (!Array.isArray(res)) {
return res;
}
return res?.[0];
}
24 changes: 24 additions & 0 deletions packages/inference/src/tasks/nlp/textGenerationStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,30 @@ export interface TextGenerationStreamOutput {
* Only available when the generation is finished
*/
details: TextGenerationStreamDetails | null;
/**
* If Message API compatible
*/
choices?: Choice[];
}

export interface Choice {
index: number;
delta: {
role: string;
content?: string;
tool_calls?: {
index: number;
id: string;
type: string;
function: {
name?: string;
arguments: string;
};
};
};
message?: { role: string; content: string };
logprobs?: Record<string, unknown>;
finish_reason?: string;
}

/**
Expand Down
18 changes: 16 additions & 2 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,29 @@ export interface BaseArgs {
*/
accessToken?: string;
/**
* The model to use. Can be a full URL for a dedicated inference endpoint.
* The model to use.
*
* If not specified, will call huggingface.co/api/tasks to get the default model for the task.
*
* /!\ Legacy behavior allows this to be an URL, but this is deprecated and will be removed in the future.
* Use the `endpointUrl` parameter instead.
*/
model?: string;

/**
* The URL of the endpoint to use. If not specified, will call huggingface.co/api/tasks to get the default endpoint for the task.
*
* If specified, will use this URL instead of the default one.
*/
endpointUrl?: string;
}

export type RequestArgs = BaseArgs &
({ data: Blob | ArrayBuffer } | { inputs: unknown }) & {
(
| { data: Blob | ArrayBuffer }
| { inputs: unknown }
| { messages?: Array<{ role: "user" | "assistant"; content: string }> }
) & {
parameters?: Record<string, unknown>;
accessToken?: string;
};
8 changes: 8 additions & 0 deletions packages/inference/src/utils/isEmpty.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
export function isObjectEmpty(object: object): boolean {
for (const prop in object) {
if (Object.prototype.hasOwnProperty.call(object, prop)) {
return false;
}
}
return true;
}
Loading
Loading