diff --git a/README.md b/README.md index 438b25c0..5189aeef 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,7 @@ Below is a comprehensive table of all available tools, how to use them with an a | Tool | Agent Usage | Use Case | |------|-------------|----------| | a2a_client | `provider = A2AClientToolProvider(known_agent_urls=["http://localhost:9000"]); agent = Agent(tools=provider.tools)` | Discover and communicate with A2A-compliant agents, send messages between agents | +| a2a_registry_client | `provider = AgentRegistryToolProvider(registry_url="http://localhost:8000"); agent = Agent(tools=provider.tools)` | Find and communicate with agents through a centralized registry, search by skills, find best matches | | file_read | `agent.tool.file_read(path="path/to/file.txt")` | Reading configuration files, parsing code files, loading datasets | | file_write | `agent.tool.file_write(path="path/to/file.txt", content="file content")` | Writing results to files, creating new files, saving output data | | editor | `agent.tool.editor(command="view", path="path/to/file.py")` | Advanced file operations like syntax highlighting, pattern replacement, and multi-file edits | @@ -639,6 +640,30 @@ response = agent("discover available agents and send a greeting message") # - send_message(message_text, target_agent_url) to communicate ``` +### A2A Registry Client + +```python +from strands import Agent +from strands_tools.a2a_registry_client import AgentRegistryToolProvider + +# Initialize the registry client provider +provider = AgentRegistryToolProvider(registry_url="http://localhost:8000") +agent = Agent(tools=provider.tools) + +# Use natural language to interact with the agent registry +response = agent("find agents that can help with python development") +response = agent("send a message to an agent with data processor skills asking it to process this dataset") +response = agent("find the best agent for machine learning tasks and ask it to train a model") + +# The agent will automatically use the available tools: +# - registry_find_agents_by_skill(skill) to search by skills +# - registry_find_best_agent_for_task(skills) to find optimal matches +# - registry_send_message_to_agent(name, message) to communicate +# - registry_find_and_message_agent(skills, message) for combined operations +# - registry_get_all_agents() to list all available agents +# - registry_find_similar_agents(agent_id) to find similar agents +``` + ### Diagram ```python diff --git a/src/strands_tools/a2a_registry_client.py b/src/strands_tools/a2a_registry_client.py new file mode 100644 index 00000000..577a8d36 --- /dev/null +++ b/src/strands_tools/a2a_registry_client.py @@ -0,0 +1,491 @@ +""" +Agent Registry Tool Provider for Strands Agents. + +Provides tools for discovering and searching agents through a centralized registry. +""" + +import logging +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional +from uuid import uuid4 + +import httpx +from a2a.client import ClientConfig, ClientFactory +from a2a.types import AgentCard, Message, Part, Role, TextPart, TransportProtocol +from strands import tool +from strands.types.tools import AgentTool + +logger = logging.getLogger(__name__) + + +@dataclass +class AgentSearchCriteria: + required_skills: List[str] = None + preferred_skills: List[str] = None + min_version: str = None + max_response_time_ms: int = None + regions: List[str] = None + exclude_agents: List[str] = None + + +class AgentRegistryToolProvider: + """Agent Registry tool provider with both discovery and communication capabilities.""" + + def __init__( + self, + registry_url: str = "http://localhost:8000", + timeout: int = 30, + agent_auth=None, + transports: Dict[str, Callable] = None, + default_preferred_transport: str = TransportProtocol.http_json, + ): + self.registry_url = registry_url + self.timeout = timeout + self.agent_auth = agent_auth + self.transports = transports or {} + self._httpx_client: httpx.AsyncClient | None = None + self._client_factory: ClientFactory | None = None + self._request_id = 0 + self._agent_cache: Dict[str, AgentCard] = {} + self._default_preferred_transport = default_preferred_transport + + logger.info( + f"Initialized AgentRegistryToolProvider with registry_url={registry_url}, " + f"timeout={timeout}s, transports={list(self.transports.keys())}" + ) + + @property + def tools(self) -> List[AgentTool]: + """Extract all @tool decorated methods from this instance.""" + tools = [] + for attr_name in dir(self): + if attr_name == "tools": + continue + attr = getattr(self, attr_name) + if isinstance(attr, AgentTool): + tools.append(attr) + return tools + + async def _ensure_httpx_client(self) -> httpx.AsyncClient: + """Ensure the shared HTTP client is initialized.""" + if self._httpx_client is None: + self._httpx_client = httpx.AsyncClient(timeout=self.timeout) + return self._httpx_client + + async def _ensure_client_factory(self, agent_card: AgentCard = None) -> ClientFactory: + """Ensure the ClientFactory is initialized with proper auth.""" + httpx_client = await self._ensure_httpx_client() + + # Set authentication if agent card and auth class are provided + if agent_card and self.agent_auth: + logger.info( + f"🔐 Applying {self.agent_auth.__name__} authentication for agent {agent_card.name} at {agent_card.url}" + ) + httpx_client.auth = self.agent_auth(agent_card) + else: + logger.warning(f"⚠️ No authentication configured for agent {agent_card.name if agent_card else 'unknown'}") + + # Build supported transports list from registered transports + supported_transports = [TransportProtocol.http_json, TransportProtocol.jsonrpc] + list(self.transports.keys()) + + config = ClientConfig(httpx_client=httpx_client, streaming=False, supported_transports=supported_transports) + + client_factory = ClientFactory(config) + + # Register all configured transports + for transport_name, transport_factory in self.transports.items(): + client_factory.register(transport_name, transport_factory) + + return client_factory + + async def _get_agent_card_from_registry(self, agent_name: str) -> Optional[AgentCard]: + """Get agent card from registry and convert to A2A AgentCard.""" + logger.info(f"Getting agent card for {agent_name} from registry") + try: + result = await self._jsonrpc_request("get_agent", {"agent_id": agent_name}) + if result.get("found"): + agent_data = result.get("agent_card") + if agent_data: + logger.info(f"Successfully retrieved agent card for {agent_name}") + return AgentCard(**agent_data) + + logger.warning(f"No agent card found for {agent_name}") + return None + + except Exception as e: + error_msg = f"Failed to get agent card for {agent_name}: {e}" + logger.error(error_msg) + print(error_msg) + return None + + async def _send_message_to_agent_direct( + self, agent_data: dict, message_text: str, message_id: str = None + ) -> Dict[str, Any]: + """Send message to agent using agent data directly (bypassing registry lookup).""" + agent_name = agent_data.get("name", "unknown") + logger.info(f"Sending message directly to agent {agent_name}: {message_text[:100]}...") + try: + # Ensure agent data has all required AgentCard fields + agent_card_data = { + "name": agent_data.get("name", "unknown"), + "description": agent_data.get("description", ""), + "url": agent_data.get("url", ""), + "version": agent_data.get("version", "1.0.0"), + "protocol_version": agent_data.get("protocol_version", "0.3.0"), + "preferred_transport": agent_data.get("preferred_transport", self._default_preferred_transport), + "skills": agent_data.get("skills", []), + "capabilities": agent_data.get("capabilities", {}), + "default_input_modes": agent_data.get("default_input_modes", ["application/json"]), + "default_output_modes": agent_data.get("default_output_modes", ["application/json"]), + } + + # Convert agent data to AgentCard format + agent_card = AgentCard(**agent_card_data) + + # Create A2A client and send message with proper auth + logger.info( + f"🔧 Creating A2A client for agent {agent_name} with transport: {agent_card.preferred_transport}" + ) + client_factory = await self._ensure_client_factory(agent_card) + client = client_factory.create(agent_card) + logger.info(f"🔧 Created client type: {type(client).__name__}") + + if message_id is None: + message_id = uuid4().hex + + message = Message( + kind="message", + role=Role.user, + parts=[Part(TextPart(kind="text", text=message_text))], + message_id=message_id, + ) + + # Send message and get response - collect all events first + logger.debug(f"Sending message {message_id} to agent {agent_name}") + events = [] + async for event in client.send_message(message): + logger.info(f"Received event type: {type(event)} from agent {agent_name}") + events.append(event) + + # Process collected events + for event in events: + logger.info(f"Processing event type: {type(event)} from agent {agent_name}") + logger.info(f"Event content: {event}") + if isinstance(event, Message): + logger.info(f"Received message response from agent {agent_name}") + return { + "status": "success", + "response": event.model_dump(mode="python", exclude_none=True), + "agent_name": agent_name, + "message_id": message_id, + } + elif isinstance(event, tuple) and len(event) == 2: + task, update_event = event + logger.info(f"Received task response from agent {agent_name}") + return { + "status": "success", + "response": { + "task": task.model_dump(mode="python", exclude_none=True), + "update": update_event.model_dump(mode="python", exclude_none=True) + if update_event + else None, + }, + "agent_name": agent_name, + "message_id": message_id, + } + elif isinstance(event, dict): + logger.info(f"Received dict response from agent {agent_name}") + return {"status": "success", "response": event, "agent_name": agent_name, "message_id": message_id} + else: + logger.warning(f"Unknown event type {type(event)}: {event}") + + first_event_type = type(events[0]).__name__ if events else "None" + logger.warning( + f"No response received from agent {agent_name} - collected {len(events)} events, " + f"first event type: {first_event_type}" + ) + return { + "status": "error", + "error": "No response received from agent", + "agent_name": agent_name, + "message_id": message_id, + } + + except Exception as e: + logger.exception(f"Error sending message to agent {agent_name}") + return {"status": "error", "error": str(e), "agent_name": agent_name, "message_id": message_id} + + @tool + async def registry_send_message_to_agent( + self, agent_name: str, message_text: str, message_id: str = None + ) -> dict[str, Any]: + """ + Send a message to an agent found via registry. + + Args: + agent_name: Name of the agent in the registry + message_text: Message to send + message_id: Optional message ID + + Returns: + dict: Response from the agent + """ + logger.info(f"Sending message to agent {agent_name}: {message_text[:100]}...") + try: + # Get agent card from registry + agent_data = await self._get_agent_card_from_registry(agent_name) + if not agent_data: + error_msg = f"Agent {agent_name} not found in registry" + logger.error(error_msg) + print(error_msg) + return {"status": "error", "error": f"Agent {agent_name} not found in registry: {error_msg}"} + + # Convert AgentCard to dict if needed + agent_dict = agent_data.model_dump() if isinstance(agent_data, AgentCard) else agent_data + return await self._send_message_to_agent_direct(agent_dict, message_text, message_id) + + except Exception as e: + logger.exception(f"Error sending message to agent {agent_name}") + return {"status": "error", "error": str(e), "agent_name": agent_name, "message_id": message_id or "unknown"} + + @tool + async def registry_find_and_message_agent(self, required_skills: List[str], message_text: str) -> Dict[str, Any]: + """ + Find the best agent for a task and send it a message. + + Args: + required_skills: Skills the agent must have + message_text: Message to send to the selected agent + task_description: Optional task description for better matching + + Returns: + dict: Combined discovery and messaging result + """ + # First find the best agent + best_agent_result = await self.registry_find_best_agent_for_task(required_skills) + + if best_agent_result["status"] != "success" or not best_agent_result["best_agent"]: + return {"status": "error", "error": "No suitable agent found", "required_skills": required_skills} + + # Then send message to that agent using the agent data we already have + agent_data = best_agent_result["best_agent"] + message_result = await self._send_message_to_agent_direct(agent_data, message_text) + + return { + "status": "success", + "discovery_result": best_agent_result, + "message_result": message_result, + "selected_agent": agent_data["name"], + } + + def _next_id(self): + """Generate next JSON-RPC request ID.""" + self._request_id += 1 + return self._request_id + + async def _jsonrpc_request(self, method: str, params: dict = None) -> dict: + """Make a JSON-RPC 2.0 request.""" + logger.debug(f"Making JSON-RPC request: {method} with params: {params}") + client = await self._ensure_httpx_client() + payload = {"jsonrpc": "2.0", "method": method, "id": self._next_id()} + if params: + payload["params"] = params + + try: + response = await client.post( + f"{self.registry_url}/jsonrpc", json=payload, headers={"Content-Type": "application/json"} + ) + response.raise_for_status() + result = response.json() + + logger.debug(f"JSON-RPC response for {method}: status={response.status_code}") + + if "error" in result: + error_msg = f"JSON-RPC Error for {method}: {result['error']}" + logger.error(error_msg) + print(error_msg) + raise Exception(f"JSON-RPC Error: {result['error']}") + + return result.get("result", {}) + except Exception as e: + error_msg = f"Failed JSON-RPC request {method}: {e}" + logger.error(error_msg) + logger.exception(f"Full stack trace for {method} failure:") + print(error_msg) + raise + + @tool + async def registry_find_agents_by_skill(self, skill_id: str) -> Dict[str, Any]: + """ + Find all agents that have a specific skill. + + Args: + skill_id: The skill identifier to search for + + Returns: + dict: Search results including agents list and metadata + """ + try: + result = await self._jsonrpc_request("search_agents", {"query": skill_id}) + return { + "status": "success", + "agents": result.get("agents", []), + "skill_searched": skill_id, + "total_count": len(result.get("agents", [])), + } + except Exception as e: + error_msg = f"Error finding agents by skill {skill_id}: {e}" + logger.exception(error_msg) + print(error_msg) + return {"status": "error", "error": str(e), "skill_searched": skill_id} + + @tool + async def registry_get_all_agents(self) -> Dict[str, Any]: + """ + Get all registered agents from the registry. + + Returns: + dict: All registered agents with their capabilities + """ + try: + result = await self._jsonrpc_request("list_agents") + agents = result.get("agents", []) + return {"status": "success", "agents": agents, "total_count": len(agents)} + except Exception as e: + error_msg = f"Error getting all agents: {e}" + logger.exception(error_msg) + print(error_msg) + return {"status": "error", "error": str(e)} + + @tool + async def registry_find_best_agent_for_task(self, required_skills: List[str]) -> Dict[str, Any]: + """ + Find the best agent for a specific task based on required skills. + + Args: + required_skills: List of skills the agent must have + + Returns: + dict: Best matching agent or None if no match found + """ + logger.info(f"Finding best agent for required skills: {required_skills}") + try: + # Get all agents first + all_agents_result = await self._jsonrpc_request("list_agents") + all_agents = all_agents_result.get("agents", []) + logger.debug(f"Found {len(all_agents)} total agents in registry") + + # Filter agents that have all required skills + compatible_agents = [] + for agent in all_agents: + # Handle both 'id' and 'name' fields for backward compatibility + agent_skills = set() + for skill in agent.get("skills", []): + if "id" in skill: + agent_skills.add(skill["id"]) + if "name" in skill: + agent_skills.add(skill["name"]) + + if all(skill in agent_skills for skill in required_skills): + compatible_agents.append(agent) + logger.debug(f"Agent {agent.get('name')} is compatible with required skills") + else: + logger.debug( + f"Agent {agent.get('name')} skills {agent_skills} don't match required {required_skills}" + ) + + logger.info(f"Found {len(compatible_agents)} compatible agents") + if not compatible_agents: + logger.warning(f"No agents found with all required skills: {required_skills}") + return { + "status": "success", + "best_agent": None, + "message": "No agents found with all required skills", + "required_skills": required_skills, + } + + # Simple ranking by skill count + best_agent = max(compatible_agents, key=lambda x: len(x.get("skills", []))) + logger.info( + f"Selected best agent: {best_agent.get('name')} with {len(best_agent.get('skills', []))} skills" + ) + + return { + "status": "success", + "best_agent": best_agent, + "required_skills": required_skills, + "total_compatible": len(compatible_agents), + } + + except Exception as e: + error_msg = f"Error finding best agent for task: {e}" + logger.exception(error_msg) + print(error_msg) + return {"status": "error", "error": str(e), "required_skills": required_skills} + + @tool + async def registry_find_similar_agents(self, reference_agent_id: str) -> Dict[str, Any]: + """ + Find agents similar to a reference agent based on skill overlap. + + Args: + reference_agent_id: The ID of the reference agent + + Returns: + dict: List of similar agents with similarity scores + """ + try: + # Get reference agent + ref_result = await self._jsonrpc_request("get_agent", {"agent_id": reference_agent_id}) + reference_agent = ref_result.get("agent_card") + + if not reference_agent: + return {"status": "error", "error": f"Reference agent {reference_agent_id} not found"} + + # Handle both 'id' and 'name' fields for skills + reference_skills = set() + for skill in reference_agent.get("skills", []): + if "id" in skill: + reference_skills.add(skill["id"]) + if "name" in skill: + reference_skills.add(skill["name"]) + + # Get all agents and calculate similarity + all_agents_result = await self._jsonrpc_request("list_agents") + all_agents = all_agents_result.get("agents", []) + + similar_agents = [] + for agent in all_agents: + if agent["name"] == reference_agent_id: + continue + + # Handle both 'id' and 'name' fields for skills + agent_skills = set() + for skill in agent.get("skills", []): + if "id" in skill: + agent_skills.add(skill["id"]) + if "name" in skill: + agent_skills.add(skill["name"]) + overlap = len(reference_skills.intersection(agent_skills)) + + if overlap > 0: + similarity_score = overlap / len(reference_skills.union(agent_skills)) + agent_copy = agent.copy() + agent_copy["similarity_score"] = similarity_score + similar_agents.append(agent_copy) + + # Sort by similarity + similar_agents.sort(key=lambda x: x["similarity_score"], reverse=True) + + return { + "status": "success", + "similar_agents": similar_agents, + "reference_agent": reference_agent_id, + "total_found": len(similar_agents), + } + + except Exception as e: + error_msg = f"Error finding similar agents to {reference_agent_id}: {e}" + logger.exception(error_msg) + print(error_msg) + return {"status": "error", "error": str(e), "reference_agent": reference_agent_id} diff --git a/tests/test_a2a_registry_client.py b/tests/test_a2a_registry_client.py new file mode 100644 index 00000000..45b6bb83 --- /dev/null +++ b/tests/test_a2a_registry_client.py @@ -0,0 +1,332 @@ +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from strands_tools.a2a_registry_client import AgentRegistryToolProvider + + +def test_init_default_parameters(): + """Test initialization with default parameters.""" + provider = AgentRegistryToolProvider() + + assert provider.registry_url == "http://localhost:8000" + assert provider.timeout == 30 + assert provider.agent_auth is None + assert provider.transports == {} + assert provider._httpx_client is None + assert provider._client_factory is None + assert provider._request_id == 0 + assert provider._agent_cache == {} + + +def test_init_custom_parameters(): + """Test initialization with custom parameters.""" + registry_url = "http://custom-registry.com" + timeout = 60 + transports = {"custom": Mock()} + + provider = AgentRegistryToolProvider(registry_url=registry_url, timeout=timeout, transports=transports) + + assert provider.registry_url == registry_url + assert provider.timeout == timeout + assert provider.transports == transports + + +def test_tools_property(): + """Test that tools property returns decorated methods.""" + provider = AgentRegistryToolProvider() + tools = provider.tools + + tool_names = [tool.tool_name for tool in tools] + assert "registry_send_message_to_agent" in tool_names + assert "registry_find_and_message_agent" in tool_names + assert "registry_find_agents_by_skill" in tool_names + assert "registry_get_all_agents" in tool_names + assert "registry_find_best_agent_for_task" in tool_names + assert "registry_find_similar_agents" in tool_names + + +@pytest.mark.asyncio +async def test_ensure_httpx_client_creates_new_client(): + """Test _ensure_httpx_client creates new client when none exists.""" + provider = AgentRegistryToolProvider(timeout=45) + + with patch("httpx.AsyncClient") as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + result = await provider._ensure_httpx_client() + + mock_client_class.assert_called_once_with(timeout=45) + assert result == mock_client + assert provider._httpx_client == mock_client + + +@pytest.mark.asyncio +async def test_ensure_httpx_client_reuses_existing(): + """Test _ensure_httpx_client reuses existing client.""" + provider = AgentRegistryToolProvider() + existing_client = Mock() + provider._httpx_client = existing_client + + result = await provider._ensure_httpx_client() + + assert result == existing_client + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_ensure_httpx_client") +async def test_ensure_client_factory_with_auth(mock_ensure_client): + """Test _ensure_client_factory with authentication.""" + provider = AgentRegistryToolProvider() + mock_httpx_client = Mock() + mock_ensure_client.return_value = mock_httpx_client + mock_auth = Mock() + mock_auth.__name__ = "MockAuth" + provider.agent_auth = mock_auth + + mock_agent_card = Mock() + mock_agent_card.name = "test_agent" + + with patch("strands_tools.a2a_registry_client.ClientFactory") as mock_factory_class: + mock_factory = Mock() + mock_factory_class.return_value = mock_factory + + result = await provider._ensure_client_factory(mock_agent_card) + + assert result == mock_factory + assert mock_httpx_client.auth == mock_auth.return_value + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_jsonrpc_request") +async def test_get_agent_card_from_registry_success(mock_jsonrpc): + """Test _get_agent_card_from_registry with successful response.""" + provider = AgentRegistryToolProvider() + agent_data = {"name": "test_agent", "url": "http://test.com"} + mock_jsonrpc.return_value = {"found": True, "agent_card": agent_data} + + with patch("strands_tools.a2a_registry_client.AgentCard") as mock_agent_card: + mock_card = Mock() + mock_agent_card.return_value = mock_card + + result = await provider._get_agent_card_from_registry("test_agent") + + assert result == mock_card + mock_jsonrpc.assert_called_once_with("get_agent", {"agent_id": "test_agent"}) + mock_agent_card.assert_called_once_with(**agent_data) + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_jsonrpc_request") +async def test_get_agent_card_from_registry_not_found(mock_jsonrpc): + """Test _get_agent_card_from_registry when agent not found.""" + provider = AgentRegistryToolProvider() + mock_jsonrpc.return_value = {"found": False} + + result = await provider._get_agent_card_from_registry("test_agent") + + assert result is None + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_jsonrpc_request") +async def test_get_agent_card_from_registry_error(mock_jsonrpc): + """Test _get_agent_card_from_registry handles errors.""" + provider = AgentRegistryToolProvider() + mock_jsonrpc.side_effect = Exception("Network error") + + result = await provider._get_agent_card_from_registry("test_agent") + + assert result is None + + +def test_next_id(): + """Test _next_id increments request ID.""" + provider = AgentRegistryToolProvider() + + assert provider._next_id() == 1 + assert provider._next_id() == 2 + assert provider._request_id == 2 + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_ensure_httpx_client") +async def test_jsonrpc_request_success(mock_ensure_client): + """Test _jsonrpc_request with successful response.""" + provider = AgentRegistryToolProvider() + mock_client = Mock() + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"jsonrpc": "2.0", "result": {"data": "test"}, "id": 1} + mock_client.post = AsyncMock(return_value=mock_response) + mock_ensure_client.return_value = mock_client + + result = await provider._jsonrpc_request("test_method", {"param": "value"}) + + assert result == {"data": "test"} + mock_client.post.assert_called_once() + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_ensure_httpx_client") +async def test_jsonrpc_request_error_response(mock_ensure_client): + """Test _jsonrpc_request with error response.""" + provider = AgentRegistryToolProvider() + mock_client = Mock() + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"jsonrpc": "2.0", "error": {"code": -1, "message": "Test error"}, "id": 1} + mock_client.post = AsyncMock(return_value=mock_response) + mock_ensure_client.return_value = mock_client + + with pytest.raises(Exception, match="JSON-RPC Error"): + await provider._jsonrpc_request("test_method") + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_get_agent_card_from_registry") +@patch.object(AgentRegistryToolProvider, "_send_message_to_agent_direct") +async def test_registry_send_message_to_agent_success(mock_send_message, mock_get_agent): + """Test registry_send_message_to_agent with successful flow.""" + provider = AgentRegistryToolProvider() + mock_agent_card = Mock() + mock_agent_card.model_dump.return_value = {"name": "test_agent"} + mock_get_agent.return_value = mock_agent_card + mock_send_message.return_value = {"status": "success", "response": "test response"} + + result = await provider.registry_send_message_to_agent("test_agent", "Hello") + + assert result["status"] == "success" + mock_get_agent.assert_called_once_with("test_agent") + mock_send_message.assert_called_once() + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_get_agent_card_from_registry") +async def test_registry_send_message_to_agent_not_found(mock_get_agent): + """Test registry_send_message_to_agent when agent not found.""" + provider = AgentRegistryToolProvider() + mock_get_agent.return_value = None + + result = await provider.registry_send_message_to_agent("test_agent", "Hello") + + assert result["status"] == "error" + assert "not found in registry" in result["error"] + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_jsonrpc_request") +async def test_registry_find_agents_by_skill_success(mock_jsonrpc): + """Test registry_find_agents_by_skill with successful response.""" + provider = AgentRegistryToolProvider() + agents_data = [{"name": "agent1"}, {"name": "agent2"}] + mock_jsonrpc.return_value = {"agents": agents_data} + + result = await provider.registry_find_agents_by_skill("python") + + assert result["status"] == "success" + assert result["agents"] == agents_data + assert result["skill_searched"] == "python" + assert result["total_count"] == 2 + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_jsonrpc_request") +async def test_registry_get_all_agents_success(mock_jsonrpc): + """Test registry_get_all_agents with successful response.""" + provider = AgentRegistryToolProvider() + agents_data = [{"name": "agent1"}, {"name": "agent2"}] + mock_jsonrpc.return_value = {"agents": agents_data} + + result = await provider.registry_get_all_agents() + + assert result["status"] == "success" + assert result["agents"] == agents_data + assert result["total_count"] == 2 + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_jsonrpc_request") +async def test_registry_find_best_agent_for_task_success(mock_jsonrpc): + """Test registry_find_best_agent_for_task finds compatible agent.""" + provider = AgentRegistryToolProvider() + agents_data = [ + {"name": "agent1", "skills": [{"id": "python"}, {"id": "web"}]}, + {"name": "agent2", "skills": [{"id": "python"}]}, + ] + mock_jsonrpc.return_value = {"agents": agents_data} + + result = await provider.registry_find_best_agent_for_task(["python"]) + + assert result["status"] == "success" + assert result["best_agent"]["name"] == "agent1" # Has more skills + assert result["total_compatible"] == 2 + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_jsonrpc_request") +async def test_registry_find_best_agent_for_task_no_match(mock_jsonrpc): + """Test registry_find_best_agent_for_task when no agents match.""" + provider = AgentRegistryToolProvider() + agents_data = [{"name": "agent1", "skills": [{"id": "java"}]}] + mock_jsonrpc.return_value = {"agents": agents_data} + + result = await provider.registry_find_best_agent_for_task(["python"]) + + assert result["status"] == "success" + assert result["best_agent"] is None + assert "No agents found" in result["message"] + + +@pytest.mark.asyncio +@patch.object(AgentRegistryToolProvider, "_jsonrpc_request") +async def test_registry_find_similar_agents_success(mock_jsonrpc): + """Test registry_find_similar_agents finds similar agents.""" + provider = AgentRegistryToolProvider() + reference_agent = {"name": "ref_agent", "skills": [{"id": "python"}, {"id": "web"}]} + all_agents = [{"name": "agent1", "skills": [{"id": "python"}]}, {"name": "agent2", "skills": [{"id": "java"}]}] + + mock_jsonrpc.side_effect = [ + {"agent_card": reference_agent}, # get_agent call + {"agents": all_agents}, # list_agents call + ] + + result = await provider.registry_find_similar_agents("ref_agent") + + assert result["status"] == "success" + assert len(result["similar_agents"]) == 1 # Only agent1 has overlap + assert result["similar_agents"][0]["name"] == "agent1" + assert "similarity_score" in result["similar_agents"][0] + + +@pytest.mark.asyncio +async def test_registry_find_and_message_agent_success(): + """Test registry_find_and_message_agent with successful flow.""" + provider = AgentRegistryToolProvider() + best_agent_data = {"name": "best_agent"} + + with patch.object(provider, "registry_find_best_agent_for_task", new_callable=AsyncMock) as mock_find_best: + with patch.object(provider, "_send_message_to_agent_direct", new_callable=AsyncMock) as mock_send_message: + mock_find_best.return_value = {"status": "success", "best_agent": best_agent_data} + mock_send_message.return_value = {"status": "success", "response": "test response"} + + result = await provider.registry_find_and_message_agent(["python"], "Hello") + + assert result["status"] == "success" + assert result["selected_agent"] == "best_agent" + mock_find_best.assert_called_once_with(["python"]) + mock_send_message.assert_called_once_with(best_agent_data, "Hello") + + +@pytest.mark.asyncio +async def test_registry_find_and_message_agent_no_agent_found(): + """Test registry_find_and_message_agent when no suitable agent found.""" + provider = AgentRegistryToolProvider() + + with patch.object(provider, "registry_find_best_agent_for_task", new_callable=AsyncMock) as mock_find_best: + mock_find_best.return_value = {"status": "success", "best_agent": None} + + result = await provider.registry_find_and_message_agent(["python"], "Hello") + + assert result["status"] == "error" + assert "No suitable agent found" in result["error"]