Skip to content

Commit bb89466

Browse files
GWealecopybara-github
authored andcommitted
chore: Improve type hints and handle None values in ADK utils
Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 866025998
1 parent adbc37f commit bb89466

File tree

8 files changed

+48
-36
lines changed

8 files changed

+48
-36
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ asyncio_mode = "auto"
219219
python_version = "3.10"
220220
exclude = ["tests/", "contributing/samples/"]
221221
plugins = ["pydantic.mypy"]
222-
# Start with non-strict mode, and swtich to strict mode later.
223222
strict = true
224223
disable_error_code = ["import-not-found", "import-untyped", "unused-ignore"]
225224
follow_imports = "skip"

src/google/adk/a2a/converters/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def _to_a2a_context_id(app_name: str, user_id: str, session_id: str) -> str:
5959
)
6060

6161

62-
def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]:
62+
def _from_a2a_context_id(
63+
context_id: str | None,
64+
) -> tuple[str, str, str] | tuple[None, None, None]:
6365
"""Converts an A2A context id to app name, user id and session id.
6466
if context_id is None, return None, None, None
6567
if context_id is not None, but not in the format of
@@ -69,7 +71,7 @@ def _from_a2a_context_id(context_id: str) -> tuple[str, str, str]:
6971
context_id: The A2A context id.
7072
7173
Returns:
72-
The app name, user id and session id.
74+
The app name, user id and session id, or (None, None, None) if invalid.
7375
"""
7476
if not context_id:
7577
return None, None, None

src/google/adk/models/cache_metadata.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,12 @@ def __str__(self) -> str:
113113
f"fingerprint={self.fingerprint[:8]}..."
114114
)
115115
cache_id = self.cache_name.split("/")[-1]
116+
if self.expire_time is None:
117+
return (
118+
f"Cache {cache_id}: used {self.invocations_used} invocations, "
119+
f"cached {self.contents_count} contents, "
120+
"expires unknown"
121+
)
116122
time_until_expiry_minutes = (self.expire_time - time.time()) / 60
117123
return (
118124
f"Cache {cache_id}: used {self.invocations_used} invocations, "

src/google/adk/utils/_client_labels_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from collections.abc import Iterator
1718
from contextlib import contextmanager
1819
import contextvars
1920
import os
@@ -32,7 +33,7 @@
3233
"""Label used to denote calls emerging to external system as a part of Evals."""
3334

3435
# The ContextVar holds client label collected for the current request.
35-
_LABEL_CONTEXT: contextvars.ContextVar[str] = contextvars.ContextVar(
36+
_LABEL_CONTEXT: contextvars.ContextVar[str | None] = contextvars.ContextVar(
3637
"_LABEL_CONTEXT", default=None
3738
)
3839

@@ -49,7 +50,7 @@ def _get_default_labels() -> List[str]:
4950

5051

5152
@contextmanager
52-
def client_label_context(client_label: str):
53+
def client_label_context(client_label: str) -> Iterator[None]:
5354
"""Runs the operation within the context of the given client label."""
5455
current_client_label = _LABEL_CONTEXT.get()
5556

src/google/adk/utils/content_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020
def is_audio_part(part: types.Part) -> bool:
2121
return (
22-
part.inline_data
23-
and part.inline_data.mime_type
22+
part.inline_data is not None
23+
and part.inline_data.mime_type is not None
2424
and part.inline_data.mime_type.startswith('audio/')
2525
) or (
26-
part.file_data
27-
and part.file_data.mime_type
26+
part.file_data is not None
27+
and part.file_data.mime_type is not None
2828
and part.file_data.mime_type.startswith('audio/')
2929
)
3030

src/google/adk/utils/feature_decorator.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,16 @@
1414

1515
from __future__ import annotations
1616

17+
from collections.abc import Callable
1718
import functools
1819
import os
19-
from typing import Callable
20+
from typing import Any
2021
from typing import cast
2122
from typing import Optional
2223
from typing import TypeVar
23-
from typing import Union
2424
import warnings
2525

26-
T = TypeVar("T", bound=Union[Callable, type])
26+
T = TypeVar("T")
2727

2828

2929
def _is_truthy_env(var_name: str) -> bool:
@@ -39,8 +39,8 @@ def _make_feature_decorator(
3939
default_message: str,
4040
block_usage: bool = False,
4141
bypass_env_var: Optional[str] = None,
42-
) -> Callable:
43-
def decorator_factory(message_or_obj=None):
42+
) -> Callable[..., Any]:
43+
def decorator_factory(message_or_obj: Any = None) -> Any:
4444
# Case 1: Used as @decorator without parentheses
4545
# message_or_obj is the decorated class/function
4646
if message_or_obj is not None and (
@@ -68,10 +68,11 @@ def decorator(obj: T) -> T:
6868
msg = f"[{label.upper()}] {obj_name}: {message}"
6969

7070
if isinstance(obj, type): # decorating a class
71-
orig_init = obj.__init__
71+
cls = cast(type[Any], obj)
72+
orig_init = cast(Any, cls).__init__
7273

7374
@functools.wraps(orig_init)
74-
def new_init(self, *args, **kwargs):
75+
def new_init(self: Any, *args: Any, **kwargs: Any) -> Any:
7576
# Check if usage should be bypassed via environment variable at call time
7677
should_bypass = bypass_env_var is not None and _is_truthy_env(
7778
bypass_env_var
@@ -86,13 +87,14 @@ def new_init(self, *args, **kwargs):
8687
warnings.warn(msg, category=UserWarning, stacklevel=2)
8788
return orig_init(self, *args, **kwargs)
8889

89-
obj.__init__ = new_init # type: ignore[attr-defined]
90-
return cast(T, obj)
90+
cast(Any, cls).__init__ = new_init
91+
return cast(T, cls)
9192

9293
elif callable(obj): # decorating a function or method
94+
func = cast(Callable[..., Any], obj)
9395

94-
@functools.wraps(obj)
95-
def wrapper(*args, **kwargs):
96+
@functools.wraps(func)
97+
def wrapper(*args: Any, **kwargs: Any) -> Any:
9698
# Check if usage should be bypassed via environment variable at call time
9799
should_bypass = bypass_env_var is not None and _is_truthy_env(
98100
bypass_env_var
@@ -105,7 +107,7 @@ def wrapper(*args, **kwargs):
105107
raise RuntimeError(msg)
106108
else:
107109
warnings.warn(msg, category=UserWarning, stacklevel=2)
108-
return obj(*args, **kwargs)
110+
return func(*args, **kwargs)
109111

110112
return cast(T, wrapper)
111113

src/google/adk/utils/output_schema_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from .variant_utils import GoogleLLMVariant
2929

3030

31-
def can_use_output_schema_with_tools(model: Union[str, BaseLlm]):
31+
def can_use_output_schema_with_tools(model: Union[str, BaseLlm]) -> bool:
3232
"""Returns True if output schema with tools is supported."""
3333
model_string = model if isinstance(model, str) else model.model
3434

src/google/adk/utils/streaming_utils.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class StreamingResponseAggregator:
3232
individual (partial) model responses, as well as for aggregated content.
3333
"""
3434

35-
def __init__(self):
35+
def __init__(self) -> None:
3636
self._text = ''
3737
self._thought_text = ''
3838
self._usage_metadata = None
@@ -48,9 +48,9 @@ def __init__(self):
4848
self._current_fc_name: Optional[str] = None
4949
self._current_fc_args: dict[str, Any] = {}
5050
self._current_fc_id: Optional[str] = None
51-
self._current_thought_signature: Optional[str] = None
51+
self._current_thought_signature: Optional[bytes] = None
5252

53-
def _flush_text_buffer_to_sequence(self):
53+
def _flush_text_buffer_to_sequence(self) -> None:
5454
"""Flush current text buffer to parts sequence.
5555
5656
This helper is used in progressive SSE mode to maintain part ordering.
@@ -70,7 +70,7 @@ def _flush_text_buffer_to_sequence(self):
7070

7171
def _get_value_from_partial_arg(
7272
self, partial_arg: types.PartialArg, json_path: str
73-
):
73+
) -> tuple[Any, bool]:
7474
"""Extract value from a partial argument.
7575
7676
Args:
@@ -80,7 +80,7 @@ def _get_value_from_partial_arg(
8080
Returns:
8181
Tuple of (value, has_value) where has_value indicates if a value exists
8282
"""
83-
value = None
83+
value: Any = None
8484
has_value = False
8585

8686
if partial_arg.string_value is not None:
@@ -95,12 +95,11 @@ def _get_value_from_partial_arg(
9595
path_parts = path_without_prefix.split('.')
9696

9797
# Try to get existing value
98-
existing_value = self._current_fc_args
98+
existing_value: Any = self._current_fc_args
9999
for part in path_parts:
100100
if isinstance(existing_value, dict) and part in existing_value:
101101
existing_value = existing_value[part]
102102
else:
103-
existing_value = None
104103
break
105104

106105
# Append to existing string or set new value
@@ -121,7 +120,7 @@ def _get_value_from_partial_arg(
121120

122121
return value, has_value
123122

124-
def _set_value_by_json_path(self, json_path: str, value: Any):
123+
def _set_value_by_json_path(self, json_path: str, value: Any) -> None:
125124
"""Set a value in _current_fc_args using JSONPath notation.
126125
127126
Args:
@@ -147,7 +146,7 @@ def _set_value_by_json_path(self, json_path: str, value: Any):
147146
# Set the final value
148147
current[path_parts[-1]] = value
149148

150-
def _flush_function_call_to_sequence(self):
149+
def _flush_function_call_to_sequence(self) -> None:
151150
"""Flush current function call to parts sequence.
152151
153152
This creates a complete FunctionCall part from accumulated partial args.
@@ -175,7 +174,7 @@ def _flush_function_call_to_sequence(self):
175174
self._current_fc_id = None
176175
self._current_thought_signature = None
177176

178-
def _process_streaming_function_call(self, fc: types.FunctionCall):
177+
def _process_streaming_function_call(self, fc: types.FunctionCall) -> None:
179178
"""Process a streaming function call with partialArgs.
180179
181180
Args:
@@ -208,14 +207,14 @@ def _process_streaming_function_call(self, fc: types.FunctionCall):
208207
self._flush_text_buffer_to_sequence()
209208
self._flush_function_call_to_sequence()
210209

211-
def _process_function_call_part(self, part: types.Part):
210+
def _process_function_call_part(self, part: types.Part) -> None:
212211
"""Process a function call part (streaming or non-streaming).
213212
214213
Args:
215214
part: The part containing a function call
216215
"""
217216
fc = part.function_call
218-
if not fc:
217+
if fc is None:
219218
return
220219

221220
# Check if this is a streaming FC (has partialArgs or will_continue=True)
@@ -298,10 +297,11 @@ async def process_response(
298297
and llm_response.content.parts[0].text
299298
):
300299
part0 = llm_response.content.parts[0]
300+
part_text = part0.text or ''
301301
if part0.thought:
302-
self._thought_text += part0.text
302+
self._thought_text += part_text
303303
else:
304-
self._text += part0.text
304+
self._text += part_text
305305
llm_response.partial = True
306306
elif (self._thought_text or self._text) and (
307307
not llm_response.content
@@ -382,3 +382,5 @@ def close(self) -> Optional[LlmResponse]:
382382
else candidate.finish_message,
383383
usage_metadata=self._usage_metadata,
384384
)
385+
386+
return None

0 commit comments

Comments
 (0)