Skip to content

Commit

Permalink
Merge pull request #35 from dataforgoodfr/feat/cohere
Browse files Browse the repository at this point in the history
Add Cohere provider
  • Loading branch information
samuelrince authored Apr 15, 2024
2 parents 7ab0a2a + dfcfa80 commit 6376f59
Show file tree
Hide file tree
Showing 12 changed files with 1,024 additions and 18 deletions.
6 changes: 6 additions & 0 deletions ecologits/data/models.csv
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,9 @@ anthropic,claude-2.1,130,130,model_architecture_not_released,https://docs.google
anthropic,claude-2.0,130,130,model_architecture_not_released,https://docs.google.com/spreadsheets/d/1O5KVQW1Hx5ZAkcg8AIRjbQLQzx2wVaLl0SqUu-ir9Fs/edit?usp=sharing
anthropic,claude-instant-1.2,20;70,20;70,model_architecture_not_released,
huggingface_hub,HuggingFaceH4/zephyr-7b-beta,7.24,7.24,model_architecture_not_released,
cohere,command-light,6,6,model_architecture_not_released,https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm
cohere,command-light-nightly,6,6,model_architecture_not_released,
cohere,command,52,52,model_architecture_not_released,https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm
cohere,command-nightly,52,52,model_architecture_not_released,
cohere,command-r,35,35,model_architecture_not_released,https://huggingface.co/CohereForAI/c4ai-command-r-v01
cohere,command-r-plus,104,104,model_architecture_not_released,https://huggingface.co/CohereForAI/c4ai-command-r-plus-4bit
10 changes: 10 additions & 0 deletions ecologits/ecologits.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def init_instruments() -> None:
init_anthropic_instrumentor()
init_mistralai_instrumentor()
init_huggingface_instrumentor()
init_cohere_instrumentor()


def init_openai_instrumentor() -> None:
Expand All @@ -44,9 +45,18 @@ def init_mistralai_instrumentor() -> None:
instrumentor = MistralAIInstrumentor()
instrumentor.instrument()


def init_huggingface_instrumentor() -> None:
if importlib.util.find_spec("huggingface_hub") is not None:
from ecologits.tracers.huggingface_tracer import HuggingfaceInstrumentor

instrumentor = HuggingfaceInstrumentor()
instrumentor.instrument()


def init_cohere_instrumentor() -> None:
if importlib.util.find_spec("cohere") is not None:
from ecologits.tracers.cohere_tracer import CohereInstrumentor

instrumentor = CohereInstrumentor()
instrumentor.instrument()
1 change: 1 addition & 0 deletions ecologits/model_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class Providers(Enum):
mistralai = "mistralai"
openai = "openai"
huggingface_hub = "huggingface_hub"
cohere = "cohere"


class Warnings(Enum):
Expand Down
145 changes: 145 additions & 0 deletions ecologits/tracers/cohere_tracer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import time
from typing import Any, AsyncIterator, Callable, Iterator

from wrapt import wrap_function_wrapper

from ecologits.impacts import Impacts
from ecologits.tracers.utils import compute_llm_impacts

try:
from cohere import AsyncClient, Client
from cohere.types.non_streamed_chat_response import NonStreamedChatResponse as _NonStreamedChatResponse
from cohere.types.streamed_chat_response import StreamedChatResponse
from cohere.types.streamed_chat_response import StreamedChatResponse_StreamEnd as _StreamedChatResponse_StreamEnd
except ImportError:
Client = object()
AsyncClient = object()
_NonStreamedChatResponse = object()
StreamedChatResponse = object()
_StreamedChatResponse_StreamEnd = object()


PROVIDER = "cohere"


class NonStreamedChatResponse(_NonStreamedChatResponse):
impacts: Impacts

class Config:
arbitrary_types_allowed = True


class StreamedChatResponse_StreamEnd(_StreamedChatResponse_StreamEnd): # noqa: N801
impacts: Impacts

class Config:
arbitrary_types_allowed = True


def cohere_chat_wrapper(
wrapped: Callable, instance: Client, args: Any, kwargs: Any # noqa: ARG001
) -> NonStreamedChatResponse:
timer_start = time.perf_counter()
response = wrapped(*args, **kwargs)
request_latency = time.perf_counter() - timer_start
output_tokens = response.meta.tokens.output_tokens
model_name = kwargs.get("model", "command-r")
impacts = compute_llm_impacts(
provider=PROVIDER,
model_name=model_name,
output_token_count=output_tokens,
request_latency=request_latency,
)
return NonStreamedChatResponse(**response.dict(), impacts=impacts)


async def cohere_async_chat_wrapper(
wrapped: Callable, instance: AsyncClient, args: Any, kwargs: Any # noqa: ARG001
) -> NonStreamedChatResponse:
timer_start = time.perf_counter()
response = await wrapped(*args, **kwargs)
request_latency = time.perf_counter() - timer_start
output_tokens = response.meta.tokens.output_tokens
model_name = kwargs.get("model", "command-r")
impacts = compute_llm_impacts(
provider=PROVIDER,
model_name=model_name,
output_token_count=output_tokens,
request_latency=request_latency,
)
return NonStreamedChatResponse(**response.dict(), impacts=impacts)


def cohere_stream_chat_wrapper(
wrapped: Callable, instance: Client, args: Any, kwargs: Any # noqa: ARG001
) -> Iterator[StreamedChatResponse]:
model_name = kwargs.get("model", "command-r")
timer_start = time.perf_counter()
stream = wrapped(*args, **kwargs)
for event in stream:
if event.event_type == "stream-end":
request_latency = time.perf_counter() - timer_start
output_tokens = event.response.meta.tokens.output_tokens
impacts = compute_llm_impacts(
provider=PROVIDER,
model_name=model_name,
output_token_count=output_tokens,
request_latency=request_latency,
)
yield StreamedChatResponse_StreamEnd(**event.dict(), impacts=impacts)
else:
yield event


async def cohere_async_stream_chat_wrapper(
wrapped: Callable, instance: AsyncClient, args: Any, kwargs: Any # noqa: ARG001
) -> AsyncIterator[StreamedChatResponse]:
model_name = kwargs.get("model", "command-r")
timer_start = time.perf_counter()
stream = wrapped(*args, **kwargs)
async for event in stream:
if event.event_type == "stream-end":
request_latency = time.perf_counter() - timer_start
output_tokens = event.response.meta.tokens.output_tokens
impacts = compute_llm_impacts(
provider=PROVIDER,
model_name=model_name,
output_token_count=output_tokens,
request_latency=request_latency,
)
yield StreamedChatResponse_StreamEnd(**event.dict(), impacts=impacts)
else:
yield event


class CohereInstrumentor:
def __init__(self) -> None:
self.wrapped_methods = [
{
"module": "cohere.base_client",
"name": "BaseCohere.chat",
"wrapper": cohere_chat_wrapper,
},
{
"module": "cohere.base_client",
"name": "AsyncBaseCohere.chat",
"wrapper": cohere_async_chat_wrapper,
},
{
"module": "cohere.base_client",
"name": "BaseCohere.chat_stream",
"wrapper": cohere_stream_chat_wrapper,
},
{
"module": "cohere.base_client",
"name": "AsyncBaseCohere.chat_stream",
"wrapper": cohere_async_stream_chat_wrapper,
},
]

def instrument(self) -> None:
for wrapper in self.wrapped_methods:
wrap_function_wrapper(
wrapper["module"], wrapper["name"], wrapper["wrapper"]
)

111 changes: 93 additions & 18 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ wrapt = "^1.16.0"
pydantic = ">=2,<3"
mistralai = { version = "^0.1.3", optional = true }
anthropic = { version = "^0.18.1", optional = true }
cohere = {version = "^5.2.5", optional = true}
huggingface-hub = { version = "^0.22.2", optional = true }
tiktoken = {version = "^0.6.0", optional = true}
aiohttp = {version = "^3.9.3", optional = true}
Expand All @@ -25,6 +26,7 @@ minijinja = {version = "^1.0.16", optional = true}
[tool.poetry.extras]
mistralai = ["mistralai"]
anthropic = ["anthropic"]
cohere = ["cohere"]
huggingface-hub = ["huggingface-hub", "tiktoken", "aiohttp", "minijinja"]


Expand Down
Loading

0 comments on commit 6376f59

Please sign in to comment.