|
| 1 | +from functools import wraps |
| 2 | + |
| 3 | +from sentry_sdk import consts |
| 4 | +from sentry_sdk._types import TYPE_CHECKING |
| 5 | +from sentry_sdk.ai.monitoring import record_token_usage |
| 6 | +from sentry_sdk.consts import SPANDATA |
| 7 | +from sentry_sdk.ai.utils import set_data_normalized |
| 8 | + |
| 9 | +if TYPE_CHECKING: |
| 10 | + from typing import Any, Callable, Iterator |
| 11 | + from sentry_sdk.tracing import Span |
| 12 | + |
| 13 | +import sentry_sdk |
| 14 | +from sentry_sdk.scope import should_send_default_pii |
| 15 | +from sentry_sdk.integrations import DidNotEnable, Integration |
| 16 | +from sentry_sdk.utils import ( |
| 17 | + capture_internal_exceptions, |
| 18 | + event_from_exception, |
| 19 | + ensure_integration_enabled, |
| 20 | +) |
| 21 | + |
| 22 | +try: |
| 23 | + from cohere.client import Client |
| 24 | + from cohere.base_client import BaseCohere |
| 25 | + from cohere import ChatStreamEndEvent, NonStreamedChatResponse |
| 26 | + |
| 27 | + if TYPE_CHECKING: |
| 28 | + from cohere import StreamedChatResponse |
| 29 | +except ImportError: |
| 30 | + raise DidNotEnable("Cohere not installed") |
| 31 | + |
| 32 | + |
| 33 | +COLLECTED_CHAT_PARAMS = { |
| 34 | + "model": SPANDATA.AI_MODEL_ID, |
| 35 | + "k": SPANDATA.AI_TOP_K, |
| 36 | + "p": SPANDATA.AI_TOP_P, |
| 37 | + "seed": SPANDATA.AI_SEED, |
| 38 | + "frequency_penalty": SPANDATA.AI_FREQUENCY_PENALTY, |
| 39 | + "presence_penalty": SPANDATA.AI_PRESENCE_PENALTY, |
| 40 | + "raw_prompting": SPANDATA.AI_RAW_PROMPTING, |
| 41 | +} |
| 42 | + |
| 43 | +COLLECTED_PII_CHAT_PARAMS = { |
| 44 | + "tools": SPANDATA.AI_TOOLS, |
| 45 | + "preamble": SPANDATA.AI_PREAMBLE, |
| 46 | +} |
| 47 | + |
| 48 | +COLLECTED_CHAT_RESP_ATTRS = { |
| 49 | + "generation_id": "ai.generation_id", |
| 50 | + "is_search_required": "ai.is_search_required", |
| 51 | + "finish_reason": "ai.finish_reason", |
| 52 | +} |
| 53 | + |
| 54 | +COLLECTED_PII_CHAT_RESP_ATTRS = { |
| 55 | + "citations": "ai.citations", |
| 56 | + "documents": "ai.documents", |
| 57 | + "search_queries": "ai.search_queries", |
| 58 | + "search_results": "ai.search_results", |
| 59 | + "tool_calls": "ai.tool_calls", |
| 60 | +} |
| 61 | + |
| 62 | + |
| 63 | +class CohereIntegration(Integration): |
| 64 | + identifier = "cohere" |
| 65 | + |
| 66 | + def __init__(self, include_prompts=True): |
| 67 | + # type: (CohereIntegration, bool) -> None |
| 68 | + self.include_prompts = include_prompts |
| 69 | + |
| 70 | + @staticmethod |
| 71 | + def setup_once(): |
| 72 | + # type: () -> None |
| 73 | + BaseCohere.chat = _wrap_chat(BaseCohere.chat, streaming=False) |
| 74 | + Client.embed = _wrap_embed(Client.embed) |
| 75 | + BaseCohere.chat_stream = _wrap_chat(BaseCohere.chat_stream, streaming=True) |
| 76 | + |
| 77 | + |
| 78 | +def _capture_exception(exc): |
| 79 | + # type: (Any) -> None |
| 80 | + event, hint = event_from_exception( |
| 81 | + exc, |
| 82 | + client_options=sentry_sdk.get_client().options, |
| 83 | + mechanism={"type": "cohere", "handled": False}, |
| 84 | + ) |
| 85 | + sentry_sdk.capture_event(event, hint=hint) |
| 86 | + |
| 87 | + |
| 88 | +def _wrap_chat(f, streaming): |
| 89 | + # type: (Callable[..., Any], bool) -> Callable[..., Any] |
| 90 | + |
| 91 | + def collect_chat_response_fields(span, res, include_pii): |
| 92 | + # type: (Span, NonStreamedChatResponse, bool) -> None |
| 93 | + if include_pii: |
| 94 | + if hasattr(res, "text"): |
| 95 | + set_data_normalized( |
| 96 | + span, |
| 97 | + SPANDATA.AI_RESPONSES, |
| 98 | + [res.text], |
| 99 | + ) |
| 100 | + for pii_attr in COLLECTED_PII_CHAT_RESP_ATTRS: |
| 101 | + if hasattr(res, pii_attr): |
| 102 | + set_data_normalized(span, "ai." + pii_attr, getattr(res, pii_attr)) |
| 103 | + |
| 104 | + for attr in COLLECTED_CHAT_RESP_ATTRS: |
| 105 | + if hasattr(res, attr): |
| 106 | + set_data_normalized(span, "ai." + attr, getattr(res, attr)) |
| 107 | + |
| 108 | + if hasattr(res, "meta"): |
| 109 | + if hasattr(res.meta, "billed_units"): |
| 110 | + record_token_usage( |
| 111 | + span, |
| 112 | + prompt_tokens=res.meta.billed_units.input_tokens, |
| 113 | + completion_tokens=res.meta.billed_units.output_tokens, |
| 114 | + ) |
| 115 | + elif hasattr(res.meta, "tokens"): |
| 116 | + record_token_usage( |
| 117 | + span, |
| 118 | + prompt_tokens=res.meta.tokens.input_tokens, |
| 119 | + completion_tokens=res.meta.tokens.output_tokens, |
| 120 | + ) |
| 121 | + |
| 122 | + if hasattr(res.meta, "warnings"): |
| 123 | + set_data_normalized(span, "ai.warnings", res.meta.warnings) |
| 124 | + |
| 125 | + @wraps(f) |
| 126 | + @ensure_integration_enabled(CohereIntegration, f) |
| 127 | + def new_chat(*args, **kwargs): |
| 128 | + # type: (*Any, **Any) -> Any |
| 129 | + if "message" not in kwargs: |
| 130 | + return f(*args, **kwargs) |
| 131 | + |
| 132 | + if not isinstance(kwargs.get("message"), str): |
| 133 | + return f(*args, **kwargs) |
| 134 | + |
| 135 | + message = kwargs.get("message") |
| 136 | + |
| 137 | + span = sentry_sdk.start_span( |
| 138 | + op=consts.OP.COHERE_CHAT_COMPLETIONS_CREATE, |
| 139 | + description="cohere.client.Chat", |
| 140 | + ) |
| 141 | + span.__enter__() |
| 142 | + try: |
| 143 | + res = f(*args, **kwargs) |
| 144 | + except Exception as e: |
| 145 | + _capture_exception(e) |
| 146 | + span.__exit__(None, None, None) |
| 147 | + raise e from None |
| 148 | + |
| 149 | + integration = sentry_sdk.get_client().get_integration(CohereIntegration) |
| 150 | + |
| 151 | + with capture_internal_exceptions(): |
| 152 | + if should_send_default_pii() and integration.include_prompts: |
| 153 | + set_data_normalized( |
| 154 | + span, |
| 155 | + SPANDATA.AI_INPUT_MESSAGES, |
| 156 | + list( |
| 157 | + map( |
| 158 | + lambda x: { |
| 159 | + "role": getattr(x, "role", "").lower(), |
| 160 | + "content": getattr(x, "message", ""), |
| 161 | + }, |
| 162 | + kwargs.get("chat_history", []), |
| 163 | + ) |
| 164 | + ) |
| 165 | + + [{"role": "user", "content": message}], |
| 166 | + ) |
| 167 | + for k, v in COLLECTED_PII_CHAT_PARAMS.items(): |
| 168 | + if k in kwargs: |
| 169 | + set_data_normalized(span, v, kwargs[k]) |
| 170 | + |
| 171 | + for k, v in COLLECTED_CHAT_PARAMS.items(): |
| 172 | + if k in kwargs: |
| 173 | + set_data_normalized(span, v, kwargs[k]) |
| 174 | + set_data_normalized(span, SPANDATA.AI_STREAMING, False) |
| 175 | + |
| 176 | + if streaming: |
| 177 | + old_iterator = res |
| 178 | + |
| 179 | + def new_iterator(): |
| 180 | + # type: () -> Iterator[StreamedChatResponse] |
| 181 | + |
| 182 | + with capture_internal_exceptions(): |
| 183 | + for x in old_iterator: |
| 184 | + if isinstance(x, ChatStreamEndEvent): |
| 185 | + collect_chat_response_fields( |
| 186 | + span, |
| 187 | + x.response, |
| 188 | + include_pii=should_send_default_pii() |
| 189 | + and integration.include_prompts, |
| 190 | + ) |
| 191 | + yield x |
| 192 | + |
| 193 | + span.__exit__(None, None, None) |
| 194 | + |
| 195 | + return new_iterator() |
| 196 | + elif isinstance(res, NonStreamedChatResponse): |
| 197 | + collect_chat_response_fields( |
| 198 | + span, |
| 199 | + res, |
| 200 | + include_pii=should_send_default_pii() |
| 201 | + and integration.include_prompts, |
| 202 | + ) |
| 203 | + span.__exit__(None, None, None) |
| 204 | + else: |
| 205 | + set_data_normalized(span, "unknown_response", True) |
| 206 | + span.__exit__(None, None, None) |
| 207 | + return res |
| 208 | + |
| 209 | + return new_chat |
| 210 | + |
| 211 | + |
| 212 | +def _wrap_embed(f): |
| 213 | + # type: (Callable[..., Any]) -> Callable[..., Any] |
| 214 | + |
| 215 | + @wraps(f) |
| 216 | + @ensure_integration_enabled(CohereIntegration, f) |
| 217 | + def new_embed(*args, **kwargs): |
| 218 | + # type: (*Any, **Any) -> Any |
| 219 | + with sentry_sdk.start_span( |
| 220 | + op=consts.OP.COHERE_EMBEDDINGS_CREATE, |
| 221 | + description="Cohere Embedding Creation", |
| 222 | + ) as span: |
| 223 | + integration = sentry_sdk.get_client().get_integration(CohereIntegration) |
| 224 | + if "texts" in kwargs and ( |
| 225 | + should_send_default_pii() and integration.include_prompts |
| 226 | + ): |
| 227 | + if isinstance(kwargs["texts"], str): |
| 228 | + set_data_normalized(span, "ai.texts", [kwargs["texts"]]) |
| 229 | + elif ( |
| 230 | + isinstance(kwargs["texts"], list) |
| 231 | + and len(kwargs["texts"]) > 0 |
| 232 | + and isinstance(kwargs["texts"][0], str) |
| 233 | + ): |
| 234 | + set_data_normalized( |
| 235 | + span, SPANDATA.AI_INPUT_MESSAGES, kwargs["texts"] |
| 236 | + ) |
| 237 | + |
| 238 | + if "model" in kwargs: |
| 239 | + set_data_normalized(span, SPANDATA.AI_MODEL_ID, kwargs["model"]) |
| 240 | + try: |
| 241 | + res = f(*args, **kwargs) |
| 242 | + except Exception as e: |
| 243 | + _capture_exception(e) |
| 244 | + raise e from None |
| 245 | + if ( |
| 246 | + hasattr(res, "meta") |
| 247 | + and hasattr(res.meta, "billed_units") |
| 248 | + and hasattr(res.meta.billed_units, "input_tokens") |
| 249 | + ): |
| 250 | + record_token_usage( |
| 251 | + span, |
| 252 | + prompt_tokens=res.meta.billed_units.input_tokens, |
| 253 | + total_tokens=res.meta.billed_units.input_tokens, |
| 254 | + ) |
| 255 | + return res |
| 256 | + |
| 257 | + return new_embed |
0 commit comments