Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 258 additions & 6 deletions logfire/experimental/query_client.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from __future__ import annotations

import json
import platform
from datetime import datetime
from collections.abc import Generator
from contextlib import contextmanager
from datetime import datetime, timezone
from types import TracebackType
from typing import TYPE_CHECKING, Any, Generic, Literal, TypedDict, TypeVar
from typing import TYPE_CHECKING, Any, Generic, Literal, TypedDict, TypeVar, overload

from typing_extensions import Self
from typing_extensions import NotRequired, Self, TypeAlias, deprecated

from logfire import VERSION
from logfire._internal.config import get_base_url_from_token
Expand Down Expand Up @@ -74,6 +77,62 @@ class RowQueryResults(TypedDict):
rows: list[dict[str, Any]]


class RowQueryResultsV2(TypedDict):
"""The row-oriented results of a JSON-format query."""

columns: list[ColumnDetails]
rows: list[dict[str, Any]]


class RowQueryResultsV2Explained(RowQueryResultsV2):
logical_plan: Any
physical_plan: Any
physical_plan_with_metrics: Any


class StreamSchemaMessage(TypedDict):
"""First line of the NDJSON stream (omitted when ``include_schema=False``)."""

type: Literal['schema']
schema: dict[str, Any]


class StreamExplainMessage(TypedDict):
"""Emitted when ``explain=True``, after ``schema`` and before any ``data``."""

type: Literal['explain']
logical_plan: NotRequired[Any]
physical_plan: NotRequired[Any]


class StreamDataMessage(TypedDict):
"""One per Arrow record batch; repeats."""

type: Literal['data']
rows: list[dict[str, Any]]


class StreamErrorMessage(TypedDict):
"""Emitted when a record batch fails. Always followed by an ``end`` message."""

type: Literal['error']
message: str


class StreamEndMessage(TypedDict):
"""Final line of the stream. ``error`` is set if the stream failed mid-flight."""

type: Literal['end']
row_count: int
physical_plan_with_metrics: NotRequired[Any]
error: NotRequired[str]


StreamMessage: TypeAlias = (
StreamSchemaMessage | StreamExplainMessage | StreamDataMessage | StreamErrorMessage | StreamEndMessage
)


def _rows_to_columns(result: RowQueryResults) -> QueryResults:
"""Convert a row-oriented JSON query result to a column-oriented one."""
columns_by_name: dict[str, ColumnData] = {col['name']: {**col, 'values': []} for col in result['columns']}
Expand All @@ -83,11 +142,47 @@ def _rows_to_columns(result: RowQueryResults) -> QueryResults:
return {'columns': list(columns_by_name.values())}


_FF_DATA_TYPE_KEYS_TO_REMOVE = {'dict_id', 'dict_is_ordered', 'metadata'}


def _transform_fields_for_backwards_compatibility(obj: Any) -> Any:
"""Recursively removes all occurrences of _FF_DATA_TYPE_KEYS_TO_REMOVE as keys from arbitrary nesting within `obj`."""
if isinstance(obj, dict):
new_obj: dict[str, Any] = {}
for k, v in obj.items(): # type: ignore
if k in _FF_DATA_TYPE_KEYS_TO_REMOVE:
continue
if k == 'data_type':
k = 'datatype'
new_obj[k] = _transform_fields_for_backwards_compatibility(v)
return new_obj

elif isinstance(obj, list):
return [_transform_fields_for_backwards_compatibility(item) for item in obj] # type: ignore

else:
return obj


def _map_v2_result(obj: dict[str, Any]) -> RowQueryResultsV2 | RowQueryResultsV2Explained:
mapped: RowQueryResultsV2 | RowQueryResultsV2Explained = {
'columns': _transform_fields_for_backwards_compatibility(obj['schema']['fields']),
'rows': obj['data'],
}
if 'logical_plan' in obj:
# All the plan keys are guaranteed to be present:
for k in ['logical_plan', 'physical_plan', 'physical_plan_with_metrics']:
mapped[k] = obj[k]

return mapped


T = TypeVar('T', bound=BaseClient)


_ACCEPT = Literal['application/json', 'application/vnd.apache.arrow.stream', 'text/csv']
_USER_AGENT = f'logfire-sdk-python/{VERSION} (Python {platform.python_version()}, os {platform.platform()}, arch {platform.machine()})'
_MIN_DATETIME = datetime(2020, 1, 1, tzinfo=timezone.utc)


class _BaseLogfireQueryClient(Generic[T]):
Expand Down Expand Up @@ -119,6 +214,35 @@ def _build_query_params(
params['max_timestamp'] = max_timestamp.isoformat()
return params

def _build_v2_body(
self,
sql: str,
min_timestamp: datetime | None,
max_timestamp: datetime | None,
limit: int | None,
params: dict[str, str] | None = None,
timezone: str | None = None,
deployment_environment: str | list[str] | None = None,
explain: bool = False,
include_schema: bool = True,
) -> dict[str, Any]:
body: dict[str, Any] = {'sql': sql, 'explain': explain, 'include_schema': include_schema}

if limit is not None:
body['limit'] = limit
body['min_timestamp'] = (min_timestamp or _MIN_DATETIME).isoformat()
if max_timestamp is not None:
body['max_timestamp'] = max_timestamp.isoformat()
if params is not None:
body['params'] = params
if timezone is not None:
body['timezone'] = timezone
if isinstance(deployment_environment, str):
deployment_environment = [deployment_environment]
if deployment_environment is not None:
body['deployment_environment'] = deployment_environment
return body

def handle_response_errors(self, response: Response) -> None:
if response.status_code == 400: # pragma: no cover
raise QueryExecutionError(response.json())
Expand Down Expand Up @@ -168,6 +292,7 @@ def info(self) -> ReadTokenInfo:
'The read token info response is missing required fields: organization_name or project_name'
)

@deprecated('query_json() is deprecated, use query_json_rows() instead', stacklevel=2)
def query_json(
self,
sql: str,
Expand All @@ -184,22 +309,90 @@ def query_json(
)
return _rows_to_columns(row_results)

# Note: on the next major version, move the keyword-only marker after `sql`:
@overload
@deprecated('Using query_json_rows() without a min_timestamp is deprecated')
def query_json_rows(
self,
sql: str,
min_timestamp: None = None,
max_timestamp: datetime | None = None,
limit: int | None = None,
*,
params: dict[str, str] | None = None,
timezone: str | None = None,
deployment_environment: str | list[str] | None = None,
explain: Literal[True],
) -> RowQueryResultsV2Explained: ...

@overload
def query_json_rows(
self,
sql: str,
min_timestamp: datetime,
max_timestamp: datetime | None = None,
limit: int | None = None,
*,
params: dict[str, str] | None = None,
timezone: str | None = None,
deployment_environment: str | list[str] | None = None,
explain: Literal[True],
) -> RowQueryResultsV2Explained: ...

@overload
@deprecated('Using query_json_rows() without a min_timestamp is deprecated')
def query_json_rows(
self,
sql: str,
min_timestamp: None = None,
max_timestamp: datetime | None = None,
limit: int | None = None,
*,
params: dict[str, str] | None = None,
timezone: str | None = None,
deployment_environment: str | list[str] | None = None,
explain: Literal[False] = ...,
) -> RowQueryResultsV2: ...

@overload
def query_json_rows(
self,
sql: str,
min_timestamp: datetime,
max_timestamp: datetime | None = None,
limit: int | None = None,
*,
params: dict[str, str] | None = None,
timezone: str | None = None,
deployment_environment: str | list[str] | None = None,
explain: Literal[False] = ...,
) -> RowQueryResultsV2: ...

def query_json_rows(
self,
sql: str,
min_timestamp: datetime | None = None,
max_timestamp: datetime | None = None,
limit: int | None = None,
) -> RowQueryResults:
*,
params: dict[str, str] | None = None,
timezone: str | None = None,
deployment_environment: str | list[str] | None = None,
explain: bool = False,
) -> RowQueryResultsV2 | RowQueryResultsV2Explained:
"""Query Logfire data and return the results as a row-oriented dictionary."""
response = self._query(
response = self._query_v2(
accept='application/json',
sql=sql,
min_timestamp=min_timestamp,
max_timestamp=max_timestamp,
limit=limit,
params=params,
timezone=timezone,
deployment_environment=deployment_environment,
explain=explain,
)
return response.json()
return _map_v2_result(response.json())

def query_arrow( # pyright: ignore[reportUnknownParameterType]
self,
Expand Down Expand Up @@ -250,6 +443,37 @@ def query_csv(
)
return response.text

@contextmanager
def query_stream(
self,
sql: str,
min_timestamp: datetime | None = None,
max_timestamp: datetime | None = None,
limit: int | None = None,
*,
params: dict[str, str] | None = None,
timezone: str | None = None,
deployment_environment: str | list[str] | None = None,
explain: bool = False,
include_schema: bool = True,
) -> Generator[Generator[StreamMessage]]:
body = self._build_v2_body(
sql=sql,
min_timestamp=min_timestamp,
max_timestamp=max_timestamp,
limit=limit,
params=params,
timezone=timezone,
deployment_environment=deployment_environment,
explain=explain,
include_schema=include_schema,
)
with self.client.stream('POST', '/v2/query', headers={'accept': 'application/x-ndjson'}, json=body) as response:
if response.status_code != 200:
response.read()
self.handle_response_errors(response)
yield from (json.loads(line) for line in response.iter_lines() if line)

def _query(
self,
accept: _ACCEPT,
Expand All @@ -265,6 +489,34 @@ def _query(
self.handle_response_errors(response)
return response

def _query_v2(
self,
*,
accept: _ACCEPT,
sql: str,
min_timestamp: datetime | None = None,
max_timestamp: datetime | None = None,
limit: int | None = None,
params: dict[str, str] | None = None,
timezone: str | None = None,
deployment_environment: str | list[str] | None = None,
explain: bool = False,
) -> Response:

body = self._build_v2_body(
sql=sql,
min_timestamp=min_timestamp,
max_timestamp=max_timestamp,
limit=limit,
params=params,
timezone=timezone,
deployment_environment=deployment_environment,
explain=explain,
)
response = self.client.post('/v2/query', headers={'accept': accept}, json=body)
self.handle_response_errors(response)
return response


class AsyncLogfireQueryClient(_BaseLogfireQueryClient[AsyncClient]):
"""An asynchronous client for querying Logfire data."""
Expand Down
Loading