Skip to content

Commit e67b32a

Browse files
committed
Support passing AbortSignal into paginate()
1 parent f32e187 commit e67b32a

File tree

3 files changed

+97
-5
lines changed

3 files changed

+97
-5
lines changed

index.d.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,10 @@ declare module "replicate" {
187187
}
188188
): Promise<Response>;
189189

190-
paginate<T>(endpoint: () => Promise<Page<T>>): AsyncGenerator<[T]>;
190+
paginate<T>(
191+
endpoint: () => Promise<Page<T>>,
192+
options?: { signal?: AbortSignal }
193+
): AsyncGenerator<T[]>;
191194

192195
wait(
193196
prediction: Prediction,

index.js

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,15 +356,20 @@ class Replicate {
356356
* console.log(page);
357357
* }
358358
* @param {Function} endpoint - Function that returns a promise for the next page of results
359+
* @param {object} [options]
360+
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the request.
359361
* @yields {object[]} Each page of results
360362
*/
361-
async *paginate(endpoint) {
363+
async *paginate(endpoint, options = {}) {
362364
const response = await endpoint();
363365
yield response.results;
364-
if (response.next) {
366+
if (response.next && !(options.signal && options.signal.aborted)) {
365367
const nextPage = () =>
366-
this.request(response.next, { method: "GET" }).then((r) => r.json());
367-
yield* this.paginate(nextPage);
368+
this.request(response.next, {
369+
method: "GET",
370+
signal: options.signal,
371+
}).then((r) => r.json());
372+
yield* this.paginate(nextPage, options);
368373
}
369374
}
370375

index.test.ts

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,90 @@ describe("Replicate client", () => {
9999
});
100100
});
101101

102+
describe("paginate", () => {
103+
test("pages through results", async () => {
104+
nock(BASE_URL)
105+
.get("/collections")
106+
.reply(200, {
107+
results: [
108+
{
109+
name: "Super resolution",
110+
slug: "super-resolution",
111+
description:
112+
"Upscaling models that create high-quality images from low-quality images.",
113+
},
114+
],
115+
next: `${BASE_URL}/collections?page=2`,
116+
previous: null,
117+
});
118+
nock(BASE_URL)
119+
.get("/collections?page=2")
120+
.reply(200, {
121+
results: [
122+
{
123+
name: "Image classification",
124+
slug: "image-classification",
125+
description: "Models that classify images.",
126+
},
127+
],
128+
next: null,
129+
previous: null,
130+
});
131+
132+
const iterator = client.paginate(client.collections.list);
133+
134+
const firstPage = (await iterator.next()).value;
135+
expect(firstPage.length).toBe(1);
136+
137+
const secondPage = (await iterator.next()).value;
138+
expect(secondPage.length).toBe(1);
139+
});
140+
141+
test("accepts an abort signal", async () => {
142+
nock(BASE_URL)
143+
.get("/collections")
144+
.reply(200, {
145+
results: [
146+
{
147+
name: "Super resolution",
148+
slug: "super-resolution",
149+
description:
150+
"Upscaling models that create high-quality images from low-quality images.",
151+
},
152+
],
153+
next: `${BASE_URL}/collections?page=2`,
154+
previous: null,
155+
});
156+
nock(BASE_URL)
157+
.get("/collections?page=2")
158+
.reply(200, {
159+
results: [
160+
{
161+
name: "Image classification",
162+
slug: "image-classification",
163+
description: "Models that classify images.",
164+
},
165+
],
166+
next: null,
167+
previous: null,
168+
});
169+
170+
const controller = new AbortController();
171+
const iterator = client.paginate(client.collections.list, {
172+
signal: controller.signal,
173+
});
174+
175+
const firstIteration = await iterator.next();
176+
expect(firstIteration.value.length).toBe(1);
177+
178+
controller.abort();
179+
180+
const secondIteration = await iterator.next();
181+
expect(secondIteration.value).toBeUndefined();
182+
expect(secondIteration.done).toBe(true);
183+
});
184+
});
185+
102186
describe("account.get", () => {
103187
test("Calls the correct API route", async () => {
104188
nock(BASE_URL).get("/account").reply(200, {

0 commit comments

Comments
 (0)