Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(anthropic): add thinking as a separate completion message #2780

Merged
merged 2 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,28 @@ def _set_span_completions(span, response):
# but the API allows for multiple text blocks, so concatenate them
if content_block_type == "text":
text += content.text
elif content_block_type == "thinking":
content = dict(content)
# override the role to thinking
set_span_attribute(
span,
f"{prefix}.role",
"thinking",
)
set_span_attribute(
span,
f"{prefix}.content",
content.get("thinking"),
)
# increment the index for subsequent content blocks
index += 1
prefix = f"{SpanAttributes.LLM_COMPLETIONS}.{index}"
# set the role to the original role on the next completions
set_span_attribute(
span,
f"{prefix}.role",
response.get("role"),
)
elif content_block_type == "tool_use":
content = dict(content)
set_span_attribute(
Expand Down Expand Up @@ -269,17 +291,11 @@ async def _aset_token_usage(

if usage := response.get("usage"):
prompt_tokens = usage.input_tokens
cache_read_tokens = dict(usage).get("cache_read_input_tokens", 0) or 0
cache_creation_tokens = dict(usage).get("cache_creation_input_tokens", 0) or 0
else:
prompt_tokens = await acount_prompt_tokens_from_request(anthropic, request)

if usage := response.get("usage"):
cache_read_tokens = dict(usage).get("cache_read_input_tokens", 0)
else:
cache_read_tokens = 0

if usage := response.get("usage"):
cache_creation_tokens = dict(usage).get("cache_creation_input_tokens", 0)
else:
cache_creation_tokens = 0

input_tokens = prompt_tokens + cache_read_tokens + cache_creation_tokens
Expand Down Expand Up @@ -360,17 +376,11 @@ def _set_token_usage(

if usage := response.get("usage"):
prompt_tokens = usage.input_tokens
cache_read_tokens = dict(usage).get("cache_read_input_tokens", 0) or 0
cache_creation_tokens = dict(usage).get("cache_creation_input_tokens", 0) or 0
else:
prompt_tokens = count_prompt_tokens_from_request(anthropic, request)

if usage := response.get("usage"):
cache_read_tokens = dict(usage).get("cache_read_input_tokens", 0)
else:
cache_read_tokens = 0

if usage := response.get("usage"):
cache_creation_tokens = dict(usage).get("cache_creation_input_tokens", 0)
else:
cache_creation_tokens = 0

input_tokens = prompt_tokens + cache_read_tokens + cache_creation_tokens
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@ def _process_response_item(item, complete_response):
elif item.type == "content_block_start":
index = item.index
if len(complete_response.get("events")) <= index:
complete_response["events"].append({"index": index, "text": ""})
elif item.type == "content_block_delta" and item.delta.type == "text_delta":
complete_response["events"].append({"index": index, "text": "", "type": item.content_block.type})
elif item.type == "content_block_delta" and item.delta.type in ["thinking_delta", "text_delta"]:
index = item.index
complete_response.get("events")[index]["text"] += item.delta.text
if item.delta.type == 'thinking_delta':
complete_response["events"][index]["text"] += item.delta.thinking
elif item.delta.type == 'text_delta':
complete_response["events"][index]["text"] += item.delta.text
elif item.type == "message_delta":
for event in complete_response.get("events", []):
event["finish_reason"] = item.delta.stop_reason
Expand All @@ -52,8 +55,8 @@ def _set_token_usage(
token_histogram: Histogram = None,
choice_counter: Counter = None,
):
cache_read_tokens = complete_response.get("usage", {}).get("cache_read_input_tokens", 0)
cache_creation_tokens = complete_response.get("usage", {}).get("cache_creation_input_tokens", 0)
cache_read_tokens = complete_response.get("usage", {}).get("cache_read_input_tokens", 0) or 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you don't need this since get returns 0 by default

Copy link
Contributor Author

@dinmukhamedm dinmukhamedm Mar 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the thing is that the new anthropic SDK explicitly sets these fields to None, and so get returns a None, which fails below, where we try to add things up.

>>> d = {"key": "val", "none_key": None}
>>> v = d.get("none_key", "fallback")
>>> print (v)
None
>>> 

same for all the changes below

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh
I don't know if I hate python more or Anthropic more. Thanks!

cache_creation_tokens = complete_response.get("usage", {}).get("cache_creation_input_tokens", 0) or 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same


input_tokens = prompt_tokens + cache_read_tokens + cache_creation_tokens
total_tokens = input_tokens + completion_tokens
Expand Down Expand Up @@ -116,6 +119,8 @@ def _set_completions(span, events):
set_span_attribute(
span, f"{prefix}.finish_reason", event.get("finish_reason")
)
role = "thinking" if event.get("type") == "thinking" else "assistant"
set_span_attribute(span, f"{prefix}.role", role)
set_span_attribute(span, f"{prefix}.content", event.get("text"))
except Exception as e:
logger.warning("Failed to set completion attributes, error: %s", str(e))
Expand Down Expand Up @@ -159,13 +164,13 @@ def build_from_streaming_response(
completion_tokens = -1
# prompt_usage
if usage := complete_response.get("usage"):
prompt_tokens = usage.get("input_tokens", 0)
prompt_tokens = usage.get("input_tokens", 0) or 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same?

else:
prompt_tokens = count_prompt_tokens_from_request(instance, kwargs)

# completion_usage
if usage := complete_response.get("usage"):
completion_tokens = usage.get("output_tokens", 0)
completion_tokens = usage.get("output_tokens", 0) or 0
else:
completion_content = ""
if complete_response.get("events"):
Expand All @@ -174,7 +179,7 @@ def build_from_streaming_response(
if event.get("text"):
completion_content += event.get("text")

if model_name:
if model_name and hasattr(instance, "count_tokens"):
completion_tokens = instance.count_tokens(completion_content)

_set_token_usage(
Expand Down Expand Up @@ -250,7 +255,7 @@ async def abuild_from_streaming_response(
if event.get("text"):
completion_content += event.get("text")

if model_name:
if model_name and hasattr(instance, "count_tokens"):
completion_tokens = instance.count_tokens(completion_content)

_set_token_usage(
Expand Down
Loading