Skip to content

Commit 77c7643

Browse files
authored
Use batch rows for all dataset create operations (#64)
* init * remove prints
1 parent 1048518 commit 77c7643

File tree

1 file changed

+42
-5
lines changed

1 file changed

+42
-5
lines changed

quotientai/resources/datasets.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,26 @@ def create(
253253
dataset_response = self._client._post("/datasets", data=data)
254254
id = dataset_response["id"]
255255

256-
row_response = self._client._post(
257-
f"/datasets/{id}/dataset_rows/batch",
258-
data=rows,
259-
)
256+
# TODO: update the dataset_rows API to take in a list of rows
257+
# rather than one row at a time. This should be the expected behavior.
258+
row_responses = []
259+
self.batch_create_rows(id, rows, row_responses)
260+
dataset_rows = [
261+
DatasetRow(
262+
id=row_response["dataset_row_id"],
263+
input=row_response["input"],
264+
context=row_response["context"],
265+
expected=row_response["expected"],
266+
metadata=DatasetRowMetadata(
267+
annotation=row_response["annotation"],
268+
annotation_note=row_response["annotation_note"],
269+
),
270+
created_at=row_response["created_at"],
271+
created_by=row_response["created_by"],
272+
updated_at=row_response["updated_at"],
273+
)
274+
for row_response in row_responses
275+
]
260276

261277
dataset = Dataset(
262278
id=id,
@@ -265,7 +281,7 @@ def create(
265281
created_at=dataset_response["created_at"],
266282
updated_at=dataset_response["updated_at"],
267283
created_by=dataset_response["created_by"],
268-
rows=row_response,
284+
rows=dataset_rows,
269285
)
270286
return dataset
271287

@@ -420,3 +436,24 @@ def delete(self, dataset: Dataset, rows: Optional[List[DatasetRow]] = None) -> N
420436
)
421437

422438
return None
439+
440+
def batch_create_rows(self, dataset_id: str, rows: List[dict], row_responses: List[DatasetRow], batch_size: int = 10):
441+
"""
442+
Batch create rows for a dataset.
443+
"""
444+
# iterate over the rows in batches
445+
for i in range(0, len(rows), batch_size):
446+
batch = rows[i:i + batch_size]
447+
try:
448+
response = self._client._post(
449+
f"/datasets/{dataset_id}/dataset_rows/batch",
450+
data={"rows": batch},
451+
)
452+
row_responses.extend(response)
453+
except Exception as e:
454+
# If the batch create fails, divide batch size by two and recursively try
455+
if batch_size == 1:
456+
raise e
457+
else:
458+
self.batch_create_rows(dataset_id, batch, row_responses, batch_size // 2)
459+
return row_responses

0 commit comments

Comments
 (0)