Skip to content

Commit 25ae7f9

Browse files
authored
Merge branch 'main' into fix/handle-circular-refs
2 parents 9aa575d + a39ca94 commit 25ae7f9

File tree

2 files changed

+696
-0
lines changed

2 files changed

+696
-0
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""SQLite-backed OpenTelemetry span exporter for local development."""
16+
17+
from __future__ import annotations
18+
19+
import json
20+
import logging
21+
import sqlite3
22+
import threading
23+
from typing import Any
24+
from typing import Iterable
25+
from typing import Optional
26+
from typing import Sequence
27+
28+
from opentelemetry.sdk.trace import ReadableSpan
29+
from opentelemetry.sdk.trace.export import SpanExporter
30+
from opentelemetry.sdk.trace.export import SpanExportResult
31+
from opentelemetry.trace import SpanContext
32+
from opentelemetry.trace import TraceFlags
33+
from opentelemetry.trace import TraceState
34+
35+
logger = logging.getLogger("google_adk." + __name__)
36+
37+
_CREATE_SPANS_TABLE = """
38+
CREATE TABLE IF NOT EXISTS spans (
39+
span_id TEXT PRIMARY KEY,
40+
trace_id TEXT NOT NULL,
41+
parent_span_id TEXT,
42+
name TEXT NOT NULL,
43+
start_time_unix_nano INTEGER,
44+
end_time_unix_nano INTEGER,
45+
session_id TEXT,
46+
invocation_id TEXT,
47+
attributes_json TEXT
48+
);
49+
"""
50+
51+
_CREATE_SESSION_INDEX = """
52+
CREATE INDEX IF NOT EXISTS spans_session_id_idx ON spans(session_id);
53+
"""
54+
55+
_CREATE_TRACE_INDEX = """
56+
CREATE INDEX IF NOT EXISTS spans_trace_id_idx ON spans(trace_id);
57+
"""
58+
59+
_INSERT_SPAN = """
60+
INSERT OR REPLACE INTO spans (
61+
span_id,
62+
trace_id,
63+
parent_span_id,
64+
name,
65+
start_time_unix_nano,
66+
end_time_unix_nano,
67+
session_id,
68+
invocation_id,
69+
attributes_json
70+
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?);
71+
"""
72+
73+
_DEFAULT_TIMEOUT_SECONDS = 30.0
74+
75+
76+
class SqliteSpanExporter(SpanExporter):
77+
"""Exports spans to a local SQLite database.
78+
79+
This is intended for local development (e.g. `adk web`) to allow reloading
80+
traces for older sessions after process restart.
81+
"""
82+
83+
def __init__(self, *, db_path: str):
84+
self._db_path = db_path
85+
self._lock = threading.Lock()
86+
self._conn: Optional[sqlite3.Connection] = None
87+
self._ensure_schema()
88+
89+
def _get_connection(self) -> sqlite3.Connection:
90+
if self._conn is None:
91+
self._conn = sqlite3.connect(
92+
self._db_path,
93+
timeout=_DEFAULT_TIMEOUT_SECONDS,
94+
check_same_thread=False,
95+
)
96+
self._conn.row_factory = sqlite3.Row
97+
return self._conn
98+
99+
def _ensure_schema(self) -> None:
100+
with self._lock:
101+
conn = self._get_connection()
102+
conn.execute(_CREATE_SPANS_TABLE)
103+
conn.execute(_CREATE_SESSION_INDEX)
104+
conn.execute(_CREATE_TRACE_INDEX)
105+
conn.commit()
106+
107+
def _serialize_attributes(self, attributes: dict[str, Any]) -> str:
108+
try:
109+
return json.dumps(
110+
attributes,
111+
ensure_ascii=False,
112+
default=lambda o: "<not serializable>",
113+
)
114+
except (TypeError, ValueError) as e:
115+
logger.debug("Failed to serialize span attributes: %r", e)
116+
return "{}"
117+
118+
def _deserialize_attributes(self, attributes_json: Any) -> dict[str, Any]:
119+
if not attributes_json:
120+
return {}
121+
try:
122+
attributes = json.loads(attributes_json)
123+
except (json.JSONDecodeError, TypeError) as e:
124+
logger.debug("Failed to deserialize span attributes: %r", e)
125+
return {}
126+
return attributes if isinstance(attributes, dict) else {}
127+
128+
def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
129+
try:
130+
with self._lock:
131+
conn = self._get_connection()
132+
rows: list[tuple[Any, ...]] = []
133+
for span in spans:
134+
attributes = dict(span.attributes) if span.attributes else {}
135+
session_id = attributes.get(
136+
"gcp.vertex.agent.session_id"
137+
) or attributes.get("gen_ai.conversation.id")
138+
invocation_id = attributes.get("gcp.vertex.agent.invocation_id")
139+
140+
parent_span_id = None
141+
if span.parent is not None:
142+
parent_span_id = format(span.parent.span_id, "016x")
143+
144+
rows.append((
145+
format(span.context.span_id, "016x"),
146+
format(span.context.trace_id, "032x"),
147+
parent_span_id,
148+
span.name,
149+
span.start_time,
150+
span.end_time,
151+
session_id,
152+
invocation_id,
153+
self._serialize_attributes(attributes),
154+
))
155+
conn.executemany(_INSERT_SPAN, rows)
156+
conn.commit()
157+
return SpanExportResult.SUCCESS
158+
except Exception as e: # pylint: disable=broad-exception-caught
159+
logger.warning("Failed to export spans to SQLite: %s", e)
160+
return SpanExportResult.FAILURE
161+
162+
def shutdown(self) -> None:
163+
with self._lock:
164+
if self._conn is not None:
165+
self._conn.close()
166+
self._conn = None
167+
168+
def force_flush(self, timeout_millis: int = 30000) -> bool:
169+
return True
170+
171+
def _query(self, sql: str, params: Iterable[Any]) -> list[sqlite3.Row]:
172+
with self._lock:
173+
conn = self._get_connection()
174+
cur = conn.execute(sql, tuple(params))
175+
return list(cur.fetchall())
176+
177+
def _row_to_readable_span(self, row: sqlite3.Row) -> ReadableSpan:
178+
trace_id_hex = row["trace_id"]
179+
span_id_hex = row["span_id"]
180+
trace_id = int(str(trace_id_hex), 16)
181+
span_id = int(str(span_id_hex), 16)
182+
trace_state = TraceState()
183+
trace_flags = TraceFlags(TraceFlags.SAMPLED)
184+
context = SpanContext(
185+
trace_id=trace_id,
186+
span_id=span_id,
187+
is_remote=False,
188+
trace_flags=trace_flags,
189+
trace_state=trace_state,
190+
)
191+
192+
parent: SpanContext | None = None
193+
parent_span_id_hex = row["parent_span_id"]
194+
if parent_span_id_hex:
195+
parent = SpanContext(
196+
trace_id=trace_id,
197+
span_id=int(str(parent_span_id_hex), 16),
198+
is_remote=False,
199+
trace_flags=trace_flags,
200+
trace_state=trace_state,
201+
)
202+
203+
attributes = self._deserialize_attributes(row["attributes_json"])
204+
return ReadableSpan(
205+
name=row["name"] or "",
206+
context=context,
207+
parent=parent,
208+
attributes=attributes,
209+
start_time=row["start_time_unix_nano"],
210+
end_time=row["end_time_unix_nano"],
211+
)
212+
213+
def get_all_spans_for_session(self, session_id: str) -> list[ReadableSpan]:
214+
"""Returns all spans for a session (full trace trees).
215+
216+
We first find trace_ids associated with the session, then return all spans
217+
for those trace_ids. This works even if some spans are missing session_id
218+
attributes (e.g. parent spans).
219+
"""
220+
trace_rows = self._query(
221+
"SELECT DISTINCT trace_id FROM spans WHERE session_id = ?",
222+
(session_id,),
223+
)
224+
trace_ids = [r["trace_id"] for r in trace_rows if r["trace_id"]]
225+
if not trace_ids:
226+
return []
227+
228+
placeholders = ",".join("?" for _ in trace_ids)
229+
rows = self._query(
230+
f"SELECT * FROM spans WHERE trace_id IN ({placeholders}) "
231+
"ORDER BY start_time_unix_nano",
232+
trace_ids,
233+
)
234+
return [self._row_to_readable_span(row) for row in rows]

0 commit comments

Comments
 (0)