Skip to content

Commit

Permalink
pre-commit fix
Browse files Browse the repository at this point in the history
  • Loading branch information
adrienbanse committed Apr 12, 2024
1 parent 8cea714 commit 25e7f68
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
1 change: 1 addition & 0 deletions ecologits/impacts/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import BaseModel


class Impact(BaseModel):
type: str
name: str
Expand Down
24 changes: 12 additions & 12 deletions ecologits/tracers/cohere_tracer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import time
from typing import Callable, Any, Iterator, AsyncIterator
from typing import Any, AsyncIterator, Callable, Iterator

from cohere import Client, AsyncClient
from cohere import AsyncClient, Client
from cohere.types.non_streamed_chat_response import NonStreamedChatResponse as _NonStreamedChatResponse
from cohere.types.streamed_chat_response import StreamedChatResponse
from cohere.types.streamed_chat_response import StreamedChatResponse_StreamEnd as _StreamedChatResponse_StreamEnd
Expand All @@ -16,14 +16,14 @@ class NonStreamedChatResponse(_NonStreamedChatResponse):
impacts: Impacts

class Config:
arbitrary_types_allowed = True
arbitrary_types_allowed = True


class StreamedChatResponse_StreamEnd(_StreamedChatResponse_StreamEnd):
class StreamedChatResponse_StreamEnd(_StreamedChatResponse_StreamEnd): # noqa: N801
impacts: Impacts

class Config:
arbitrary_types_allowed = True
arbitrary_types_allowed = True


def cohere_chat_wrapper(
Expand All @@ -44,7 +44,7 @@ def cohere_chat_wrapper(


async def cohere_async_chat_wrapper(
wrapped: Callable, instance: AsyncClient, args: Any, kwargs: Any # noqa: ARG001
wrapped: Callable, instance: AsyncClient, args: Any, kwargs: Any # noqa: ARG001
) -> NonStreamedChatResponse:
timer_start = time.perf_counter()
response = await wrapped(*args, **kwargs)
Expand All @@ -61,7 +61,7 @@ async def cohere_async_chat_wrapper(


def cohere_stream_chat_wrapper(
wrapped: Callable, instance: Client, args: Any, kwargs: Any
wrapped: Callable, instance: Client, args: Any, kwargs: Any # noqa: ARG001
) -> Iterator[StreamedChatResponse]:
model_name = kwargs.get("model", "command-r")
timer_start = time.perf_counter()
Expand All @@ -82,7 +82,7 @@ def cohere_stream_chat_wrapper(


async def cohere_async_stream_chat_wrapper(
wrapped: Callable, instance: AsyncClient, args: Any, kwargs: Any
wrapped: Callable, instance: AsyncClient, args: Any, kwargs: Any # noqa: ARG001
) -> AsyncIterator[StreamedChatResponse]:
model_name = kwargs.get("model", "command-r")
timer_start = time.perf_counter()
Expand All @@ -109,22 +109,22 @@ def __init__(self) -> None:
"module": "cohere.base_client",
"name": "BaseCohere.chat",
"wrapper": cohere_chat_wrapper,
},
},
{
"module": "cohere.base_client",
"name": "AsyncBaseCohere.chat",
"wrapper": cohere_async_chat_wrapper,
},
},
{
"module": "cohere.base_client",
"name": "BaseCohere.chat_stream",
"wrapper": cohere_stream_chat_wrapper,
},
},
{
"module": "cohere.base_client",
"name": "AsyncBaseCohere.chat_stream",
"wrapper": cohere_async_stream_chat_wrapper,
},
},
]

def instrument(self) -> None:
Expand Down

0 comments on commit 25e7f68

Please sign in to comment.