Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 20 additions & 30 deletions src/strands_tools/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,83 +96,73 @@ def batch(tool: ToolUse, **kwargs) -> ToolResult:
agent = kwargs.get("agent")
invocations = kwargs.get("invocations", [])
results = []

try:
if not hasattr(agent, "tool") or agent.tool is None:
raise AttributeError("Agent does not have a valid 'tool' attribute.")

for invocation in invocations:
tool_name = invocation.get("name")
arguments = invocation.get("arguments", {})
tool_fn = getattr(agent.tool, tool_name, None)

if callable(tool_fn):
try:
# Call the tool function with the provided arguments
result = tool_fn(**arguments)

# Create a consistent result structure
batch_result = {
"name": tool_name,
"status": "success",
"result": result
}
batch_result = {"name": tool_name, "status": "success", "result": result}
results.append(batch_result)

except Exception as e:
error_msg = f"Error executing tool '{tool_name}': {str(e)}"
console.print(error_msg)

batch_result = {
"name": tool_name,
"status": "error",
"error": str(e),
"traceback": traceback.format_exc()
"traceback": traceback.format_exc(),
}
results.append(batch_result)
else:
error_msg = f"Tool '{tool_name}' not found in agent"
console.print(error_msg)

batch_result = {
"name": tool_name,
"status": "error",
"error": error_msg
}

batch_result = {"name": tool_name, "status": "error", "error": error_msg}
results.append(batch_result)

# Create a readable summary for the agent
summary_lines = []
summary_lines.append(f"Batch execution completed with {len(results)} tool(s):")

for result in results:
if result["status"] == "success":
summary_lines.append(f"✓ {result['name']}: Success")
else:
summary_lines.append(f"✗ {result['name']}: Error - {result['error']}")

summary_text = "\n".join(summary_lines)

return {
"toolUseId": tool_use_id,
"status": "success",
"content": [
{
"text": summary_text
},
{"text": summary_text},
{
"json": {
"batch_summary": {
"total_tools": len(results),
"successful": len([r for r in results if r["status"] == "success"]),
"failed": len([r for r in results if r["status"] == "error"])
"failed": len([r for r in results if r["status"] == "error"]),
},
"results": results
"results": results,
}
}
]
},
],
}

except Exception as e:
error_msg = f"Error in batch tool: {str(e)}\n{traceback.format_exc()}"
console.print(f"Error in batch tool: {str(e)}")
Expand Down
5 changes: 3 additions & 2 deletions src/strands_tools/mem0_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,9 @@ def format_retrieve_response(memories: List[Dict]) -> Panel:
def format_retrieve_graph_response(memories: List[Dict]) -> Panel:
"""Format retrieve response for graph data"""
if not memories:
return Panel("No graph memories found matching the query.",
title="[bold yellow]No Matches", border_style="yellow")
return Panel(
"No graph memories found matching the query.", title="[bold yellow]No Matches", border_style="yellow"
)

table = Table(title="Search Results", show_header=True, header_style="bold magenta")
table.add_column("Source", style="cyan")
Expand Down
12 changes: 9 additions & 3 deletions src/strands_tools/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,9 +612,11 @@ def start_workflow(self, workflow_id: str) -> Dict:
for task in ready_tasks[:current_batch_size]:
task_id = task["task_id"]
if task_id not in active_futures and task_id not in completed_tasks:
# Namespace task_id with workflow_id to prevent conflicts
namespaced_task_id = f"{workflow_id}:{task_id}"
tasks_to_submit.append(
(
task_id,
namespaced_task_id,
self.execute_task,
(task, workflow),
{},
Expand All @@ -633,9 +635,13 @@ def start_workflow(self, workflow_id: str) -> Dict:

# Process completed tasks
completed_task_ids = []
for task_id, future in active_futures.items():
for namespaced_task_id, future in active_futures.items():
if future in done:
completed_task_ids.append(task_id)
# Extract original task_id from namespaced version
task_id = (
namespaced_task_id.split(":", 1)[1] if ":" in namespaced_task_id else namespaced_task_id
)
completed_task_ids.append(namespaced_task_id)
try:
result = future.result()

Expand Down
34 changes: 17 additions & 17 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,29 @@
def mock_agent():
"""Fixture to create a mock agent with tools."""
agent = MagicMock()

# Create a mock tool registry that mimics the real agent's tool access pattern
mock_tool_registry = MagicMock()
mock_tool_registry.registry = {
"http_request": MagicMock(return_value={"status": "success", "result": {"ip": "127.0.0.1"}}),
"use_aws": MagicMock(return_value={"status": "success", "result": {"buckets": ["bucket1", "bucket2"]}}),
"error_tool": MagicMock(side_effect=Exception("Tool execution failed"))
"error_tool": MagicMock(side_effect=Exception("Tool execution failed")),
}
agent.tool_registry = mock_tool_registry

# Create a custom mock tool object that properly handles getattr
class MockTool:
def __init__(self):
self.http_request = mock_tool_registry.registry["http_request"]
self.use_aws = mock_tool_registry.registry["use_aws"]
self.error_tool = mock_tool_registry.registry["error_tool"]

def __getattr__(self, name):
# Return None for non-existent tools (this will make callable() return False)
return None

agent.tool = MockTool()

return agent


Expand All @@ -47,18 +47,18 @@ def test_batch_success(mock_agent):
assert result["toolUseId"] == "mock_tool_id"
assert result["status"] == "success"
assert len(result["content"]) == 2

# Check the summary text
assert "Batch execution completed with 2 tool(s):" in result["content"][0]["text"]
assert "✓ http_request: Success" in result["content"][0]["text"]
assert "✓ use_aws: Success" in result["content"][0]["text"]

# Check the JSON results
json_content = result["content"][1]["json"]
assert json_content["batch_summary"]["total_tools"] == 2
assert json_content["batch_summary"]["successful"] == 2
assert json_content["batch_summary"]["failed"] == 0

results = json_content["results"]
assert len(results) == 2
assert results[0]["name"] == "http_request"
Expand All @@ -81,17 +81,17 @@ def test_batch_missing_tool(mock_agent):
assert result["toolUseId"] == "mock_tool_id"
assert result["status"] == "success"
assert len(result["content"]) == 2

# Check the summary text
assert "Batch execution completed with 1 tool(s):" in result["content"][0]["text"]
assert "✗ non_existent_tool: Error" in result["content"][0]["text"]

# Check the JSON results
json_content = result["content"][1]["json"]
assert json_content["batch_summary"]["total_tools"] == 1
assert json_content["batch_summary"]["successful"] == 0
assert json_content["batch_summary"]["failed"] == 1

results = json_content["results"]
assert len(results) == 1
assert results[0]["name"] == "non_existent_tool"
Expand All @@ -111,17 +111,17 @@ def test_batch_tool_error(mock_agent):
assert result["toolUseId"] == "mock_tool_id"
assert result["status"] == "success"
assert len(result["content"]) == 2

# Check the summary text
assert "Batch execution completed with 1 tool(s):" in result["content"][0]["text"]
assert "✗ error_tool: Error" in result["content"][0]["text"]

# Check the JSON results
json_content = result["content"][1]["json"]
assert json_content["batch_summary"]["total_tools"] == 1
assert json_content["batch_summary"]["successful"] == 0
assert json_content["batch_summary"]["failed"] == 1

results = json_content["results"]
assert len(results) == 1
assert results[0]["name"] == "error_tool"
Expand All @@ -140,10 +140,10 @@ def test_batch_no_invocations(mock_agent):
assert result["toolUseId"] == "mock_tool_id"
assert result["status"] == "success"
assert len(result["content"]) == 2

# Check the summary text
assert "Batch execution completed with 0 tool(s):" in result["content"][0]["text"]

# Check the JSON results
json_content = result["content"][1]["json"]
assert json_content["batch_summary"]["total_tools"] == 0
Expand Down
11 changes: 11 additions & 0 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,17 @@ def test_start_workflow_not_found(self, mock_parent_agent):
assert result["status"] == "error"
assert "not found" in result["content"][0]["text"]

def test_task_id_namespacing(self):
"""Test task ID namespacing and extraction logic."""
workflow_id = "test_workflow"
task_id = "task1"

namespaced_task_id = f"{workflow_id}:{task_id}"
assert namespaced_task_id == "test_workflow:task1"

extracted_id = namespaced_task_id.split(":", 1)[1] if ":" in namespaced_task_id else namespaced_task_id
assert extracted_id == "task1"


class TestWorkflowStatus:
"""Test workflow status functionality."""
Expand Down