diff --git a/logfire/_internal/cli/__init__.py b/logfire/_internal/cli/__init__.py index cb49be100..7c719f1da 100644 --- a/logfire/_internal/cli/__init__.py +++ b/logfire/_internal/cli/__init__.py @@ -25,6 +25,7 @@ from ..client import LogfireClient from ..config import REGIONS, LogfireCredentials, get_base_url_from_token from ..config_params import ParamManager +from ..server_response import install_logfire_response_hook from ..tracer import SDKTracerProvider from .auth import parse_auth, parse_logout from .prompt import parse_prompt @@ -434,8 +435,9 @@ def log_trace_id(response: requests.Response, context: ContextCarrier, *args: An else: with tracer.start_as_current_span('logfire._internal.cli'), requests.Session() as session: context = get_context() - session.hooks = {'response': functools.partial(log_trace_id, context=context)} + session.hooks = {'response': [functools.partial(log_trace_id, context=context)]} session.headers.update(context) + install_logfire_response_hook(session) namespace._session = session namespace.func(namespace) diff --git a/logfire/_internal/client.py b/logfire/_internal/client.py index 9bd9c77de..ac20ffbaa 100644 --- a/logfire/_internal/client.py +++ b/logfire/_internal/client.py @@ -10,6 +10,7 @@ from logfire.version import VERSION from .auth import UserToken, UserTokenCollection +from .server_response import TransportResponseHook, install_logfire_response_hook from .utils import UnexpectedResponse UA_HEADER = f'logfire/{VERSION}' @@ -29,18 +30,29 @@ class LogfireClient: Args: user_token: The user token to use when authenticating against the API. + transport_response_hook: Optional override for the API response hook (see + `AdvancedOptions.transport_response_hook`). """ - def __init__(self, user_token: UserToken) -> None: + def __init__( + self, + user_token: UserToken, + transport_response_hook: TransportResponseHook | None = None, + ) -> None: if user_token.is_expired: raise RuntimeError('The provided user token is expired') self.base_url = user_token.base_url self._token = user_token.token self._session = Session() self._session.headers.update({'Authorization': self._token, 'User-Agent': UA_HEADER}) + install_logfire_response_hook(self._session, transport_response_hook) @classmethod - def from_url(cls, base_url: str | None) -> Self: + def from_url( + cls, + base_url: str | None, + transport_response_hook: TransportResponseHook | None = None, + ) -> Self: """Create a client from the provided base URL. Args: @@ -48,8 +60,13 @@ def from_url(cls, base_url: str | None) -> Self: the user into selecting a token from the token collection (or, if only one available, use it directly). The token collection will be created from the `~/.logfire/default.toml` file (or an empty one if no such file exists). + transport_response_hook: Optional override for the API response hook (see + `AdvancedOptions.transport_response_hook`). """ - return cls(user_token=UserTokenCollection().get_token(base_url)) + return cls( + user_token=UserTokenCollection().get_token(base_url), + transport_response_hook=transport_response_hook, + ) def _get_raw(self, endpoint: str, params: dict[str, Any] | None = None) -> Response: response = self._session.get(urljoin(self.base_url, endpoint), params=params) diff --git a/logfire/_internal/config.py b/logfire/_internal/config.py index 42209f53e..ffce4bf99 100644 --- a/logfire/_internal/config.py +++ b/logfire/_internal/config.py @@ -106,6 +106,7 @@ from .logs import ProxyLoggerProvider from .metrics import ProxyMeterProvider from .scrubbing import NOOP_SCRUBBER, BaseScrubber, Scrubber, ScrubbingOptions +from .server_response import TransportResponseHook, install_logfire_response_hook from .stack_info import warn_at_user_stacklevel from .tracer import OPEN_SPANS, PendingSpanProcessor, ProxyTracerProvider from .utils import ( @@ -201,6 +202,29 @@ class AdvancedOptions: serialized configuration sent to child processes. See the [distributed tracing guide](https://logfire.pydantic.dev/docs/how-to-guides/distributed-tracing/#thread-and-pool-executors) for more details. """ + transport_response_hook: TransportResponseHook | None = None + """Optional callback invoked for every HTTP response received from the Logfire API. + + This applies to OTLP exports, credential / project initialisation, and the remote + variables provider. The default surfaces `X-Logfire-Warning` and `X-Logfire-Error` + headers as `LogfireServerWarning` / `LogfireServerError`. + + Setting this replaces the default; pass `lambda response: None` to opt out entirely, + or compose your own logic on top of `process_logfire_response_headers`: + + ```python skip-run="true" skip-reason="needs metric/logfire setup" + from logfire._internal.server_response import process_logfire_response_headers + + def hook(response): + my_metric.inc(response.status_code) + process_logfire_response_headers(response) + + logfire.configure(advanced=AdvancedOptions(transport_response_hook=hook)) + ``` + + Raise from the hook to abort the calling code path. + """ + def generate_base_url(self, token: str) -> str: if self.base_url is not None: return self.base_url @@ -1078,7 +1102,7 @@ def add_span_processor(span_processor: SpanProcessor) -> None: # If we don't have tokens or credentials from a file, # try initializing a new project and writing a new creds file. # note, we only do this if `send_to_logfire` is explicitly `True`, not 'if-token-present' - client = LogfireClient.from_url(self.advanced.base_url) + client = LogfireClient.from_url(self.advanced.base_url, self.advanced.transport_response_hook) credentials = LogfireCredentials.initialize_project(client=client) credentials.write_creds_file(self.data_dir) @@ -1129,6 +1153,7 @@ def check_tokens(): base_url = self.advanced.generate_base_url(token) headers = {'User-Agent': f'logfire/{VERSION}', 'Authorization': token} session = OTLPExporterHttpSession() + install_logfire_response_hook(session, self.advanced.transport_response_hook) span_exporter = BodySizeCheckingOTLPSpanExporter( endpoint=urljoin(base_url, '/v1/traces'), session=session, @@ -1305,6 +1330,7 @@ def fix_pid(): # pragma: no cover base_url=base_url, token=self.api_key, options=self.variables, + transport_response_hook=self.advanced.transport_response_hook, ) multi_log_processor = SynchronousMultiLogRecordProcessor() for processor in log_record_processors: @@ -1437,6 +1463,7 @@ def _lazy_init_variable_provider(self) -> VariableProvider: base_url=base_url, token=api_key, options=options, + transport_response_hook=self.advanced.transport_response_hook, ) self._variable_provider = provider provider.start(Logfire(config=self)) @@ -1453,7 +1480,9 @@ def warn_if_not_initialized(self, message: str): ) def _initialize_credentials_from_token(self, token: str) -> LogfireCredentials | None: - return LogfireCredentials.from_token(token, requests.Session(), self.advanced.generate_base_url(token)) + session = requests.Session() + install_logfire_response_hook(session, self.advanced.transport_response_hook) + return LogfireCredentials.from_token(token, session, self.advanced.generate_base_url(token)) def _ensure_flush_after_aws_lambda(self): """Ensure that `force_flush` is called after an AWS Lambda invocation. diff --git a/logfire/_internal/server_response.py b/logfire/_internal/server_response.py new file mode 100644 index 000000000..bb1e6b4c0 --- /dev/null +++ b/logfire/_internal/server_response.py @@ -0,0 +1,64 @@ +"""Surface out-of-band signals the Logfire backend wants every SDK request to know about. + +The server attaches custom headers to API responses: + +* `X-Logfire-Warning`: an out-of-band warning the server wants the user to see. + Surfaced via `warnings.warn(..., LogfireServerWarning)`. Python's standard + "default" filter dedupes identical messages, so a chatty server only warns once. +* `X-Logfire-Error`: an out-of-band error the server wants the SDK to raise. + Always raised as `LogfireServerError`. Callers that want to keep working past + it (the OTLP pipeline, the variables provider) already swallow exceptions from + their HTTP calls; CRUD/CLI propagate the error to the user. + +`install_logfire_response_hook(session)` wires this into a `requests.Session` as +a response hook so every Logfire-bound HTTP response is inspected. Callers can +pass a custom `hook` to replace the default behaviour (see +`AdvancedOptions.transport_response_hook`). +""" + +from __future__ import annotations + +import warnings +from typing import Any, Callable + +import requests + +from logfire.exceptions import LogfireServerError, LogfireServerWarning + +WARNING_HEADER_NAME = 'X-Logfire-Warning' +ERROR_HEADER_NAME = 'X-Logfire-Error' + +TransportResponseHook = Callable[[requests.Response], object] +"""Callable invoked for every Logfire API response received by the SDK. + +The return value is ignored; raise to abort the call. +""" + + +def process_logfire_response_headers(response: requests.Response) -> None: + """Default transport response hook: surface `X-Logfire-Warning` / `X-Logfire-Error` headers.""" + warning_message = response.headers.get(WARNING_HEADER_NAME) + if warning_message: + warnings.warn(warning_message, LogfireServerWarning, stacklevel=2) + error_message = response.headers.get(ERROR_HEADER_NAME) + if error_message: + raise LogfireServerError(error_message) + + +def install_logfire_response_hook( + session: requests.Session, + hook: TransportResponseHook | None = None, +) -> None: + """Install a `requests` response hook on `session` for every Logfire API response. + + `hook` defaults to `process_logfire_response_headers`. Pass a custom callable + to replace the default behaviour (e.g. opt out by passing `lambda response: None`). + """ + user_hook = hook if hook is not None else process_logfire_response_headers + + def _hook(response: requests.Response, *_args: Any, **_kwargs: Any) -> requests.Response: + user_hook(response) + return response + + response_hooks: list[Any] = session.hooks.setdefault('response', []) + response_hooks.append(_hook) diff --git a/logfire/exceptions.py b/logfire/exceptions.py index 617fba04f..532fb2720 100644 --- a/logfire/exceptions.py +++ b/logfire/exceptions.py @@ -3,3 +3,11 @@ class LogfireConfigError(ValueError): """Error raised when there is a problem with the Logfire configuration.""" + + +class LogfireServerError(Exception): + """Error raised when the Logfire server returns an `X-Logfire-Error` header on a response.""" + + +class LogfireServerWarning(UserWarning): + """Warning emitted when the Logfire server returns an `X-Logfire-Warning` header on a response.""" diff --git a/logfire/variables/remote.py b/logfire/variables/remote.py index b780a658f..91034a96c 100644 --- a/logfire/variables/remote.py +++ b/logfire/variables/remote.py @@ -17,6 +17,7 @@ from logfire._internal.client import UA_HEADER from logfire._internal.config import VariablesOptions +from logfire._internal.server_response import TransportResponseHook, install_logfire_response_hook from logfire._internal.utils import UnexpectedResponse from logfire.variables.abstract import ( ResolvedVariable, @@ -54,21 +55,31 @@ class LogfireRemoteVariableProvider(VariableProvider): The threading implementation draws heavily from opentelemetry.sdk._shared_internal.BatchProcessor. """ - def __init__(self, base_url: str, token: str, options: VariablesOptions): + def __init__( + self, + base_url: str, + token: str, + options: VariablesOptions, + transport_response_hook: TransportResponseHook | None = None, + ): """Create a new remote variable provider. Args: base_url: The base URL of the Logfire API. token: Authentication token for the Logfire API. options: Options for retrieving remote variables. + transport_response_hook: Optional override for the API response hook + (see `AdvancedOptions.transport_response_hook`). """ block_before_first_resolve = options.block_before_first_resolve polling_interval = options.polling_interval self._base_url = base_url self._token = token + self._transport_response_hook = transport_response_hook self._session = Session() self._session.headers.update({'Authorization': f'bearer {token}', 'User-Agent': UA_HEADER}) + install_logfire_response_hook(self._session, transport_response_hook) self._timeout = options.timeout self._block_before_first_fetch = block_before_first_resolve self._polling_interval: timedelta = ( @@ -197,6 +208,7 @@ def _sse_listener(self): # pragma: no cover 'Cache-Control': 'no-cache', } ) + install_logfire_response_hook(sse_session, self._transport_response_hook) # Open streaming connection response = sse_session.get(sse_url, stream=True, timeout=(10, None)) diff --git a/tests/test_server_response.py b/tests/test_server_response.py new file mode 100644 index 000000000..6af1ddeee --- /dev/null +++ b/tests/test_server_response.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import warnings + +import pytest +import requests +import requests_mock +from inline_snapshot import snapshot + +from logfire._internal.server_response import ( + ERROR_HEADER_NAME, + WARNING_HEADER_NAME, + process_logfire_response_headers, +) +from logfire.exceptions import LogfireServerError, LogfireServerWarning + + +def test_process_response_warning_header_emits_warning(): + response = requests.Response() + response.headers[WARNING_HEADER_NAME] = 'The /foo/bar endpoint is deprecated, please use /bar/baz' + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + process_logfire_response_headers(response) + assert [(w.category, str(w.message)) for w in caught] == snapshot( + [(LogfireServerWarning, 'The /foo/bar endpoint is deprecated, please use /bar/baz')] + ) + + +def test_process_response_warning_header_dedupes(): + """Python's default `warnings` filter should fold repeats of the same message into one entry.""" + response = requests.Response() + response.headers[WARNING_HEADER_NAME] = 'a duplicated warning' + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('default') + for _ in range(5): + process_logfire_response_headers(response) + messages = [str(w.message) for w in caught] + assert messages == ['a duplicated warning'] + + +def test_process_response_error_header_raises(): + response = requests.Response() + response.headers[ERROR_HEADER_NAME] = 'something is wrong' + with pytest.raises(LogfireServerError, match='something is wrong'): + process_logfire_response_headers(response) + + +def test_response_hook_installed_on_logfire_client(): + from logfire._internal.auth import UserToken + from logfire._internal.client import LogfireClient + + token = UserToken( + token='pylf_v1_us_xxx', + base_url='https://logfire-us.pydantic.dev', + expiration='2099-12-31T23:59:59', + ) + client = LogfireClient(user_token=token) + + with requests_mock.Mocker() as m: + m.get( + 'https://logfire-us.pydantic.dev/v1/account/me', + json={'name': 'me'}, + headers={WARNING_HEADER_NAME: 'deprecated endpoint'}, + ) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + client.get_user_information() + + assert any(isinstance(w.message, LogfireServerWarning) for w in caught) + + with requests_mock.Mocker() as m: + m.get( + 'https://logfire-us.pydantic.dev/v1/account/me', + json={'name': 'me'}, + headers={ERROR_HEADER_NAME: 'no longer supported'}, + ) + with pytest.raises(LogfireServerError, match='no longer supported'): + client.get_user_information() + + +def test_custom_transport_response_hook_replaces_default(): + """A custom hook replaces the built-in header processor entirely.""" + from logfire._internal.auth import UserToken + from logfire._internal.client import LogfireClient + + seen: list[requests.Response] = [] + + def my_hook(response: requests.Response) -> None: + seen.append(response) + + token = UserToken( + token='pylf_v1_us_xxx', + base_url='https://logfire-us.pydantic.dev', + expiration='2099-12-31T23:59:59', + ) + client = LogfireClient(user_token=token, transport_response_hook=my_hook) + + with requests_mock.Mocker() as m: + m.get( + 'https://logfire-us.pydantic.dev/v1/account/me', + json={'name': 'me'}, + # Both headers set: default would warn AND raise; custom hook ignores them. + headers={WARNING_HEADER_NAME: 'deprecated', ERROR_HEADER_NAME: 'broken'}, + ) + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter('always') + client.get_user_information() + + assert len(seen) == 1 + assert not any(isinstance(w.message, LogfireServerWarning) for w in caught) + + +def test_transport_response_hook_can_opt_out(): + """`lambda response: None` disables both warnings and errors.""" + from logfire._internal.auth import UserToken + from logfire._internal.client import LogfireClient + + token = UserToken( + token='pylf_v1_us_xxx', + base_url='https://logfire-us.pydantic.dev', + expiration='2099-12-31T23:59:59', + ) + client = LogfireClient(user_token=token, transport_response_hook=lambda response: None) + + with requests_mock.Mocker() as m: + m.get( + 'https://logfire-us.pydantic.dev/v1/account/me', + json={'name': 'me'}, + headers={ERROR_HEADER_NAME: 'no longer supported'}, + ) + # No exception raised. + assert client.get_user_information() == {'name': 'me'}