From 6a269e1b12bedde0de27f6484616e657133fb41b Mon Sep 17 00:00:00 2001 From: Raphael Cristal Date: Thu, 20 Jul 2023 14:16:47 +0200 Subject: [PATCH 1/2] add `wait` to CustomModel --- cohere/client.py | 35 ++++++++++++++++--- cohere/client_async.py | 45 ++++++++++++++++++++----- cohere/responses/custom_model.py | 58 ++++++++++++++++++++++++++++++-- 3 files changed, 123 insertions(+), 15 deletions(-) diff --git a/cohere/client.py b/cohere/client.py index a8c78fd35..c7e8174c0 100644 --- a/cohere/client.py +++ b/cohere/client.py @@ -1031,7 +1031,34 @@ def create_custom_model( json["settings"]["evalFiles"].append({"path": remote_path, **dataset.file_config()}) response = self._request(f"{cohere.CUSTOM_MODEL_URL}/CreateFinetune", method="POST", json=json) - return CustomModel.from_dict(response["finetune"]) + return CustomModel.from_dict(response["finetune"], self.wait_for_custom_model) + + def wait_for_custom_model( + self, + custom_model_id: str, + timeout: Optional[float] = None, + interval: float = 60, + ) -> CustomModel: + """Wait for custom model training completion. + + Args: + custom_model_id (str): Custom model id. + timeout (Optional[float], optional): Wait timeout in seconds, if None - there is no limit to the wait time. + Defaults to None. + interval (float, optional): Wait poll interval in seconds. Defaults to 10. + + Raises: + TimeoutError: wait timed out + + Returns: + BulkEmbedJob: Custom model. + """ + + return wait_for_job( + get_job=partial(self.get_custom_model, custom_model_id), + timeout=timeout, + interval=interval, + ) def _upload_dataset( self, content: Iterable[bytes], custom_model_name: str, file_name: str, type: INTERNAL_CUSTOM_MODEL_TYPE @@ -1058,7 +1085,7 @@ def get_custom_model(self, custom_model_id: str) -> CustomModel: """ json = {"finetuneID": custom_model_id} response = self._request(f"{cohere.CUSTOM_MODEL_URL}/GetFinetune", method="POST", json=json) - return CustomModel.from_dict(response["finetune"]) + return CustomModel.from_dict(response["finetune"], self.wait_for_custom_model) def get_custom_model_by_name(self, name: str) -> CustomModel: """Get a custom model by name. @@ -1070,7 +1097,7 @@ def get_custom_model_by_name(self, name: str) -> CustomModel: """ json = {"name": name} response = self._request(f"{cohere.CUSTOM_MODEL_URL}/GetFinetuneByName", method="POST", json=json) - return CustomModel.from_dict(response["finetune"]) + return CustomModel.from_dict(response["finetune"], self.wait_for_custom_model) def list_custom_models( self, @@ -1104,4 +1131,4 @@ def list_custom_models( } response = self._request(f"{cohere.CUSTOM_MODEL_URL}/ListFinetunes", method="POST", json=json) - return [CustomModel.from_dict(r) for r in response["finetunes"]] + return [CustomModel.from_dict(r, self.wait_for_custom_model) for r in response["finetunes"]] diff --git a/cohere/client_async.py b/cohere/client_async.py index 16e99259c..6f52aff1d 100644 --- a/cohere/client_async.py +++ b/cohere/client_async.py @@ -50,7 +50,7 @@ CUSTOM_MODEL_STATUS, CUSTOM_MODEL_TYPE, INTERNAL_CUSTOM_MODEL_TYPE, - CustomModel, + AsyncCustomModel, HyperParametersInput, ) from cohere.utils import async_wait_for_job, is_api_key_valid, np_json_dumps @@ -697,7 +697,7 @@ async def create_custom_model( model_type: CUSTOM_MODEL_TYPE, dataset: CustomModelDataset, hyperparameters: Optional[HyperParametersInput] = None, - ) -> CustomModel: + ) -> AsyncCustomModel: """Create a new custom model Args: @@ -755,7 +755,34 @@ async def create_custom_model( json["settings"]["evalFiles"].append({"path": remote_path, **dataset.file_config()}) response = await self._request(f"{cohere.CUSTOM_MODEL_URL}/CreateFinetune", method="POST", json=json) - return CustomModel.from_dict(response["finetune"]) + return AsyncCustomModel.from_dict(response["finetune"], self.wait_for_custom_model) + + async def wait_for_custom_model( + self, + custom_model_id: str, + timeout: Optional[float] = None, + interval: float = 60, + ) -> AsyncCustomModel: + """Wait for custom model training completion. + + Args: + custom_model_id (str): Custom model id. + timeout (Optional[float], optional): Wait timeout in seconds, if None - there is no limit to the wait time. + Defaults to None. + interval (float, optional): Wait poll interval in seconds. Defaults to 10. + + Raises: + TimeoutError: wait timed out + + Returns: + BulkEmbedJob: Custom model. + """ + + return await async_wait_for_job( + get_job=partial(self.get_custom_model, custom_model_id), + timeout=timeout, + interval=interval, + ) async def _upload_dataset( self, content: Iterable[bytes], custom_model_name: str, file_name: str, type: INTERNAL_CUSTOM_MODEL_TYPE @@ -773,7 +800,7 @@ async def _create_signed_url( json = {"finetuneName": custom_model_name, "fileName": file_name, "finetuneType": type} return await self._request(f"{cohere.CUSTOM_MODEL_URL}/GetFinetuneUploadSignedURL", method="POST", json=json) - async def get_custom_model(self, custom_model_id: str) -> CustomModel: + async def get_custom_model(self, custom_model_id: str) -> AsyncCustomModel: """Get a custom model by id. Args: @@ -783,9 +810,9 @@ async def get_custom_model(self, custom_model_id: str) -> CustomModel: """ json = {"finetuneID": custom_model_id} response = await self._request(f"{cohere.CUSTOM_MODEL_URL}/GetFinetune", method="POST", json=json) - return CustomModel.from_dict(response["finetune"]) + return AsyncCustomModel.from_dict(response["finetune"], self.wait_for_custom_model) - async def get_custom_model_by_name(self, name: str) -> CustomModel: + async def get_custom_model_by_name(self, name: str) -> AsyncCustomModel: """Get a custom model by name. Args: @@ -795,7 +822,7 @@ async def get_custom_model_by_name(self, name: str) -> CustomModel: """ json = {"name": name} response = await self._request(f"{cohere.CUSTOM_MODEL_URL}/GetFinetuneByName", method="POST", json=json) - return CustomModel.from_dict(response["finetune"]) + return AsyncCustomModel.from_dict(response["finetune"], self.wait_for_custom_model) async def list_custom_models( self, @@ -803,7 +830,7 @@ async def list_custom_models( before: Optional[datetime] = None, after: Optional[datetime] = None, order_by: Optional[Literal["asc", "desc"]] = None, - ) -> List[CustomModel]: + ) -> List[AsyncCustomModel]: """List custom models of your organization. Args: @@ -829,7 +856,7 @@ async def list_custom_models( } response = await self._request(f"{cohere.CUSTOM_MODEL_URL}/ListFinetunes", method="POST", json=json) - return [CustomModel.from_dict(r) for r in response["finetunes"]] + return [AsyncCustomModel.from_dict(r, self.wait_for_custom_model) for r in response["finetunes"]] class AIOHTTPBackend: diff --git a/cohere/responses/custom_model.py b/cohere/responses/custom_model.py index c04fb9220..6d8387b43 100644 --- a/cohere/responses/custom_model.py +++ b/cohere/responses/custom_model.py @@ -2,6 +2,8 @@ from datetime import datetime from typing import Any, Dict, Optional +from cohere.utils import JobWithStatus + try: from typing import Literal, TypedDict except ImportError: @@ -73,9 +75,10 @@ class HyperParametersInput(TypedDict): learning_rate: float -class CustomModel(CohereObject): +class BaseCustomModel(CohereObject, JobWithStatus): def __init__( self, + wait_fn, id: str, name: str, status: CUSTOM_MODEL_STATUS, @@ -94,10 +97,12 @@ def __init__( self.completed_at = completed_at self.model_id = model_id self.hyperparameters = hyperparameters + self._wait_fn = wait_fn @classmethod - def from_dict(cls, data: Dict[str, Any]) -> "CustomModel": + def from_dict(cls, data: Dict[str, Any], wait_fn) -> "BaseCustomModel": return cls( + wait_fn=wait_fn, id=data["id"], name=data["name"], status=data["status"], @@ -110,6 +115,55 @@ def from_dict(cls, data: Dict[str, Any]) -> "CustomModel": else None, ) + def has_terminal_status(self) -> bool: + return self.status == "READY" + + +class CustomModel(BaseCustomModel): + def wait( + self, + timeout: Optional[float] = None, + interval: float = 60, + ) -> "CustomModel": + """Wait for custom model job completion. + + Args: + timeout (Optional[float], optional): Wait timeout in seconds, if None - there is no limit to the wait time. + Defaults to None. + interval (float, optional): Wait poll interval in seconds. Defaults to 60. + + Raises: + TimeoutError: wait timed out + + Returns: + CustomModel: custom model. + """ + + return self._wait_fn(custom_model_id=self.id, timeout=timeout, interval=interval) + + +class AsyncCustomModel(BaseCustomModel): + async def wait( + self, + timeout: Optional[float] = None, + interval: float = 60, + ) -> "CustomModel": + """Wait for custom model job completion. + + Args: + timeout (Optional[float], optional): Wait timeout in seconds, if None - there is no limit to the wait time. + Defaults to None. + interval (float, optional): Wait poll interval in seconds. Defaults to 60. + + Raises: + TimeoutError: wait timed out + + Returns: + CustomModel: custom model. + """ + + return await self._wait_fn(custom_model_id=self.id, timeout=timeout, interval=interval) + def _parse_date(datetime_string: str) -> datetime: return datetime.strptime(datetime_string, "%Y-%m-%dT%H:%M:%S.%f%z") From 5b8322cec4dd6eab2aec057101ec977384fbb1ab Mon Sep 17 00:00:00 2001 From: Raphael Cristal Date: Thu, 20 Jul 2023 14:24:40 +0200 Subject: [PATCH 2/2] bump version --- CHANGELOG.md | 5 +++++ pyproject.toml | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d1cf7205..3d49fad53 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Changelog +## 4.17 + +- [#263](https://github.com/cohere-ai/cohere-python/pull/263) + - Add wait() to custom models + ## 4.16 - [#262](https://github.com/cohere-ai/cohere-python/pull/262) diff --git a/pyproject.toml b/pyproject.toml index 217b85063..2805bf409 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cohere" -version = "4.16.0" +version = "4.17.0" description = "" authors = ["Cohere"] readme = "README.md"