Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wait to CustomModel #263

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

## 4.17

- [#263](https://github.com/cohere-ai/cohere-python/pull/263)
- Add wait() to custom models

## 4.16

- [#262](https://github.com/cohere-ai/cohere-python/pull/262)
Expand Down
35 changes: 31 additions & 4 deletions cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,34 @@ def create_custom_model(
json["settings"]["evalFiles"].append({"path": remote_path, **dataset.file_config()})

response = self._request(f"{cohere.CUSTOM_MODEL_URL}/CreateFinetune", method="POST", json=json)
return CustomModel.from_dict(response["finetune"])
return CustomModel.from_dict(response["finetune"], self.wait_for_custom_model)

def wait_for_custom_model(
self,
custom_model_id: str,
timeout: Optional[float] = None,
interval: float = 60,
) -> CustomModel:
"""Wait for custom model training completion.

Args:
custom_model_id (str): Custom model id.
timeout (Optional[float], optional): Wait timeout in seconds, if None - there is no limit to the wait time.
Defaults to None.
interval (float, optional): Wait poll interval in seconds. Defaults to 10.

Raises:
TimeoutError: wait timed out

Returns:
BulkEmbedJob: Custom model.
"""

return wait_for_job(
get_job=partial(self.get_custom_model, custom_model_id),
timeout=timeout,
interval=interval,
)

def _upload_dataset(
self, content: Iterable[bytes], custom_model_name: str, file_name: str, type: INTERNAL_CUSTOM_MODEL_TYPE
Expand All @@ -1058,7 +1085,7 @@ def get_custom_model(self, custom_model_id: str) -> CustomModel:
"""
json = {"finetuneID": custom_model_id}
response = self._request(f"{cohere.CUSTOM_MODEL_URL}/GetFinetune", method="POST", json=json)
return CustomModel.from_dict(response["finetune"])
return CustomModel.from_dict(response["finetune"], self.wait_for_custom_model)

def get_custom_model_by_name(self, name: str) -> CustomModel:
"""Get a custom model by name.
Expand All @@ -1070,7 +1097,7 @@ def get_custom_model_by_name(self, name: str) -> CustomModel:
"""
json = {"name": name}
response = self._request(f"{cohere.CUSTOM_MODEL_URL}/GetFinetuneByName", method="POST", json=json)
return CustomModel.from_dict(response["finetune"])
return CustomModel.from_dict(response["finetune"], self.wait_for_custom_model)

def list_custom_models(
self,
Expand Down Expand Up @@ -1104,4 +1131,4 @@ def list_custom_models(
}

response = self._request(f"{cohere.CUSTOM_MODEL_URL}/ListFinetunes", method="POST", json=json)
return [CustomModel.from_dict(r) for r in response["finetunes"]]
return [CustomModel.from_dict(r, self.wait_for_custom_model) for r in response["finetunes"]]
45 changes: 36 additions & 9 deletions cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
CUSTOM_MODEL_STATUS,
CUSTOM_MODEL_TYPE,
INTERNAL_CUSTOM_MODEL_TYPE,
CustomModel,
AsyncCustomModel,
HyperParametersInput,
)
from cohere.utils import async_wait_for_job, is_api_key_valid, np_json_dumps
Expand Down Expand Up @@ -697,7 +697,7 @@ async def create_custom_model(
model_type: CUSTOM_MODEL_TYPE,
dataset: CustomModelDataset,
hyperparameters: Optional[HyperParametersInput] = None,
) -> CustomModel:
) -> AsyncCustomModel:
"""Create a new custom model

Args:
Expand Down Expand Up @@ -755,7 +755,34 @@ async def create_custom_model(
json["settings"]["evalFiles"].append({"path": remote_path, **dataset.file_config()})

response = await self._request(f"{cohere.CUSTOM_MODEL_URL}/CreateFinetune", method="POST", json=json)
return CustomModel.from_dict(response["finetune"])
return AsyncCustomModel.from_dict(response["finetune"], self.wait_for_custom_model)

async def wait_for_custom_model(
self,
custom_model_id: str,
timeout: Optional[float] = None,
interval: float = 60,
) -> AsyncCustomModel:
"""Wait for custom model training completion.

Args:
custom_model_id (str): Custom model id.
timeout (Optional[float], optional): Wait timeout in seconds, if None - there is no limit to the wait time.
Defaults to None.
interval (float, optional): Wait poll interval in seconds. Defaults to 10.

Raises:
TimeoutError: wait timed out

Returns:
BulkEmbedJob: Custom model.
"""

return await async_wait_for_job(
get_job=partial(self.get_custom_model, custom_model_id),
timeout=timeout,
interval=interval,
)

async def _upload_dataset(
self, content: Iterable[bytes], custom_model_name: str, file_name: str, type: INTERNAL_CUSTOM_MODEL_TYPE
Expand All @@ -773,7 +800,7 @@ async def _create_signed_url(
json = {"finetuneName": custom_model_name, "fileName": file_name, "finetuneType": type}
return await self._request(f"{cohere.CUSTOM_MODEL_URL}/GetFinetuneUploadSignedURL", method="POST", json=json)

async def get_custom_model(self, custom_model_id: str) -> CustomModel:
async def get_custom_model(self, custom_model_id: str) -> AsyncCustomModel:
"""Get a custom model by id.

Args:
Expand All @@ -783,9 +810,9 @@ async def get_custom_model(self, custom_model_id: str) -> CustomModel:
"""
json = {"finetuneID": custom_model_id}
response = await self._request(f"{cohere.CUSTOM_MODEL_URL}/GetFinetune", method="POST", json=json)
return CustomModel.from_dict(response["finetune"])
return AsyncCustomModel.from_dict(response["finetune"], self.wait_for_custom_model)

async def get_custom_model_by_name(self, name: str) -> CustomModel:
async def get_custom_model_by_name(self, name: str) -> AsyncCustomModel:
"""Get a custom model by name.

Args:
Expand All @@ -795,15 +822,15 @@ async def get_custom_model_by_name(self, name: str) -> CustomModel:
"""
json = {"name": name}
response = await self._request(f"{cohere.CUSTOM_MODEL_URL}/GetFinetuneByName", method="POST", json=json)
return CustomModel.from_dict(response["finetune"])
return AsyncCustomModel.from_dict(response["finetune"], self.wait_for_custom_model)

async def list_custom_models(
self,
statuses: Optional[List[CUSTOM_MODEL_STATUS]] = None,
before: Optional[datetime] = None,
after: Optional[datetime] = None,
order_by: Optional[Literal["asc", "desc"]] = None,
) -> List[CustomModel]:
) -> List[AsyncCustomModel]:
"""List custom models of your organization.

Args:
Expand All @@ -829,7 +856,7 @@ async def list_custom_models(
}

response = await self._request(f"{cohere.CUSTOM_MODEL_URL}/ListFinetunes", method="POST", json=json)
return [CustomModel.from_dict(r) for r in response["finetunes"]]
return [AsyncCustomModel.from_dict(r, self.wait_for_custom_model) for r in response["finetunes"]]


class AIOHTTPBackend:
Expand Down
58 changes: 56 additions & 2 deletions cohere/responses/custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from datetime import datetime
from typing import Any, Dict, Optional

from cohere.utils import JobWithStatus

try:
from typing import Literal, TypedDict
except ImportError:
Expand Down Expand Up @@ -73,9 +75,10 @@ class HyperParametersInput(TypedDict):
learning_rate: float


class CustomModel(CohereObject):
class BaseCustomModel(CohereObject, JobWithStatus):
def __init__(
self,
wait_fn,
id: str,
name: str,
status: CUSTOM_MODEL_STATUS,
Expand All @@ -94,10 +97,12 @@ def __init__(
self.completed_at = completed_at
self.model_id = model_id
self.hyperparameters = hyperparameters
self._wait_fn = wait_fn

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "CustomModel":
def from_dict(cls, data: Dict[str, Any], wait_fn) -> "BaseCustomModel":
return cls(
wait_fn=wait_fn,
id=data["id"],
name=data["name"],
status=data["status"],
Expand All @@ -110,6 +115,55 @@ def from_dict(cls, data: Dict[str, Any]) -> "CustomModel":
else None,
)

def has_terminal_status(self) -> bool:
return self.status == "READY"


class CustomModel(BaseCustomModel):
def wait(
self,
timeout: Optional[float] = None,
interval: float = 60,
) -> "CustomModel":
"""Wait for custom model job completion.

Args:
timeout (Optional[float], optional): Wait timeout in seconds, if None - there is no limit to the wait time.
Defaults to None.
interval (float, optional): Wait poll interval in seconds. Defaults to 60.

Raises:
TimeoutError: wait timed out

Returns:
CustomModel: custom model.
"""

return self._wait_fn(custom_model_id=self.id, timeout=timeout, interval=interval)


class AsyncCustomModel(BaseCustomModel):
async def wait(
self,
timeout: Optional[float] = None,
interval: float = 60,
) -> "CustomModel":
"""Wait for custom model job completion.

Args:
timeout (Optional[float], optional): Wait timeout in seconds, if None - there is no limit to the wait time.
Defaults to None.
interval (float, optional): Wait poll interval in seconds. Defaults to 60.

Raises:
TimeoutError: wait timed out

Returns:
CustomModel: custom model.
"""

return await self._wait_fn(custom_model_id=self.id, timeout=timeout, interval=interval)


def _parse_date(datetime_string: str) -> datetime:
return datetime.strptime(datetime_string, "%Y-%m-%dT%H:%M:%S.%f%z")
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cohere"
version = "4.16.0"
version = "4.17.0"
description = ""
authors = ["Cohere"]
readme = "README.md"
Expand Down
Loading