diff --git a/cohere/client.py b/cohere/client.py index ac8452ad4..07d0ae319 100644 --- a/cohere/client.py +++ b/cohere/client.py @@ -1009,25 +1009,12 @@ def create_custom_model( """ internal_custom_model_type = CUSTOM_MODEL_PRODUCT_MAPPING[model_type] - # Figuring out base model - if internal_custom_model_type in ["GENERATIVE", "CLASSIFICATION"]: - assert base_model is None, "base_model has to be None for generative and classification models" - internal_base_model = "medium" - elif internal_custom_model_type == "RERANK": - internal_base_model = base_model or "english" - assert internal_base_model in [ - "english", - "multilingual", - ], "base_model has to be `english` or `multilingual`" - else: - raise ValueError(f"Unsupported model_type: {internal_custom_model_type}") - json = { "name": name, "settings": { "trainFiles": [], "evalFiles": [], - "baseModel": internal_base_model, + "baseModel": base_model, "finetuneType": internal_custom_model_type, }, } diff --git a/cohere/client_async.py b/cohere/client_async.py index 4704ba199..f6ca42827 100644 --- a/cohere/client_async.py +++ b/cohere/client_async.py @@ -733,25 +733,12 @@ async def create_custom_model( """ internal_custom_model_type = CUSTOM_MODEL_PRODUCT_MAPPING[model_type] - # Figuring out base model - if internal_custom_model_type in ["GENERATIVE", "CLASSIFICATION"]: - assert base_model is None, "base_model has to be None for generative and classification models" - internal_base_model = "medium" - elif internal_custom_model_type == "RERANK": - internal_base_model = base_model or "english" - assert internal_base_model in [ - "english", - "multilingual", - ], "base_model has to be `english` or `multilingual`" - else: - raise ValueError(f"Unsupported model_type: {internal_custom_model_type}") - json = { "name": name, "settings": { "trainFiles": [], "evalFiles": [], - "baseModel": internal_base_model, + "baseModel": base_model, "finetuneType": internal_custom_model_type, }, }