Skip to content

Commit 40746ef

Browse files
feat(ai-monitoring): Cohere integration (#3055)
* Cohere integration * Fix lint * Fix bug with model ID not being pulled * Exclude known models from langchain * tox.ini * Removed print statement * Apply suggestions from code review Co-authored-by: Anton Pirker <[email protected]> --------- Co-authored-by: Anton Pirker <[email protected]>
1 parent 1a32183 commit 40746ef

File tree

10 files changed

+523
-2
lines changed

10 files changed

+523
-2
lines changed

.github/workflows/test-integrations-data-processing.yml

+8
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ jobs:
5858
run: |
5959
set -x # print commands that are executed
6060
./scripts/runtox.sh "py${{ matrix.python-version }}-celery-latest" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
61+
- name: Test cohere latest
62+
run: |
63+
set -x # print commands that are executed
64+
./scripts/runtox.sh "py${{ matrix.python-version }}-cohere-latest" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
6165
- name: Test huey latest
6266
run: |
6367
set -x # print commands that are executed
@@ -126,6 +130,10 @@ jobs:
126130
run: |
127131
set -x # print commands that are executed
128132
./scripts/runtox.sh --exclude-latest "py${{ matrix.python-version }}-celery" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
133+
- name: Test cohere pinned
134+
run: |
135+
set -x # print commands that are executed
136+
./scripts/runtox.sh --exclude-latest "py${{ matrix.python-version }}-cohere" --cov=tests --cov=sentry_sdk --cov-report= --cov-branch
129137
- name: Test huey pinned
130138
run: |
131139
set -x # print commands that are executed

mypy.ini

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ warn_unused_ignores = True
2525
;
2626
; Do not use wildcards in module paths, otherwise added modules will
2727
; automatically have the same set of relaxed rules as the rest
28-
28+
[mypy-cohere.*]
29+
ignore_missing_imports = True
2930
[mypy-django.*]
3031
ignore_missing_imports = True
3132
[mypy-pyramid.*]

scripts/split-tox-gh-actions/split-tox-gh-actions.py

+1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
"arq",
7171
"beam",
7272
"celery",
73+
"cohere",
7374
"huey",
7475
"langchain",
7576
"openai",

sentry_sdk/consts.py

+33
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,18 @@ class SPANDATA:
9191
See: https://develop.sentry.dev/sdk/performance/span-data-conventions/
9292
"""
9393

94+
AI_FREQUENCY_PENALTY = "ai.frequency_penalty"
95+
"""
96+
Used to reduce repetitiveness of generated tokens.
97+
Example: 0.5
98+
"""
99+
100+
AI_PRESENCE_PENALTY = "ai.presence_penalty"
101+
"""
102+
Used to reduce repetitiveness of generated tokens.
103+
Example: 0.5
104+
"""
105+
94106
AI_INPUT_MESSAGES = "ai.input_messages"
95107
"""
96108
The input messages to an LLM call.
@@ -164,12 +176,31 @@ class SPANDATA:
164176
For an AI model call, the logit bias
165177
"""
166178

179+
AI_PREAMBLE = "ai.preamble"
180+
"""
181+
For an AI model call, the preamble parameter.
182+
Preambles are a part of the prompt used to adjust the model's overall behavior and conversation style.
183+
Example: "You are now a clown."
184+
"""
185+
186+
AI_RAW_PROMPTING = "ai.raw_prompting"
187+
"""
188+
Minimize pre-processing done to the prompt sent to the LLM.
189+
Example: true
190+
"""
191+
167192
AI_RESPONSES = "ai.responses"
168193
"""
169194
The responses to an AI model call. Always as a list.
170195
Example: ["hello", "world"]
171196
"""
172197

198+
AI_SEED = "ai.seed"
199+
"""
200+
The seed, ideally models given the same seed and same other parameters will produce the exact same output.
201+
Example: 123.45
202+
"""
203+
173204
DB_NAME = "db.name"
174205
"""
175206
The name of the database being accessed. For commands that switch the database, this should be set to the target database (even if the command fails).
@@ -298,6 +329,8 @@ class SPANDATA:
298329
class OP:
299330
ANTHROPIC_MESSAGES_CREATE = "ai.messages.create.anthropic"
300331
CACHE_GET_ITEM = "cache.get_item"
332+
COHERE_CHAT_COMPLETIONS_CREATE = "ai.chat_completions.create.cohere"
333+
COHERE_EMBEDDINGS_CREATE = "ai.embeddings.create.cohere"
301334
DB = "db"
302335
DB_REDIS = "db.redis"
303336
EVENT_DJANGO = "event.django"

sentry_sdk/integrations/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def iter_default_integrations(with_auto_enabling_integrations):
7878
"sentry_sdk.integrations.celery.CeleryIntegration",
7979
"sentry_sdk.integrations.chalice.ChaliceIntegration",
8080
"sentry_sdk.integrations.clickhouse_driver.ClickhouseDriverIntegration",
81+
"sentry_sdk.integrations.cohere.CohereIntegration",
8182
"sentry_sdk.integrations.django.DjangoIntegration",
8283
"sentry_sdk.integrations.falcon.FalconIntegration",
8384
"sentry_sdk.integrations.fastapi.FastApiIntegration",

sentry_sdk/integrations/cohere.py

+257
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
from functools import wraps
2+
3+
from sentry_sdk import consts
4+
from sentry_sdk._types import TYPE_CHECKING
5+
from sentry_sdk.ai.monitoring import record_token_usage
6+
from sentry_sdk.consts import SPANDATA
7+
from sentry_sdk.ai.utils import set_data_normalized
8+
9+
if TYPE_CHECKING:
10+
from typing import Any, Callable, Iterator
11+
from sentry_sdk.tracing import Span
12+
13+
import sentry_sdk
14+
from sentry_sdk.scope import should_send_default_pii
15+
from sentry_sdk.integrations import DidNotEnable, Integration
16+
from sentry_sdk.utils import (
17+
capture_internal_exceptions,
18+
event_from_exception,
19+
ensure_integration_enabled,
20+
)
21+
22+
try:
23+
from cohere.client import Client
24+
from cohere.base_client import BaseCohere
25+
from cohere import ChatStreamEndEvent, NonStreamedChatResponse
26+
27+
if TYPE_CHECKING:
28+
from cohere import StreamedChatResponse
29+
except ImportError:
30+
raise DidNotEnable("Cohere not installed")
31+
32+
33+
COLLECTED_CHAT_PARAMS = {
34+
"model": SPANDATA.AI_MODEL_ID,
35+
"k": SPANDATA.AI_TOP_K,
36+
"p": SPANDATA.AI_TOP_P,
37+
"seed": SPANDATA.AI_SEED,
38+
"frequency_penalty": SPANDATA.AI_FREQUENCY_PENALTY,
39+
"presence_penalty": SPANDATA.AI_PRESENCE_PENALTY,
40+
"raw_prompting": SPANDATA.AI_RAW_PROMPTING,
41+
}
42+
43+
COLLECTED_PII_CHAT_PARAMS = {
44+
"tools": SPANDATA.AI_TOOLS,
45+
"preamble": SPANDATA.AI_PREAMBLE,
46+
}
47+
48+
COLLECTED_CHAT_RESP_ATTRS = {
49+
"generation_id": "ai.generation_id",
50+
"is_search_required": "ai.is_search_required",
51+
"finish_reason": "ai.finish_reason",
52+
}
53+
54+
COLLECTED_PII_CHAT_RESP_ATTRS = {
55+
"citations": "ai.citations",
56+
"documents": "ai.documents",
57+
"search_queries": "ai.search_queries",
58+
"search_results": "ai.search_results",
59+
"tool_calls": "ai.tool_calls",
60+
}
61+
62+
63+
class CohereIntegration(Integration):
64+
identifier = "cohere"
65+
66+
def __init__(self, include_prompts=True):
67+
# type: (CohereIntegration, bool) -> None
68+
self.include_prompts = include_prompts
69+
70+
@staticmethod
71+
def setup_once():
72+
# type: () -> None
73+
BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False)
74+
Client.embed = _wrap_embed(Client.embed)
75+
BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True)
76+
77+
78+
def _capture_exception(exc):
79+
# type: (Any) -> None
80+
event, hint = event_from_exception(
81+
exc,
82+
client_options=sentry_sdk.get_client().options,
83+
mechanism={"type": "cohere", "handled": False},
84+
)
85+
sentry_sdk.capture_event(event, hint=hint)
86+
87+
88+
def _wrap_chat(f, streaming):
89+
# type: (Callable[..., Any], bool) -> Callable[..., Any]
90+
91+
def collect_chat_response_fields(span, res, include_pii):
92+
# type: (Span, NonStreamedChatResponse, bool) -> None
93+
if include_pii:
94+
if hasattr(res, "text"):
95+
set_data_normalized(
96+
span,
97+
SPANDATA.AI_RESPONSES,
98+
[res.text],
99+
)
100+
for pii_attr in COLLECTED_PII_CHAT_RESP_ATTRS:
101+
if hasattr(res, pii_attr):
102+
set_data_normalized(span, "ai." + pii_attr, getattr(res, pii_attr))
103+
104+
for attr in COLLECTED_CHAT_RESP_ATTRS:
105+
if hasattr(res, attr):
106+
set_data_normalized(span, "ai." + attr, getattr(res, attr))
107+
108+
if hasattr(res, "meta"):
109+
if hasattr(res.meta, "billed_units"):
110+
record_token_usage(
111+
span,
112+
prompt_tokens=res.meta.billed_units.input_tokens,
113+
completion_tokens=res.meta.billed_units.output_tokens,
114+
)
115+
elif hasattr(res.meta, "tokens"):
116+
record_token_usage(
117+
span,
118+
prompt_tokens=res.meta.tokens.input_tokens,
119+
completion_tokens=res.meta.tokens.output_tokens,
120+
)
121+
122+
if hasattr(res.meta, "warnings"):
123+
set_data_normalized(span, "ai.warnings", res.meta.warnings)
124+
125+
@wraps(f)
126+
@ensure_integration_enabled(CohereIntegration, f)
127+
def new_chat(*args, **kwargs):
128+
# type: (*Any, **Any) -> Any
129+
if "message" not in kwargs:
130+
return f(*args, **kwargs)
131+
132+
if not isinstance(kwargs.get("message"), str):
133+
return f(*args, **kwargs)
134+
135+
message = kwargs.get("message")
136+
137+
span = sentry_sdk.start_span(
138+
op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE,
139+
description="cohere.client.Chat",
140+
)
141+
span.__enter__()
142+
try:
143+
res = f(*args, **kwargs)
144+
except Exception as e:
145+
_capture_exception(e)
146+
span.__exit__(None, None, None)
147+
raise e from None
148+
149+
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
150+
151+
with capture_internal_exceptions():
152+
if should_send_default_pii() and integration.include_prompts:
153+
set_data_normalized(
154+
span,
155+
SPANDATA.AI_INPUT_MESSAGES,
156+
list(
157+
map(
158+
lambda x: {
159+
"role": getattr(x, "role", "").lower(),
160+
"content": getattr(x, "message", ""),
161+
},
162+
kwargs.get("chat_history", []),
163+
)
164+
)
165+
+ [{"role": "user", "content": message}],
166+
)
167+
for k, v in COLLECTED_PII_CHAT_PARAMS.items():
168+
if k in kwargs:
169+
set_data_normalized(span, v, kwargs[k])
170+
171+
for k, v in COLLECTED_CHAT_PARAMS.items():
172+
if k in kwargs:
173+
set_data_normalized(span, v, kwargs[k])
174+
set_data_normalized(span, SPANDATA.AI_STREAMING, False)
175+
176+
if streaming:
177+
old_iterator = res
178+
179+
def new_iterator():
180+
# type: () -> Iterator[StreamedChatResponse]
181+
182+
with capture_internal_exceptions():
183+
for x in old_iterator:
184+
if isinstance(x, ChatStreamEndEvent):
185+
collect_chat_response_fields(
186+
span,
187+
x.response,
188+
include_pii=should_send_default_pii()
189+
and integration.include_prompts,
190+
)
191+
yield x
192+
193+
span.__exit__(None, None, None)
194+
195+
return new_iterator()
196+
elif isinstance(res, NonStreamedChatResponse):
197+
collect_chat_response_fields(
198+
span,
199+
res,
200+
include_pii=should_send_default_pii()
201+
and integration.include_prompts,
202+
)
203+
span.__exit__(None, None, None)
204+
else:
205+
set_data_normalized(span, "unknown_response", True)
206+
span.__exit__(None, None, None)
207+
return res
208+
209+
return new_chat
210+
211+
212+
def _wrap_embed(f):
213+
# type: (Callable[..., Any]) -> Callable[..., Any]
214+
215+
@wraps(f)
216+
@ensure_integration_enabled(CohereIntegration, f)
217+
def new_embed(*args, **kwargs):
218+
# type: (*Any, **Any) -> Any
219+
with sentry_sdk.start_span(
220+
op=consts.OP.COHERE_EMBEDDINGS_CREATE,
221+
description="Cohere Embedding Creation",
222+
) as span:
223+
integration = sentry_sdk.get_client().get_integration(CohereIntegration)
224+
if "texts" in kwargs and (
225+
should_send_default_pii() and integration.include_prompts
226+
):
227+
if isinstance(kwargs["texts"], str):
228+
set_data_normalized(span, "ai.texts", [kwargs["texts"]])
229+
elif (
230+
isinstance(kwargs["texts"], list)
231+
and len(kwargs["texts"]) > 0
232+
and isinstance(kwargs["texts"][0], str)
233+
):
234+
set_data_normalized(
235+
span, SPANDATA.AI_INPUT_MESSAGES, kwargs["texts"]
236+
)
237+
238+
if "model" in kwargs:
239+
set_data_normalized(span, SPANDATA.AI_MODEL_ID, kwargs["model"])
240+
try:
241+
res = f(*args, **kwargs)
242+
except Exception as e:
243+
_capture_exception(e)
244+
raise e from None
245+
if (
246+
hasattr(res, "meta")
247+
and hasattr(res.meta, "billed_units")
248+
and hasattr(res.meta.billed_units, "input_tokens")
249+
):
250+
record_token_usage(
251+
span,
252+
prompt_tokens=res.meta.billed_units.input_tokens,
253+
total_tokens=res.meta.billed_units.input_tokens,
254+
)
255+
return res
256+
257+
return new_embed

sentry_sdk/integrations/langchain.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,12 @@ def count_tokens(s):
6363

6464
# To avoid double collecting tokens, we do *not* measure
6565
# token counts for models for which we have an explicit integration
66-
NO_COLLECT_TOKEN_MODELS = ["openai-chat"] # TODO add huggingface and anthropic
66+
NO_COLLECT_TOKEN_MODELS = [
67+
"openai-chat",
68+
"anthropic-chat",
69+
"cohere-chat",
70+
"huggingface_endpoint",
71+
]
6772

6873

6974
class LangchainIntegration(Integration):
@@ -216,6 +221,7 @@ def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
216221
watched_span.no_collect_tokens = any(
217222
x in all_params.get("_type", "") for x in NO_COLLECT_TOKEN_MODELS
218223
)
224+
219225
if not model and "anthropic" in all_params.get("_type"):
220226
model = "claude-2"
221227
if model:

tests/integrations/cohere/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
import pytest
2+
3+
pytest.importorskip("cohere")

0 commit comments

Comments
 (0)