Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 95 additions & 9 deletions docs/my-website/docs/providers/cohere.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,51 @@ os.environ["COHERE_API_KEY"] = ""

### LiteLLM Python SDK

#### Cohere v2 API (Default)

```python showLineNumbers
from litellm import completion

## set ENV variables
os.environ["COHERE_API_KEY"] = "cohere key"

# cohere call
# cohere v2 call
response = completion(
model="cohere_chat/command-a-03-2025",
messages = [{ "content": "Hello, how are you?","role": "user"}]
)
```

#### Cohere v1 API

To use the Cohere v1/chat API, prefix your model name with `cohere_chat/v1/`:

```python showLineNumbers
from litellm import completion

## set ENV variables
os.environ["COHERE_API_KEY"] = "cohere key"

# cohere v1 call
response = completion(
model="command-r",
model="cohere_chat/v1/command-a-03-2025",
messages = [{ "content": "Hello, how are you?","role": "user"}]
)
```

#### Streaming

**Cohere v2 Streaming:**

```python showLineNumbers
from litellm import completion

## set ENV variables
os.environ["COHERE_API_KEY"] = "cohere key"

# cohere call
# cohere v2 streaming
response = completion(
model="command-r",
model="cohere_chat/command-a-03-2025",
messages = [{ "content": "Hello, how are you?","role": "user"}],
stream=True
)
Expand All @@ -48,6 +69,25 @@ for chunk in response:
```


**Cohere v1 Streaming:**

```python showLineNumbers
from litellm import completion

## set ENV variables
os.environ["COHERE_API_KEY"] = "cohere key"

# cohere v1 streaming
response = completion(
model="cohere_chat/v1/command-a-03-2025",
messages = [{ "content": "Hello, how are you?","role": "user"}],
stream=True
)

for chunk in response:
print(chunk)
```


## Usage with LiteLLM Proxy

Expand All @@ -63,11 +103,21 @@ export COHERE_API_KEY="your-api-key"

Define the cohere models you want to use in the config.yaml

**For Cohere v1 models:**
```yaml showLineNumbers
model_list:
- model_name: command-a-03-2025
litellm_params:
model: command-a-03-2025
model: cohere_chat/v1/command-a-03-2025
api_key: "os.environ/COHERE_API_KEY"
```
**For Cohere v2 models:**
```yaml showLineNumbers
model_list:
- model_name: command-a-03-2025-v2
litellm_params:
model: cohere_chat/command-a-03-2025
api_key: "os.environ/COHERE_API_KEY"
```
Expand All @@ -78,9 +128,8 @@ litellm --config /path/to/config.yaml

### 3. Test it


<Tabs>
<TabItem value="Curl" label="Curl Request">
<TabItem value="v1-curl" label="Cohere v1 - Curl Request">

```shell showLineNumbers
curl --location 'http://0.0.0.0:4000/chat/completions' \
Expand All @@ -98,7 +147,25 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \
'
```
</TabItem>
<TabItem value="openai" label="OpenAI v1.0.0+">
<TabItem value="v2-curl" label="Cohere v2 - Curl Request">

```shell showLineNumbers
curl --location 'http://0.0.0.0:4000/chat/completions' \
--header 'Content-Type: application/json' \
--header 'Authorization: Bearer <your-litellm-api-key>' \
--data ' {
"model": "command-a-03-2025-v2",
"messages": [
{
"role": "user",
"content": "what llm are you"
}
]
}
'
```
</TabItem>
<TabItem value="v1-openai" label="Cohere v1 - OpenAI SDK">

```python showLineNumbers
import openai
Expand All @@ -107,7 +174,7 @@ client = openai.OpenAI(
base_url="http://0.0.0.0:4000"
)

# request sent to model set on litellm proxy
# request sent to cohere v1 model
response = client.chat.completions.create(model="command-a-03-2025", messages = [
{
"role": "user",
Expand All @@ -116,7 +183,26 @@ response = client.chat.completions.create(model="command-a-03-2025", messages =
])

print(response)
```
</TabItem>
<TabItem value="v2-openai" label="Cohere v2 - OpenAI SDK">

```python showLineNumbers
import openai
client = openai.OpenAI(
api_key="anything",
base_url="http://0.0.0.0:4000"
)

# request sent to cohere v2 model
response = client.chat.completions.create(model="command-a-03-2025-v2", messages = [
{
"role": "user",
"content": "this is a test request, write a short poem"
}
])

print(response)
```
</TabItem>
</Tabs>
Expand Down
1 change: 1 addition & 0 deletions litellm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1174,6 +1174,7 @@ def add_known_models():
AmazonTitanV2Config,
)
from .llms.cohere.chat.transformation import CohereChatConfig
from .llms.cohere.chat.v2_transformation import CohereV2ChatConfig
from .llms.bedrock.embed.cohere_transformation import BedrockCohereEmbeddingConfig
from .llms.bedrock.embed.twelvelabs_marengo_transformation import TwelveLabsMarengoEmbeddingConfig
from .llms.openai.openai import OpenAIConfig, MistralEmbeddingConfig
Expand Down
134 changes: 28 additions & 106 deletions litellm/llms/cohere/chat/v2_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
import httpx

import litellm
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.cohere import CohereV2ChatResponse
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolCallChunk
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.types.utils import ModelResponse, Usage

from ..common_utils import CohereError
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
from ..common_utils import CohereV2ModelResponseIterator
from ..common_utils import validate_environment as cohere_validate_environment

if TYPE_CHECKING:
Expand All @@ -22,7 +22,7 @@
LiteLLMLoggingObj = Any


class CohereV2ChatConfig(BaseConfig):
class CohereV2ChatConfig(OpenAIGPTConfig):
"""
Configuration class for Cohere's API interface.

Expand Down Expand Up @@ -164,32 +164,12 @@ def transform_request(
litellm_params: dict,
headers: dict,
) -> dict:
## Load Config
for k, v in litellm.CohereChatConfig.get_config().items():
if (
k not in optional_params
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v

most_recent_message, chat_history = cohere_messages_pt_v2(
messages=messages, model=model, llm_provider="cohere_chat"
)

## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
optional_params["tools"] = cohere_tools
if isinstance(most_recent_message, dict):
optional_params["tool_results"] = [most_recent_message]
elif isinstance(most_recent_message, str):
optional_params["message"] = most_recent_message

## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
optional_params["force_single_step"] = True

return optional_params
"""
Cohere v2 chat api is in openai format, so we can use the openai transform request function to transform the request.
"""
data = super().transform_request(model, messages, optional_params, litellm_params, headers)

return data

def transform_response(
self,
Expand Down Expand Up @@ -263,93 +243,35 @@ def transform_response(
setattr(model_response, "usage", usage)
return model_response

def _construct_cohere_tool(
self,
tools: Optional[list] = None,
):
if tools is None:
tools = []
cohere_tools = []
for tool in tools:
cohere_tool = self._translate_openai_tool_to_cohere(tool)
cohere_tools.append(cohere_tool)
return cohere_tools

def _translate_openai_tool_to_cohere(
self,
openai_tool: dict,
):
# cohere tools look like this
"""
{
"name": "query_daily_sales_report",
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
"parameter_definitions": {
"day": {
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
"type": "str",
"required": True
}
}
}
"""

# OpenAI tools look like this
"""
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
"""
cohere_tool = {
"name": openai_tool["function"]["name"],
"description": openai_tool["function"]["description"],
"parameter_definitions": {},
}

for param_name, param_def in openai_tool["function"]["parameters"][
"properties"
].items():
required_params = (
openai_tool.get("function", {})
.get("parameters", {})
.get("required", [])
)
cohere_param_def = {
"description": param_def.get("description", ""),
"type": param_def.get("type", ""),
"required": param_name in required_params,
}
cohere_tool["parameter_definitions"][param_name] = cohere_param_def

return cohere_tool

def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return CohereModelResponseIterator(
return CohereV2ModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)

def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for Cohere v2 chat completion.
The api_base should already include the full path.
"""
if api_base is None:
raise ValueError("api_base is required")
return api_base

def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
Expand Down
Loading
Loading