Skip to content

Commit

Permalink
checkRepoAccess added (#947)
Browse files Browse the repository at this point in the history
More helpers for
#945

Introduces `checkRepoAccess` to check if user has read access to repo
  • Loading branch information
coyotte508 authored Oct 5, 2024
1 parent 3f99361 commit 186ab73
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 26 deletions.
23 changes: 12 additions & 11 deletions packages/hub/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,23 @@ For some of the calls, you need to create an account and generate an [access tok
Learn how to find free models using the hub package in this [interactive tutorial](https://scrimba.com/scrim/c7BbVPcd?pl=pkVnrP7uP).

```ts
import { createRepo, uploadFiles, uploadFilesWithProgress, deleteFile, deleteRepo, listFiles, whoAmI, modelInfo, listModels } from "@huggingface/hub";
import * as hub from "@huggingface/hub";
import type { RepoDesignation } from "@huggingface/hub";

const repo: RepoDesignation = { type: "model", name: "myname/some-model" };

const {name: username} = await whoAmI({accessToken: "hf_..."});
const {name: username} = await hub.whoAmI({accessToken: "hf_..."});

for await (const model of listModels({search: {owner: username}, accessToken: "hf_..."})) {
for await (const model of hub.listModels({search: {owner: username}, accessToken: "hf_..."})) {
console.log("My model:", model);
}

const specificModel = await modelInfo({name: "openai-community/gpt2"});
const specificModel = await hub.modelInfo({name: "openai-community/gpt2"});
await hub.checkRepoAccess({repo, accessToken: "hf_..."});

await createRepo({ repo, accessToken: "hf_...", license: "mit" });
await hub.createRepo({ repo, accessToken: "hf_...", license: "mit" });

await uploadFiles({
await hub.uploadFiles({
repo,
accessToken: "hf_...",
files: [
Expand All @@ -69,7 +70,7 @@ await uploadFiles({

// or

for await (const progressEvent of await uploadFilesWithProgress({
for await (const progressEvent of await hub.uploadFilesWithProgress({
repo,
accessToken: "hf_...",
files: [
Expand All @@ -79,15 +80,15 @@ for await (const progressEvent of await uploadFilesWithProgress({
console.log(progressEvent);
}

await deleteFile({repo, accessToken: "hf_...", path: "myfile.bin"});
await hub.deleteFile({repo, accessToken: "hf_...", path: "myfile.bin"});

await (await downloadFile({ repo, path: "README.md" })).text();
await (await hub.downloadFile({ repo, path: "README.md" })).text();

for await (const fileInfo of listFiles({repo})) {
for await (const fileInfo of hub.listFiles({repo})) {
console.log(fileInfo);
}

await deleteRepo({ repo, accessToken: "hf_..." });
await hub.deleteRepo({ repo, accessToken: "hf_..." });
```

## OAuth Login
Expand Down
34 changes: 34 additions & 0 deletions packages/hub/src/lib/check-repo-access.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import { assert, describe, expect, it } from "vitest";
import { checkRepoAccess } from "./check-repo-access";
import { HubApiError } from "../error";
import { TEST_ACCESS_TOKEN, TEST_HUB_URL } from "../test/consts";

describe("checkRepoAccess", () => {
it("should throw 401 when accessing unexisting repo unauthenticated", async () => {
try {
await checkRepoAccess({ repo: { name: "i--d/dont", type: "model" } });
assert(false, "should have thrown");
} catch (err) {
expect(err).toBeInstanceOf(HubApiError);
expect((err as HubApiError).statusCode).toBe(401);
}
});

it("should throw 404 when accessing unexisting repo authenticated", async () => {
try {
await checkRepoAccess({
repo: { name: "i--d/dont", type: "model" },
hubUrl: TEST_HUB_URL,
accessToken: TEST_ACCESS_TOKEN,
});
assert(false, "should have thrown");
} catch (err) {
expect(err).toBeInstanceOf(HubApiError);
expect((err as HubApiError).statusCode).toBe(404);
}
});

it("should not throw when accessing public repo", async () => {
await checkRepoAccess({ repo: { name: "openai-community/gpt2", type: "model" } });
});
});
32 changes: 32 additions & 0 deletions packages/hub/src/lib/check-repo-access.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { HUB_URL } from "../consts";
// eslint-disable-next-line @typescript-eslint/no-unused-vars
import { createApiError, type HubApiError } from "../error";
import type { CredentialsParams, RepoDesignation } from "../types/public";
import { checkCredentials } from "../utils/checkCredentials";
import { toRepoId } from "../utils/toRepoId";

/**
* Check if we have read access to a repository.
*
* Throw a {@link HubApiError} error if we don't have access. HubApiError.statusCode will be 401, 403 or 404.
*/
export async function checkRepoAccess(
params: {
repo: RepoDesignation;
hubUrl?: string;
fetch?: typeof fetch;
} & Partial<CredentialsParams>
): Promise<void> {
const accessToken = params && checkCredentials(params);
const repoId = toRepoId(params.repo);

const response = await (params.fetch || fetch)(`${params?.hubUrl || HUB_URL}/api/${repoId.type}s/${repoId.name}`, {
headers: {
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
},
});

if (!response.ok) {
throw await createApiError(response);
}
}
6 changes: 1 addition & 5 deletions packages/hub/src/lib/dataset-info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ export async function datasetInfo<
name: string;
hubUrl?: string;
additionalFields?: T[];
/**
* Set to limit the number of models returned.
*/
limit?: number;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
Expand All @@ -41,7 +37,7 @@ export async function datasetInfo<
);

if (!response.ok) {
createApiError(response);
throw await createApiError(response);
}

const data = await response.json();
Expand Down
1 change: 1 addition & 0 deletions packages/hub/src/lib/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
export * from "./cache-management";
export * from "./check-repo-access";
export * from "./commit";
export * from "./count-commits";
export * from "./create-repo";
Expand Down
6 changes: 1 addition & 5 deletions packages/hub/src/lib/model-info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,6 @@ export async function modelInfo<
name: string;
hubUrl?: string;
additionalFields?: T[];
/**
* Set to limit the number of models returned.
*/
limit?: number;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
Expand All @@ -41,7 +37,7 @@ export async function modelInfo<
);

if (!response.ok) {
createApiError(response);
throw await createApiError(response);
}

const data = await response.json();
Expand Down
6 changes: 1 addition & 5 deletions packages/hub/src/lib/space-info.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ export async function spaceInfo<
name: string;
hubUrl?: string;
additionalFields?: T[];
/**
* Set to limit the number of models returned.
*/
limit?: number;
/**
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
*/
Expand All @@ -42,7 +38,7 @@ export async function spaceInfo<
);

if (!response.ok) {
createApiError(response);
throw await createApiError(response);
}

const data = await response.json();
Expand Down

0 comments on commit 186ab73

Please sign in to comment.