diff --git a/src/strands_tools/batch.py b/src/strands_tools/batch.py index 1c19d258..ec6b53b5 100644 --- a/src/strands_tools/batch.py +++ b/src/strands_tools/batch.py @@ -79,7 +79,7 @@ def batch(tool: ToolUse, **kwargs) -> ToolResult: - If a tool function is not found or an error occurs, it will be captured in the results. - This tool is designed to work with agents that support dynamic tool invocation. - Sammple output: + Sample output: { "status": "success", "results": [ @@ -96,41 +96,83 @@ def batch(tool: ToolUse, **kwargs) -> ToolResult: agent = kwargs.get("agent") invocations = kwargs.get("invocations", []) results = [] + try: if not hasattr(agent, "tool") or agent.tool is None: raise AttributeError("Agent does not have a valid 'tool' attribute.") + for invocation in invocations: tool_name = invocation.get("name") arguments = invocation.get("arguments", {}) tool_fn = getattr(agent.tool, tool_name, None) + if callable(tool_fn): try: - # Only pass JSON-serializable arguments to the tool + # Call the tool function with the provided arguments result = tool_fn(**arguments) - - if result["status"] == "success": - results.append({"json": {"name": tool_name, "status": "success", "result": result}}) - else: - results.append( - {"toolUseId": tool_use_id, "status": "error", "content": [{"text": "Tool missing"}]} - ) + + # Create a consistent result structure + batch_result = { + "name": tool_name, + "status": "success", + "result": result + } + results.append(batch_result) + except Exception as e: - error_msg = f"Error in batch tool: {str(e)}\n{traceback.format_exc()}" - console.print(f"Error in batch tool: {str(e)}") - results.append({"toolUseId": tool_use_id, "status": "error", "content": [{"text": error_msg}]}) - else: - results.append( - { - "toolUseId": tool_use_id, + error_msg = f"Error executing tool '{tool_name}': {str(e)}" + console.print(error_msg) + + batch_result = { + "name": tool_name, "status": "error", - "content": [{"text": f"Tool '{tool_name}' not found in agent or tool call failed."}], + "error": str(e), + "traceback": traceback.format_exc() } - ) + results.append(batch_result) + else: + error_msg = f"Tool '{tool_name}' not found in agent" + console.print(error_msg) + + batch_result = { + "name": tool_name, + "status": "error", + "error": error_msg + } + results.append(batch_result) + + # Create a readable summary for the agent + summary_lines = [] + summary_lines.append(f"Batch execution completed with {len(results)} tool(s):") + + for result in results: + if result["status"] == "success": + summary_lines.append(f"✓ {result['name']}: Success") + else: + summary_lines.append(f"✗ {result['name']}: Error - {result['error']}") + + summary_text = "\n".join(summary_lines) + return { "toolUseId": tool_use_id, "status": "success", - "content": results, + "content": [ + { + "text": summary_text + }, + { + "json": { + "batch_summary": { + "total_tools": len(results), + "successful": len([r for r in results if r["status"] == "success"]), + "failed": len([r for r in results if r["status"] == "error"]) + }, + "results": results + } + } + ] } + except Exception as e: error_msg = f"Error in batch tool: {str(e)}\n{traceback.format_exc()}" console.print(f"Error in batch tool: {str(e)}") diff --git a/tests/test_batch.py b/tests/test_batch.py index be44c659..b73984d2 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -8,9 +8,29 @@ def mock_agent(): """Fixture to create a mock agent with tools.""" agent = MagicMock() - agent.tool.http_request = MagicMock(return_value={"status": "success", "result": {"ip": "127.0.0.1"}}) - agent.tool.use_aws = MagicMock(return_value={"status": "success", "result": {"buckets": ["bucket1", "bucket2"]}}) - agent.tool.error_tool = MagicMock(side_effect=Exception("Tool execution failed")) + + # Create a mock tool registry that mimics the real agent's tool access pattern + mock_tool_registry = MagicMock() + mock_tool_registry.registry = { + "http_request": MagicMock(return_value={"status": "success", "result": {"ip": "127.0.0.1"}}), + "use_aws": MagicMock(return_value={"status": "success", "result": {"buckets": ["bucket1", "bucket2"]}}), + "error_tool": MagicMock(side_effect=Exception("Tool execution failed")) + } + agent.tool_registry = mock_tool_registry + + # Create a custom mock tool object that properly handles getattr + class MockTool: + def __init__(self): + self.http_request = mock_tool_registry.registry["http_request"] + self.use_aws = mock_tool_registry.registry["use_aws"] + self.error_tool = mock_tool_registry.registry["error_tool"] + + def __getattr__(self, name): + # Return None for non-existent tools (this will make callable() return False) + return None + + agent.tool = MockTool() + return agent @@ -27,12 +47,26 @@ def test_batch_success(mock_agent): assert result["toolUseId"] == "mock_tool_id" assert result["status"] == "success" assert len(result["content"]) == 2 - assert result["content"][0]["json"]["name"] == "http_request" - assert result["content"][0]["json"]["status"] == "success" - assert result["content"][0]["json"]["result"]["result"]["ip"] == "127.0.0.1" - assert result["content"][1]["json"]["name"] == "use_aws" - assert result["content"][1]["json"]["status"] == "success" - assert result["content"][1]["json"]["result"]["result"]["buckets"] == ["bucket1", "bucket2"] + + # Check the summary text + assert "Batch execution completed with 2 tool(s):" in result["content"][0]["text"] + assert "✓ http_request: Success" in result["content"][0]["text"] + assert "✓ use_aws: Success" in result["content"][0]["text"] + + # Check the JSON results + json_content = result["content"][1]["json"] + assert json_content["batch_summary"]["total_tools"] == 2 + assert json_content["batch_summary"]["successful"] == 2 + assert json_content["batch_summary"]["failed"] == 0 + + results = json_content["results"] + assert len(results) == 2 + assert results[0]["name"] == "http_request" + assert results[0]["status"] == "success" + assert results[0]["result"]["result"]["ip"] == "127.0.0.1" + assert results[1]["name"] == "use_aws" + assert results[1]["status"] == "success" + assert results[1]["result"]["result"]["buckets"] == ["bucket1", "bucket2"] def test_batch_missing_tool(mock_agent): @@ -46,10 +80,23 @@ def test_batch_missing_tool(mock_agent): assert result["toolUseId"] == "mock_tool_id" assert result["status"] == "success" - assert len(result["content"]) == 1 - assert result["content"][0]["toolUseId"] == "mock_tool_id" - assert result["content"][0]["status"] == "error" - assert "Tool missing" in result["content"][0]["content"][0]["text"] + assert len(result["content"]) == 2 + + # Check the summary text + assert "Batch execution completed with 1 tool(s):" in result["content"][0]["text"] + assert "✗ non_existent_tool: Error" in result["content"][0]["text"] + + # Check the JSON results + json_content = result["content"][1]["json"] + assert json_content["batch_summary"]["total_tools"] == 1 + assert json_content["batch_summary"]["successful"] == 0 + assert json_content["batch_summary"]["failed"] == 1 + + results = json_content["results"] + assert len(results) == 1 + assert results[0]["name"] == "non_existent_tool" + assert results[0]["status"] == "error" + assert "not found in agent" in results[0]["error"] def test_batch_tool_error(mock_agent): @@ -63,10 +110,24 @@ def test_batch_tool_error(mock_agent): assert result["toolUseId"] == "mock_tool_id" assert result["status"] == "success" - assert len(result["content"]) == 1 - assert result["content"][0]["toolUseId"] == "mock_tool_id" - assert result["content"][0]["status"] == "error" - assert "Error in batch tool" in result["content"][0]["content"][0]["text"] + assert len(result["content"]) == 2 + + # Check the summary text + assert "Batch execution completed with 1 tool(s):" in result["content"][0]["text"] + assert "✗ error_tool: Error" in result["content"][0]["text"] + + # Check the JSON results + json_content = result["content"][1]["json"] + assert json_content["batch_summary"]["total_tools"] == 1 + assert json_content["batch_summary"]["successful"] == 0 + assert json_content["batch_summary"]["failed"] == 1 + + results = json_content["results"] + assert len(results) == 1 + assert results[0]["name"] == "error_tool" + assert results[0]["status"] == "error" + assert "Tool execution failed" in results[0]["error"] + assert "traceback" in results[0] def test_batch_no_invocations(mock_agent): @@ -78,7 +139,17 @@ def test_batch_no_invocations(mock_agent): assert result["toolUseId"] == "mock_tool_id" assert result["status"] == "success" - assert len(result["content"]) == 0 + assert len(result["content"]) == 2 + + # Check the summary text + assert "Batch execution completed with 0 tool(s):" in result["content"][0]["text"] + + # Check the JSON results + json_content = result["content"][1]["json"] + assert json_content["batch_summary"]["total_tools"] == 0 + assert json_content["batch_summary"]["successful"] == 0 + assert json_content["batch_summary"]["failed"] == 0 + assert len(json_content["results"]) == 0 def test_batch_top_level_error(mock_agent):