Skip to content

Commit 05d8f6f

Browse files
committed
fix: improve type hint for memory_service parameter
1 parent 0b2be2f commit 05d8f6f

2 files changed

Lines changed: 74 additions & 1 deletion

File tree

src/google/adk/cli/fast_api.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
from ..auth.credential_service.in_memory_credential_service import InMemoryCredentialService
3939
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
4040
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
41+
from ..memory.base_memory_service import BaseMemoryService
42+
from ..memory.in_memory_memory_service import InMemoryMemoryService
43+
from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
4144
from ..runners import Runner
4245
from .adk_web_server import AdkWebServer
4346
from .service_registry import load_services_module
@@ -79,7 +82,7 @@ def get_fast_api_app(
7982
artifact_service_uri: Optional[str] = None,
8083
memory_service_uri: Optional[str] = None,
8184
use_local_storage: bool = True,
82-
memory_service: Optional[Any] = None,
85+
memory_service: Optional[BaseMemoryService] = None,
8386
eval_storage_uri: Optional[str] = None,
8487
allow_origins: Optional[list[str]] = None,
8588
web: bool,

tests/unittests/cli/test_fast_api.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,5 +1717,75 @@ async def run_async_session_not_found(self, **kwargs):
17171717
assert "Session not found" in response.json()["detail"]
17181718

17191719

1720+
def test_get_fast_api_app_with_custom_memory_service(
1721+
mock_session_service,
1722+
mock_artifact_service,
1723+
mock_agent_loader,
1724+
mock_eval_sets_manager,
1725+
mock_eval_set_results_manager,
1726+
):
1727+
"""Test that custom memory_service is used directly when provided."""
1728+
custom_memory_service = MagicMock()
1729+
1730+
with (
1731+
patch.object(signal, "signal", autospec=True, return_value=None),
1732+
patch.object(
1733+
fast_api_module,
1734+
"create_session_service_from_options",
1735+
autospec=True,
1736+
return_value=mock_session_service,
1737+
),
1738+
patch.object(
1739+
fast_api_module,
1740+
"create_artifact_service_from_options",
1741+
autospec=True,
1742+
return_value=mock_artifact_service,
1743+
),
1744+
patch.object(
1745+
fast_api_module,
1746+
"create_memory_service_from_options",
1747+
autospec=True,
1748+
) as mock_create_memory_service,
1749+
patch.object(
1750+
fast_api_module,
1751+
"AgentLoader",
1752+
autospec=True,
1753+
return_value=mock_agent_loader,
1754+
),
1755+
patch.object(
1756+
fast_api_module,
1757+
"LocalEvalSetsManager",
1758+
autospec=True,
1759+
return_value=mock_eval_sets_manager,
1760+
),
1761+
patch.object(
1762+
fast_api_module,
1763+
"LocalEvalSetResultsManager",
1764+
autospec=True,
1765+
return_value=mock_eval_set_results_manager,
1766+
),
1767+
patch.object(
1768+
fast_api_module,
1769+
"load_services_module",
1770+
autospec=True,
1771+
return_value=None,
1772+
),
1773+
):
1774+
app = get_fast_api_app(
1775+
agents_dir=".",
1776+
web=True,
1777+
session_service_uri="",
1778+
artifact_service_uri="",
1779+
memory_service_uri="",
1780+
memory_service=custom_memory_service,
1781+
allow_origins=["*"],
1782+
a2a=False,
1783+
host="127.0.0.1",
1784+
port=8000,
1785+
)
1786+
1787+
mock_create_memory_service.assert_not_called()
1788+
1789+
17201790
if __name__ == "__main__":
17211791
pytest.main(["-xvs", __file__])

0 commit comments

Comments
 (0)