Skip to content

Commit

Permalink
add base model option to custom model (#276)
Browse files Browse the repository at this point in the history
* add base model option
* expose basemodel
* update async client
* move defaults to backend

---------

Co-authored-by: Sander Land <[email protected]>
  • Loading branch information
alex-matton and sanderland authored Aug 14, 2023
1 parent 0bacac9 commit f60bd70
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
9 changes: 8 additions & 1 deletion cohere/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1082,6 +1082,7 @@ def create_custom_model(
name: str,
model_type: CUSTOM_MODEL_TYPE,
dataset: CustomModelDataset,
base_model: Optional[str] = None,
hyperparameters: Optional[HyperParametersInput] = None,
) -> CustomModel:
"""Create a new custom model
Expand All @@ -1090,6 +1091,11 @@ def create_custom_model(
name (str): name of your custom model, has to be unique across your organization
model_type (GENERATIVE, CLASSIFY, RERANK): type of custom model
dataset (InMemoryDataset, CsvDataset, JsonlDataset, TextDataset): A dataset for your training. Consists of a train and optional eval file.
base_model (str): base model to use for your custom model.
For generative and classify models, `base_model` has to be None (no option available for now)
For rerank models, you can choose between `english` and `multilingual`. Defaults to `english` if not specified.
The English model is better for English, while the multilingual model should be picked if a non-negligible part of queries/documents
will be in other languages
hyperparameters (HyperParametersInput): adjust hyperparameters for your custom model. Only for generative custom models.
Returns:
str: the id of the custom model that was created
Expand All @@ -1112,12 +1118,13 @@ def create_custom_model(
"""
internal_custom_model_type = CUSTOM_MODEL_PRODUCT_MAPPING[model_type]

json = {
"name": name,
"settings": {
"trainFiles": [],
"evalFiles": [],
"baseModel": "medium",
"baseModel": base_model,
"finetuneType": internal_custom_model_type,
},
}
Expand Down
9 changes: 8 additions & 1 deletion cohere/client_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,7 @@ async def create_custom_model(
name: str,
model_type: CUSTOM_MODEL_TYPE,
dataset: CustomModelDataset,
base_model: Optional[str] = None,
hyperparameters: Optional[HyperParametersInput] = None,
) -> AsyncCustomModel:
"""Create a new custom model
Expand All @@ -834,6 +835,11 @@ async def create_custom_model(
name (str): name of your custom model, has to be unique across your organization
model_type (GENERATIVE, CLASSIFY, RERANK): type of custom model
dataset (InMemoryDataset, CsvDataset, JsonlDataset, TextDataset): A dataset for your training. Consists of a train and optional eval file.
base_model (str): base model to use for your custom model.
For generative and classify models, `base_model` has to be None (no option available for now)
For rerank models, you can choose between `english` and `multilingual`. Defaults to `english` if not specified.
The English model is better for English, while the multilingual model should be picked if a non-negligible part of queries/documents
will be in other languages
hyperparameters (HyperParametersInput): adjust hyperparameters for your custom model. Only for generative custom models.
Returns:
str: the id of the custom model that was created
Expand All @@ -856,12 +862,13 @@ async def create_custom_model(
"""
internal_custom_model_type = CUSTOM_MODEL_PRODUCT_MAPPING[model_type]

json = {
"name": name,
"settings": {
"trainFiles": [],
"evalFiles": [],
"baseModel": "medium",
"baseModel": base_model,
"finetuneType": internal_custom_model_type,
},
}
Expand Down
3 changes: 3 additions & 0 deletions cohere/responses/custom_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def __init__(
model_type: CUSTOM_MODEL_TYPE,
created_at: datetime,
completed_at: Optional[datetime],
base_model: Optional[str] = None,
model_id: Optional[str] = None,
hyperparameters: Optional[HyperParameters] = None,
) -> None:
Expand All @@ -95,6 +96,7 @@ def __init__(
self.model_type = model_type
self.created_at = created_at
self.completed_at = completed_at
self.base_model = base_model
self.model_id = model_id
self.hyperparameters = hyperparameters
self._wait_fn = wait_fn
Expand All @@ -109,6 +111,7 @@ def from_dict(cls, data: Dict[str, Any], wait_fn) -> "BaseCustomModel":
model_type=REVERSE_CUSTOM_MODEL_PRODUCT_MAPPING[data["settings"]["finetuneType"]],
created_at=_parse_date(data["created_at"]),
completed_at=_parse_date(data["completed_at"]) if "completed_at" in data else None,
base_model=data["settings"]["baseModel"],
model_id=data["model"]["route"] if "model" in data else None,
hyperparameters=HyperParameters.from_response(data["settings"]["hyperparameters"])
if data["settings"]["hyperparameters"]
Expand Down

0 comments on commit f60bd70

Please sign in to comment.