Skip to content

Commit

Permalink
fix anthropic stream
Browse files Browse the repository at this point in the history
  • Loading branch information
adrienbanse committed Apr 10, 2024
1 parent e4b271a commit ba766d9
Showing 1 changed file with 40 additions and 30 deletions.
70 changes: 40 additions & 30 deletions genai_impact/tracers/anthropic_tracer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Iterator
from types import TracebackType
from typing import Any, Callable, Iterator, Generic, TypeVar
from typing_extensions import override

from wrapt import wrap_function_wrapper
Expand All @@ -23,20 +24,50 @@ class Message(_Message):
impacts: Impacts

class MessageStream(_MessageStream):
impacts: Impacts

@override
def __stream_text__(self) -> Iterator[str]:
return self.stream_save
output_tokens = 0
for chunk in self:
if type(chunk) is MessageStartEvent:
message = chunk.message
model = models.find_model(provider="anthropic", model_name=message.model)
output_tokens += message.usage.output_tokens
elif type(chunk) is MessageDeltaEvent:
output_tokens += chunk.usage.output_tokens
elif chunk.type == "content_block_delta" and chunk.delta.type == "text_delta":
yield chunk.delta.text
model_size = model.active_parameters or model.active_parameters_range
impacts = compute_llm_impact(
model_parameter_count=model_size, output_token_count=output_tokens
)
self.impacts = impacts

def __init__(self, parent, impacts) -> None:
self.stream_save = parent.text_stream
def __init__(self, parent):
super().__init__(
cast_to = parent._cast_to,
response = parent.response,
client = parent._client
)
self.impacts = impacts

MessageStreamT = TypeVar("MessageStreamT", bound=MessageStream)
class MessageStreamManager(Generic[MessageStreamT]):
def __init__(self, api_request: Callable[[], MessageStreamT]) -> None:
self.__stream: MessageStreamT | None = None
self.__api_request = api_request

def __enter__(self) -> MessageStreamT:
self.__stream = self.__api_request()
self.__stream = MessageStream(self.__stream)
return self.__stream

def __exit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self.__stream is not None:
self.__stream.close()

def compute_impacts_and_return_response(response: Any) -> Message:
model = models.find_model(provider="anthropic", model_name=response.model)
Expand All @@ -63,33 +94,12 @@ async def anthropic_async_chat_wrapper(
) -> Message:
response = await wrapped(*args, **kwargs)
return compute_impacts_and_return_response(response)

def compute_impacts_and_return_stream_response(response: Any) -> MessageStream:
output_tokens = 0
with response as stream:
for i, event in enumerate(stream):
if i == 0:
if type(event) is MessageStartEvent:
message = event.message
model = models.find_model(provider="anthropic", model_name=message.model)
output_tokens += message.usage.output_tokens
else:
print("Stream is not initialized with MessageStartEvent")
return stream
elif type(event) is MessageDeltaEvent:
output_tokens += event.usage.output_tokens
model_size = model.active_parameters or model.active_parameters_range
impacts = compute_llm_impact(
model_parameter_count=model_size, output_token_count=output_tokens
)

return MessageStream(stream, impacts)

def anthropic_stream_chat_wrapper(
wrapped: Callable, instance: _Anthropic, args: Any, kwargs: Any # noqa: ARG001
) -> MessageStream:
) -> MessageStreamManager:
response = wrapped(*args, **kwargs)
return compute_impacts_and_return_stream_response(response)
return MessageStreamManager(response._MessageStreamManager__api_request)

class AnthropicInstrumentor:
def __init__(self) -> None:
Expand Down

0 comments on commit ba766d9

Please sign in to comment.