Skip to content

Commit

Permalink
Fix batching logic
Browse files Browse the repository at this point in the history
  • Loading branch information
billytrend-cohere committed Sep 13, 2024
1 parent d290fa1 commit 798a43f
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ def __exit__(self, exc_type, exc_value, traceback):
def embed(
self,
*,
texts: typing.Sequence[str],
texts: typing.Optional[typing.Sequence[str]] = OMIT,
images: typing.Optional[typing.Sequence[str]] = OMIT,
model: typing.Optional[str] = OMIT,
input_type: typing.Optional[EmbedInputType] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
Expand All @@ -190,28 +191,36 @@ def embed(
return BaseCohere.embed(
self,
texts=texts,
images=images,
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)

texts = texts or []
texts_batches = [texts[i : i + embed_batch_size] for i in range(0, len(texts), embed_batch_size)]

images = images or []
images_batches = [images[i : i + embed_batch_size] for i in range(0, len(images), embed_batch_size)]

zipped = zip(texts_batches, images_batches)

responses = [
response
for response in self._executor.map(
lambda text_batch: BaseCohere.embed(
lambda batch: BaseCohere.embed(
self,
texts=text_batch,
texts=batch[0],
images=batch[1],
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
),
texts_batches,
zipped,
)
]

Expand Down Expand Up @@ -366,7 +375,8 @@ async def __aexit__(self, exc_type, exc_value, traceback):
async def embed(
self,
*,
texts: typing.Sequence[str],
texts: typing.Optional[typing.Sequence[str]] = OMIT,
images: typing.Optional[typing.Sequence[str]] = OMIT,
model: typing.Optional[str] = OMIT,
input_type: typing.Optional[EmbedInputType] = OMIT,
embedding_types: typing.Optional[typing.Sequence[EmbeddingType]] = OMIT,
Expand All @@ -385,22 +395,30 @@ async def embed(
request_options=request_options,
)


texts = texts or []
texts_batches = [texts[i : i + embed_batch_size] for i in range(0, len(texts), embed_batch_size)]

images = images or []
images_batches = [images[i : i + embed_batch_size] for i in range(0, len(images), embed_batch_size)]

zipped = zip(texts_batches, images_batches)

responses = typing.cast(
typing.List[EmbedResponse],
await asyncio.gather(
*[
AsyncBaseCohere.embed(
self,
texts=text_batch,
texts=batch[0],
images=batch[1],
model=model,
input_type=input_type,
embedding_types=embedding_types,
truncate=truncate,
request_options=request_options,
)
for text_batch in texts_batches
for batch in zipped
]
),
)
Expand Down

0 comments on commit 798a43f

Please sign in to comment.